ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

tensorflow/serving部署keras模型

2021-04-30 15:00:41  阅读:226  来源: 互联网

标签:serving keras 模型 path tensorflow model


之前写了一篇tensorflow/serving部署tensorflow模型的文章,记录了详细的操作步骤与常见的错误及解决方案,具体见:TensorFlow Serving模型转换与部署

本文主要记录tensorflow/serving部署keras模型过程中的一些重要步骤,以便后续查阅。

我们在keras中保存模型通常用model.save或者model.save_weights函数。
其中,model.save函数保存的模型往往比的是模型的结构与权重,而model.save_weights函数保存的仅仅是模型的结构,因此model.save函数保存的模型往往比model.save_weights函数保存的模型要大些。

在前一篇tensorflow/serving介绍中TensorFlow Serving模型转换与部署,我们知道tensorflow/serving需要pb格式的模型,而本篇文章我们讨论的keras模型是.h5.weights格式的,因此,首先我们需要将.h5.weights格式的keras模型转换为tensorflow/serving框架可识别的pb格式模型,转换代码如下:

def keras_model_to_tfs(model, export_path):
    signature = tf.saved_model.signature_def_utils.predict_signature_def(
        inputs={'input_x': model.input}, 
        outputs={'output_y': model.output}
    )
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)
    legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
    builder.add_meta_graph_and_variables(
        sess=K.get_session(),
        tags=[tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature,
        },
        legacy_init_op=legacy_init_op)
    builder.save()
    print('Build done.')

简要说明一下keras_model_to_tfs函数的参数
model:导入的keras模型,用keras的load_modelload_weights导入的模型
export_path:转换成pb格式模型后的保存路径

模型转换完成后,剩下的工作就是部署tensorflow/serving框架,并利用grpc接口调用模型预测。
关于具体的tensorflow/serving的部署,可参考之前文章:TensorFlow Serving模型转换与部署,预测代码在之前那篇文章中也有,本文再次贴出一个。

def tfserving_grpc(title, content):
    content = content or title
    content = filter_waste(content)
    model_dir = os.path.join(project_path, 'models_weights')
    with open(os.path.join(model_dir, 'tokenizer.plk'), 'rb') as f:
        tokenizer = pickle.load(f)
    x = tokenizer.texts_to_sequences([jieba.lcut(content)])
    x = x[0]
    if len(x) > MAX_LEN:
        x = x[:MAX_LEN]
    else:
        x = x + [0] * (MAX_LEN - len(x))

    # ip地址为部署tensorflow/serving的IP
    channel = grpc.insecure_channel('xx.xx.xx.xx:8500')  
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
    request = predict_pb2.PredictRequest()
    request.model_spec.name = 'new_yq_model'
    # request.model_spec.version.value = 1000001
    request.model_spec.signature_name = 'serving_default'

    request.inputs["input_x"].CopyFrom(tf.contrib.util.make_tensor_proto(np.array([x], dtype=np.float)))
    response = stub.Predict(request, 10.0)

    results = {}
    for key in response.outputs:
        tensor_proto = response.outputs[key]
        results[key] = tf.contrib.util.make_ndarray(tensor_proto)

    return results

最后给一个main函数的整体过程代码。

model = build_model(len(tokenizer.index_word))
model.load_weights(os.path.join(model_dir, 'best_model.weights'))
model.summary()
export_path = './tfs_models'
keras_model_to_tfs(model, export_path)

参考

使用tensorflow serving部署keras模型(tensorflow 2.0.0)
keras、tensorflow serving踩坑记

标签:serving,keras,模型,path,tensorflow,model
来源: https://blog.csdn.net/tianyunzqs/article/details/116303326

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有