将tensorflow版本的.ckpt模型转成pytorch的.bin模型
用google-research官方的bert源码(tensorflow版本)对新的法律语料进行微调,迭代次数为100000次,每隔1000次保存一下模型,得到的结果如下:将最后三个文件取出,改名为bert_model.ckpt.data-00000-of-00001、bert_model.ckpt.index、bert_model.ckpt.meta加上之前微调使用过的config.json以及
·
用google-research官方的bert源码(tensorflow版本)对新的法律语料进行微调,迭代次数为100000次,每隔1000次保存一下模型,得到的结果如下:
将最后三个文件取出,改名为bert_model.ckpt.data-00000-of-00001、bert_model.ckpt.index、bert_model.ckpt.meta
加上之前微调使用过的config.json以及vocab.txt文件,运行如下文件后生成pytorch.bin,之后就可以被pytorch得代码调用了。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
import logging
logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = BertForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
#
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default = './chinese_L-12_H-768_A-12_improve1/bert_model.ckpt',
type = str,
help = "Path to the TensorFlow checkpoint path.")
parser.add_argument("--bert_config_file",
default = './chinese_L-12_H-768_A-12_improve1/config.json',
type = str,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default = './chinese_L-12_H-768_A-12_improve1/pytorch_model.bin',
type = str,
help = "Path to the output PyTorch model.")
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.bert_config_file,
args.pytorch_dump_path)

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。
更多推荐
所有评论(0)