Tensorflow 保存和加载model

Posted by yaohong on Wednesday, October 21, 2020

TOC

Tensorflow 保存和加载model

保存model的代码:

import tensorflow as tf
save_model_path 	= "/save_model" # 保存的文件夹路径
network.save(save_model_path);

保存的目录格式如下,期中saved_model.pb是主要的文件:

save_model
	-- assets
	-- variables
		-- variables.data-00000-of-00001
		-- variables.index
	-- saved_model.pb

加载model如下:

# 一行就行
load_model = tf.keras.models.load_model(save_model_path)
# 这是使用model进行推理
res = load_model.predict(train_images)
print("load model predict res:", res);

推理结果与原来一致:

load model predict res: [[0.0000000e+00 5.4650146e-10 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00 0.0000000e+00 1.0000000e+00 0.0000000e+00
  0.0000000e+00]
 [0.0000000e+00 1.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  1.0000000e+00]
 [0.0000000e+00 2.2018748e-10 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00 3.7330758e-33 1.0000000e+00 0.0000000e+00
  2.9055077e-38]
 [0.0000000e+00 2.2518517e-14 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00 1.0000000e+00 0.0000000e+00 0.0000000e+00
  1.8837843e-13]]

「点个赞」

Yaohong

点个赞

使用微信扫描二维码完成支付