tkitRetransformer.tkitRetransformer 源代码

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel,BertConfig,AutoModel,AutoConfig
import os
[文档]class tkitRetransformer: """ tkitRetransformer 用来修改transformers模型 """ def __init__(self,from_pretrained,tokenizer): """[模块说明] """ MODEL_NAME=from_pretrained self.tokenizer=tokenizer self.config = AutoConfig.from_pretrained(MODEL_NAME) # tokenizer = BertTokenizer.from_pretrained(tokenizer_MODEL_NAME) self.model = AutoModel.from_pretrained(MODEL_NAME) pass
[文档] def edit(self): """Edit 基本的模型编辑示例 """ self.config.position_embedding_type="relative_key_query" self.config.vocab_size=tokenizer.vocab_size self.config.type_vocab_size=100 # tokenizer self.model.embeddings.word_embeddings=nn.Embedding(self.tokenizer.vocab_size, self.model.embeddings.word_embeddings.embedding_dim, padding_idx=self.tokenizer.pad_token_id) # 修改嵌入类型 # help(model.embeddings.token_type_embeddings) self.model.embeddings.token_type_embeddings=nn.Embedding(self.config.type_vocab_size, self.model.embeddings.token_type_embeddings.embedding_dim)
[文档] def save(self, path='./model'): """Save the model path:为保存模型的目录 """ self.config.save_pretrained(path) self.tokenizer.save_pretrained(path) PATH=os.path.join(path,"pytorch_model.bin") torch.save(self.model.state_dict(), PATH) print("model save to:",path)
if __name__ == "__main__": print("测试保存模型") MODEL_NAME="uer/chinese_roberta_L-2_H-512" tokenizer_MODEL_NAME="clue/roberta_chinese_clue_tiny" tokenizer = BertTokenizer.from_pretrained(tokenizer_MODEL_NAME) trt=tkitRetransformer(MODEL_NAME,tokenizer) trt.save()