导读

tensorflow的checkpoint模型文件,只包含了模型的参数并不包含模型结构,为了方便使用tensorflow的serving进行部署,我们需要将checkpoint模型转换为saved_model格式

转换代码如下

def ckpt_to_pb(ckpt_path,output_pd_path):
"""
ckpt_path:checkpoint模型文件的目录
output_pd_path:savedmodel模型文件保存的目录
"""
	#加载模型的参数文件
    experiment_folder = "/tmp/"
    config = json.load(open(experiment_folder + 'config.json'))
	#根据的模型参数文件获取模型的结构(输入和输出)
    [x, y_, is_train, y, normalized_y, cost] = train.tf_define_model_and_cost(config)

    graph = tf.Graph()
    with tf.compat.v1.Session(graph=graph) as sess:
		#定义模型的输入输出节点
        SignatureDef = sm.signature_def_utils.build_signature_def(
            inputs={
                "x_input": sm.utils.build_tensor_info(x),
                "is_train": sm.utils.build_tensor_info(is_train)
            },
            outputs={
                "y_sigmoid": sm.utils.build_tensor_info(normalized_y)
            },
            method_name=sm.signature_constants.PREDICT_METHOD_NAME,
        )
		#加载checkpoint模型参数	
        loader = tf.compat.v1.train.import_meta_graph(ckpt_path + ".meta")
        loader.restore(sess,ckpt_path)
		
		#将checkpoint模型转换为savedmodel
        builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(output_pd_path)
        builder.add_meta_graph_and_variables(sess,tags = [tf.compat.v1.saved_model.tag_constants.SERVING],
                                             signature_def_map={sm.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: SignatureDef},
                                             strip_default_attrs=True)

        builder.save()

加载savedmodel模型进行预测

import tensorflow as tf

export_dir = "/save_model"
#加载savedmodel模型
imported = tf.saved_model.load(export_dir)
model = imported.signatures["serving_default"]
#模型预测
pred = model(x_input=tf.convert_to_tensor(input_array), is_train=tf.constant(False))
#获取模型的预测结果
pred = pred["y_sigmoid"].numpy()
Logo

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。

更多推荐