tkitAttNLocal.attnlocal 源代码

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn


# import pytorch_lightning as pl

[文档]class AttNLocal(nn.Module): """ 自动限制矩阵 实现斜对角线保留 """ def __init__(self, maxlen=128, limit=20): super(AttNLocal, self).__init__() self.limit = limit self.maxlen = maxlen pass
[文档] def autoBulidlMaskLimit(self): """ 构建一个矩阵 自动限制 预测长度 """ datas = [] for it in range(self.maxlen): new = it * [0] + [1] * self.limit + [0] * self.maxlen datas.append(new[:self.maxlen + self.limit]) return datas
[文档] def forward(self, x): B, L, D = x.size() m = self.autoBulidlMaskLimit() mask = torch.Tensor([m] * B) # print(mask.size()) # torch.where(mask>0,x,mask) # 构建填充 pad = torch.zeros(B, L, self.limit) xplus = torch.cat((x, pad.to(x.device)), dim=-1) active_loss = mask.view(-1) == 1 # print(xplus.size()) xplus_out = xplus.view(-1)[active_loss].view(B, L, -1) return xplus_out pass
if __name__ == "__main__": print("start test") # 输入维度和长度一样的矩阵 a = torch.randn(32, 10, 10) print("a", a) attL = AttNLocal(10, 5) print(attL(a).size())