# -*- coding: utf-8 -*-
import pickle
import torch,random
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split,TensorDataset
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint,LearningRateMonitor
# 自动停止
# https://pytorch-lightning.readthedocs.io/en/1.2.1/common/early_stopping.html
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torch.optim as optim
from tqdm.auto import tqdm
import torchmetrics
# from tkitbilstm import BiLSTMAttention as Bilstm
from torchcrf import CRF
[文档]class autoModel(pl.LightningModule):
"""
继承自bertlm模型
https://colab.research.google.com/drive/1-OEwiD9ouGjWrSFEWhgEnWiNvwwxlqd7#scrollTo=no6DwOqaE9Jw
做预测
https://github.com/lucidrains/performer-pytorch
"""
# class COCO(nn.Module):
def __init__(
self,learning_rate=3e-4,T_max=5,hidden_size=256,vocab_size=21128,ignore_index=0,out_num_classes=12,en_num_layers=2,de_num_layers=2,optimizer_name="AdamW",dropout=0.2,
batch_size=2,trainfile="./data/train.pkt",valfile="./data/val.pkt",testfile="./data/test.pkt", **kwargs):
super().__init__()
self.save_hyperparameters()
# SRC_SEQ_LEN=128
# TGT_SEQ_LEN=128
# DE_SEQ_LEN=128
# EN_SEQ_LEN=128
# self.hparams.hidden_size
print(self.hparams)
# self.model=Bilstm(
# vocab_size=self.hparams.vocab_size,
# dim=self.hparams.hidden_size,
# n_hidden=self.hparams.hidden_size,out_num_classes=self.hparams.out_num_classes,embedding_enabled=True,
# attention=False)
self.embedding = nn.Embedding(vocab_size, hidden_size,padding_idx=0)
self.model=nn.LSTM(hidden_size,hidden_size,dropout=dropout,
num_layers=2,
batch_first=False,
bidirectional=True
)
self.c=nn.Sequential(
nn.Dropout(self.hparams.dropout),
nn.Tanh(),
nn.Linear(self.hparams.hidden_size*2,self.hparams.out_num_classes),
nn.Dropout(self.hparams.dropout),
nn.Tanh(),
)
self.d=nn.Dropout(self.hparams.dropout)
self.decoder = CRF(self.hparams.out_num_classes,batch_first=False)
# self.accuracy = torchmetrics.Accuracy(ignore_index=self.hparams.ignore_index)
self.accuracy = torchmetrics.Accuracy()
# self.encoder_hidden = self.enc.initHidden()
# print(self)
[文档] def forward(self, x,y,x_attention_mask,y_attention_mask, decode=False):
x=self.embedding(x)
y=y.permute(1,0)
x=x.permute(1,0,2)
x=self.d(x)
x,_=self.model(x)
x=self.c(x)
# print(x.size())
loss = self.decoder (x, y.long(),reduction="token_mean")
loss=loss*-1
if decode:
pred=self.decoder.decode(x)
return pred, loss
pass
else:
return loss
[文档] def training_step(self, batch, batch_idx):
# training_step defined the train loop.
# It is independent of forward
x,x_attention_mask,y,y_attention_mask = batch
loss = self(x,y,x_attention_mask,y_attention_mask)
self.log('train_loss',loss)
return loss
[文档] def validation_step(self, batch, batch_idx):
# training_step defined the train loop.
# It is independent of forward
x,x_attention_mask,y,y_attention_mask = batch
y=y.int()
pred,loss = self(x,y,x_attention_mask,y_attention_mask,decode=True)
# print("outputs",outputs.size())
acc=self.accuracy(torch.Tensor(pred).to(self.device).view(-1).int(), y.reshape(-1))
metrics = {"val_acc": acc, "val_loss": loss}
# print(pred)
# metrics = { "val_loss": loss}
self.log_dict(metrics)
return metrics
[文档] def test_step(self, batch, batch_idx):
# training_step defined the train loop.
# It is independent of forward
x,x_attention_mask,y,y_attention_mask = batch
y=y.int()
pred,loss = self(x,y,x_attention_mask,y_attention_mask,decode=True)
# print("outputs",outputs.size())
acc=self.accuracy(torch.Tensor(pred).to(self.device).view(-1).int(),y.reshape(-1))
metrics = {"test_acc": acc, "test_loss": loss}
self.log_dict(metrics)
return metrics
[文档] def train_dataloader(self):
train=torch.load(self.hparams.trainfile)
return DataLoader(train, batch_size=int(self.hparams.batch_size),num_workers=2,pin_memory=True, shuffle=True)
[文档] def val_dataloader(self):
val=torch.load(self.hparams.valfile)
return DataLoader(val, batch_size=int(self.hparams.batch_size),num_workers=2,pin_memory=True)
[文档] def test_dataloader(self):
val=torch.load(self.hparams.testfile)
return DataLoader(val, batch_size=int(self.hparams.batch_size),num_workers=2,pin_memory=True)