tkitAutoMask.mask 源代码

# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import math
from functools import reduce
# import pytorch_lightning as pl
import torch
from torch import nn
import torch.nn.functional as F
import random
import numpy as np
# helpers

[文档]def prob_mask_like(t, prob): return torch.zeros_like(t).float().uniform_(0, 1) < prob
[文档]def mask_with_tokens(t, token_ids): """ 用记号遮掩 """ init_no_mask = torch.full_like(t, False, dtype=torch.bool) mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask) # print("mask",mask) return mask
[文档]def get_mask_subset_with_prob(mask, prob): batch, seq_len, device = *mask.shape, mask.device max_masked = math.ceil(prob * seq_len) num_tokens = mask.sum(dim=-1, keepdim=True) mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil()) mask_excess = mask_excess[:, :max_masked] rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9) _, sampled_indices = rand.topk(max_masked, dim=-1) sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0) new_mask = torch.zeros((batch, seq_len + 1), device=device) new_mask.scatter_(-1, sampled_indices, 1) return new_mask[:, 1:].bool()
[文档]def get_mask_subset_with_prob_tri(mask, prob,subset_with_prob=True): """优化版本 三角形 动态掩盖 对屏蔽的数据进行放弃一部分 上三角和下三角自动随机选择 # 添加按照比例 subset_with_prob 掩码 """ batch, seq_len, device = *mask.shape, mask.device # 自动选择头尾 if random.choice([0,1])==0: a_mask = mask.triu(diagonal=2) # diagonal 设置偏移 else: a_mask = mask.tril(diagonal=-2) # diagonal 设置偏移 # a_mask.bool() mask1=a_mask.bool() # print("mask1",mask1) if subset_with_prob==True: # 融合概率mask mask2=get_mask_subset_with_prob(mask,prob*2) # print("mask2",mask2) # print("mask1",mask1) a_mask=torch.where(mask1==True,mask1,mask2) # a_mask=torch.where(a_mask==False,a_mask,mask2) # a_mask=get_mask_subset_with_prob(a_mask,prob) return a_mask else: return mask1
[文档]def get_mask_subset_with_prob_diagonal(mask, prob,subset_with_prob=True): """优化版本 对角线矩阵掩盖 迭代中会自动随机掩盖 实现连续多个 掩盖。用于预测片段内容 # 添加按照比例 subset_with_prob 掩码 """ batch, seq_len, device = *mask.shape, mask.device x = mask.triu() # diagonal 设置偏移 # diagonal=random.randint(1,int(seq_len/2)) # 生成随机的掩盖长度 # 设置最大的掩盖长度 diagonal=random.randint(1,int(seq_len/2)) max_d=min(24,diagonal) diagonal=random.randint(5,max_d) # 生成随机的掩盖长度 y = mask.tril(diagonal=diagonal) # diagonal 设置偏移 a_mask=torch.where(y==0,y,x) # a_mask.bool() mask1=a_mask.bool() # print("mask1",mask1) if subset_with_prob==True: # 融合概率mask mask2=get_mask_subset_with_prob(mask,prob*2) # print("mask2",mask2) # print("mask1",mask1) a_mask=torch.where(mask1==True,mask1,mask2) # a_mask=torch.where(a_mask==False,a_mask,mask2) # a_mask=get_mask_subset_with_prob(a_mask,prob) return a_mask else: return mask1
# main class # main class
[文档]class autoMask(nn.Module): """ 动态mask数据 示例 >>> from transformers import BertTokenizer >>> tokenizer = BertTokenizer.from_pretrained("uer/chinese_roberta_L-2_H-128") # dir(tokenizer) >>> tomask = autoMask( >>> # transformer, >>> mask_token_id = tokenizer.mask_token_id, # the token id reserved for masking >>> pad_token_id = tokenizer.pad_token_id, # the token id for padding >>> mask_prob = 0.05, # masking probability for masked language modeling >>> replace_prob = 0.90, # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper >>> mask_ignore_token_ids = [tokenizer.cls_token_id,tokenizer.eos_token_id] # other tokens to exclude from masking, include the [cls] and [sep] here >>> ) 修改默认的pad和mask_token_id 默认使用https://huggingface.co/uer/chinese_roberta_L-2_H-128/blob/main/vocab.txt """ def __init__( self, # transformer, mask_prob = 0.15, replace_prob = 0.9, num_tokens = None, random_token_prob = 0., mask_token_id = 103, pad_token_id = -100, mask_ignore_token_ids = [], probabilitis = [0.9,0.05,0.05]): super().__init__() # self.transformer = transformer # mlm related probabilities self.mask_prob = mask_prob self.replace_prob = replace_prob self.num_tokens = num_tokens self.random_token_prob = random_token_prob # token ids self.pad_token_id = pad_token_id self.mask_token_id = mask_token_id self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id]) self.probabilitis=probabilitis
[文档] def forward(self, input,indices=False, **kwargs): """ indices :获取mask的索引 """ # do not mask [pad] tokens, or any other tokens in the tokens designated to be excluded ([cls], [sep]) # 不要屏蔽[pad]令牌或指定排除的令牌中的任何其他令牌([cls]、[sep]) # also do not include these special tokens in the tokens chosen at random # 也不要在随机选择的令牌中包含这些特殊令牌 no_mask = mask_with_tokens(input, self.mask_ignore_token_ids) # mask = get_mask_subset_with_prob(~no_mask, self.mask_prob) # 设置每种方案的概率 c=np.random.choice(len(self.probabilitis),1, p=self.probabilitis) if c[0]==0: # 默认 概率掩盖 mask = get_mask_subset_with_prob(~no_mask, self.mask_prob) elif c[0]==1: # 随机上下三角 mask = get_mask_subset_with_prob_tri(~no_mask, self.mask_prob) else: # 对角线 相当于 连续掩盖 mask = get_mask_subset_with_prob_diagonal(~no_mask, self.mask_prob) # print("mask结果\n",mask) # get mask indices 获取掩码索引 mask_indices = torch.nonzero(mask, as_tuple=True) # mask input with mask tokens with probability of `replace_prob` (keep tokens the same with probability 1 - replace_prob) masked_input = input.clone().detach() # if random token probability > 0 for mlm if self.random_token_prob > 0: assert self.num_tokens is not None, 'num_tokens keyword must be supplied when instantiating MLM if using random token replacement' random_token_prob = prob_mask_like(input, self.random_token_prob) random_tokens = torch.randint(0, self.num_tokens, input.shape, device=input.device) random_no_mask = mask_with_tokens(random_tokens, self.mask_ignore_token_ids) random_token_prob &= ~random_no_mask random_indices = torch.nonzero(random_token_prob, as_tuple=True) masked_input[random_indices] = random_tokens[random_indices] # [mask] input replace_prob = prob_mask_like(input, self.replace_prob) # print("replace_prob",replace_prob) masked_input = masked_input.masked_fill(mask * replace_prob, self.mask_token_id) # mask out any tokens to padding tokens that were not originally going to be masked labels = input.masked_fill(~mask, self.pad_token_id) if indices==True: return masked_input,labels,mask_indices else: return masked_input,labels