# -*- coding: utf-8 -*-
import pickle
import torch,random
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split,TensorDataset
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint,LearningRateMonitor
# 自动停止
# https://pytorch-lightning.readthedocs.io/en/1.2.1/common/early_stopping.html
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torch.optim as optim
from tqdm.auto import tqdm
import torchmetrics
# from tkitbilstm import BiLSTMAttention as Bilstm
from torchcrf import CRF
[文档]class autoModel(pl.LightningModule):
    """
    继承自bertlm模型
    https://colab.research.google.com/drive/1-OEwiD9ouGjWrSFEWhgEnWiNvwwxlqd7#scrollTo=no6DwOqaE9Jw
    做预测
    
    https://github.com/lucidrains/performer-pytorch
    """
# class COCO(nn.Module):
    def __init__(
        self,learning_rate=3e-4,T_max=5,hidden_size=256,vocab_size=21128,ignore_index=0,out_num_classes=12,en_num_layers=2,de_num_layers=2,optimizer_name="AdamW",dropout=0.2,
        batch_size=2,trainfile="./data/train.pkt",valfile="./data/val.pkt",testfile="./data/test.pkt", **kwargs):
        super().__init__()
        self.save_hyperparameters()
        # SRC_SEQ_LEN=128
        # TGT_SEQ_LEN=128
        # DE_SEQ_LEN=128
        # EN_SEQ_LEN=128
        # self.hparams.hidden_size
        print(self.hparams)
        # self.model=Bilstm(
        #                   vocab_size=self.hparams.vocab_size,
        #                   dim=self.hparams.hidden_size,
        #                   n_hidden=self.hparams.hidden_size,out_num_classes=self.hparams.out_num_classes,embedding_enabled=True,
        #                   attention=False)
        self.embedding = nn.Embedding(vocab_size, hidden_size,padding_idx=0)
        self.model=nn.LSTM(hidden_size,hidden_size,dropout=dropout,
                           num_layers=2,
                           batch_first=False,
                           bidirectional=True
        )
        
        
        self.c=nn.Sequential(
            nn.Dropout(self.hparams.dropout),
            nn.Tanh(),
            nn.Linear(self.hparams.hidden_size*2,self.hparams.out_num_classes),
            nn.Dropout(self.hparams.dropout),
            nn.Tanh(),
            
            
        )
        self.d=nn.Dropout(self.hparams.dropout)
        self.decoder = CRF(self.hparams.out_num_classes,batch_first=False)
        
        # self.accuracy = torchmetrics.Accuracy(ignore_index=self.hparams.ignore_index)
        self.accuracy = torchmetrics.Accuracy()
#         self.encoder_hidden = self.enc.initHidden()
        # print(self)
[文档]    def forward(self, x,y,x_attention_mask,y_attention_mask, decode=False):
        x=self.embedding(x)
        y=y.permute(1,0)
        x=x.permute(1,0,2)
        x=self.d(x)
        x,_=self.model(x)
        
        x=self.c(x)
        
        # print(x.size())
        
        
        loss = self.decoder (x, y.long(),reduction="token_mean")
        loss=loss*-1
        if decode:
            pred=self.decoder.decode(x)
            return pred, loss
            pass
        else:
            return loss 
[文档]    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x,x_attention_mask,y,y_attention_mask = batch
        loss  = self(x,y,x_attention_mask,y_attention_mask)
        self.log('train_loss',loss)
        return  loss 
[文档]    def validation_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x,x_attention_mask,y,y_attention_mask = batch
        y=y.int()
        pred,loss  = self(x,y,x_attention_mask,y_attention_mask,decode=True)
#         print("outputs",outputs.size())
        acc=self.accuracy(torch.Tensor(pred).to(self.device).view(-1).int(), y.reshape(-1))
        metrics = {"val_acc": acc, "val_loss": loss}
        # print(pred)
        # metrics = { "val_loss": loss}
        self.log_dict(metrics)
        return metrics 
[文档]    def test_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x,x_attention_mask,y,y_attention_mask = batch
        y=y.int()
        pred,loss  = self(x,y,x_attention_mask,y_attention_mask,decode=True)
#         print("outputs",outputs.size())
        acc=self.accuracy(torch.Tensor(pred).to(self.device).view(-1).int(),y.reshape(-1))
        metrics = {"test_acc": acc, "test_loss": loss}
        self.log_dict(metrics)
        return metrics 
        
[文档]    def train_dataloader(self):
        train=torch.load(self.hparams.trainfile)
        return DataLoader(train, batch_size=int(self.hparams.batch_size),num_workers=2,pin_memory=True, shuffle=True) 
[文档]    def val_dataloader(self):
        val=torch.load(self.hparams.valfile)
        return DataLoader(val, batch_size=int(self.hparams.batch_size),num_workers=2,pin_memory=True) 
[文档]    def test_dataloader(self):
        val=torch.load(self.hparams.testfile)
        return DataLoader(val, batch_size=int(self.hparams.batch_size),num_workers=2,pin_memory=True)