- »
- 模块代码 »
- tkitAttNLocal.attnlocalNew
-
tkitAttNLocal.attnlocalNew 源代码
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
# import pytorch_lightning as pl
[文档]class AttNLocalNew(nn.Module):
"""
自动限制矩阵
实现斜对角线保留权重,其他的设为-inf
"""
def __init__(self, maxlen=128, limit=20):
super(AttNLocalNew, self).__init__()
self.limit = limit
self.maxlen = maxlen
pass
[文档] def forward(self, x):
# B, L, D = x.size()
mask = torch.ones_like(x).tril(diagonal=-1) + torch.ones_like(x).triu(diagonal=self.limit) # 下三角矩阵
x[mask == 1] = -float("Inf")
return x
pass
if __name__ == "__main__":
print("start test")
# 输入维度和长度一样的矩阵
a = torch.randn(5, 10, 10)
# print("a", a)
attL = AttNLocalNew(10, 5)
out = attL(a)
print(out)
print(out.argmax(-1))
# print()