用google-research官方的bert源码(tensorflow版本)对新的法律语料进行微调,迭代次数为100000次,每隔1000次保存一下模型,得到的结果如下:

将最后三个文件取出,改名为bert_model.ckpt.data-00000-of-00001、bert_model.ckpt.index、bert_model.ckpt.meta

加上之前微调使用过的config.json以及vocab.txt文件,运行如下文件后生成pytorch.bin,之后就可以被pytorch得代码调用了。

 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
 import argparse
 import torch
  
 from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
 
 import logging
 logging.basicConfig(level=logging.INFO)
 
 def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
     # Initialise PyTorch model
     config = BertConfig.from_json_file(bert_config_file)
     print("Building PyTorch model from configuration: {}".format(str(config)))
     model = BertForPreTraining(config)
 
     # Load weights from tf checkpoint
     load_tf_weights_in_bert(model, config, tf_checkpoint_path)
 
     # Save pytorch-model
     print("Save PyTorch model to {}".format(pytorch_dump_path))
     torch.save(model.state_dict(), pytorch_dump_path)
 
 #
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     ## Required parameters
     parser.add_argument("--tf_checkpoint_path",
                         default = './chinese_L-12_H-768_A-12_improve1/bert_model.ckpt',
                         type = str,
                         help = "Path to the TensorFlow checkpoint path.")
     parser.add_argument("--bert_config_file",
                         default = './chinese_L-12_H-768_A-12_improve1/config.json',
                         type = str,
                         help = "The config json file corresponding to the pre-trained BERT model. \n"
                             "This specifies the model architecture.")
     parser.add_argument("--pytorch_dump_path",
                         default = './chinese_L-12_H-768_A-12_improve1/pytorch_model.bin',
                         type = str,
                         help = "Path to the output PyTorch model.")
     args = parser.parse_args()
     convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
                                      args.bert_config_file,
                                      args.pytorch_dump_path)

Logo

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。

更多推荐