使用tensorflow保存和恢复模型saver.restore

2024-03-04 0 602
目录
  • tensorflow保存和恢复模型saver.restore
    • 保存模型
    • 恢复模型(一定细看代码注释!!!)
  • tensorflow里的,保存和恢复模型的方式
    • 第一种情况
    • 第二种情况
  • 总结

    tensorflow保存和恢复模型saver.restore

    本文只对一些细节点做补充,大体的步骤就不详述了

    保存模型

    ① 首先我使用的是tensorflow-gpu 1.4.0

    ② 这个版本生成的ckpt文件是这样的:

    使用tensorflow保存和恢复模型saver.restore

    其中.meta存放的是网络模型和所有的变量;

    .index 和.data一起存放变量数据

    -0 -500表示checkpoint点

    ③ 保存的配置(一定细看代码注释!!!)

    import tensorflow as tf
    w1 = tf.Variable(变量的初始化, name=\’w1\’)
    w2 = tf.Variable(变量的初始化, name=\’w2\’)
    saver = tf.train.Saver([w1,w2],max_to_keep=5, keep_checkpoint_every_n_hours=2) # 这里是细节部分,可以指定保存的变量,每两小时保存最近的5个模型
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver.save(sess, \’./checkpoint_dir/MyModel\’,global_step=step,write_meta_graph=False)) # 因为模型没必要多次保存,所以写为False

    恢复模型(一定细看代码注释!!!)

    代码:

    import tensorflow as tf
    with tf.Session() as sess:
    saver = tf.train.import_meta_graph(模型路径) # 模型路径中必须指定到具体的模型下如:xx.ckpt-500.meta,且一般来讲,所有模型都是一样的,如果没有改变模型的条件下。
    # 下面的restore就是在当前的sess下恢复了所有的变量
    saver.restore(sess,数据路径) # 数据路径也必须指定到具体某个模型的数据,但创建这个路径的方法很多,比如调用最后一个保存的模型tf.train.latest_checkpoint(\’./checkpoint_dir\’),也可以是xx.ckpt-500.data,并且这两个是等效的,如果是xx.ckpt-0.data,就是第一个模型的数据
    print(sess.run(\’w1:0\’)) # 这里的w1必须加上:0

    tensorflow里的,保存和恢复模型的方式

    重点在于,第一个文件用于 训练,保存图meta和训练好的参数data(后缀),在另一个文件中导入这个图和训练好的参数,用于预测或者接着训练。

    大大减少了另一个文件里的 重复

    第一种情况

    产生变量的代码和恢复变量的代码在同一个文件时,可以直接如下调用:

    # 建模型
    saver = tf.train.Saver()

    with tf.Session() as sess:
    # 存模型,注意此处的model是文件名,不是路径
    saver.save(sess, \”/tmp/model\”)

    with tf.Session() as sess:
    # 恢复模型
    saver.restore(sess, \”/tmp/model\”)

    第二种情况

    不想在另一个文件中,把产生变量的 一大堆代码重敲一遍,可以直接从保存好的 meta文件和data文件中恢复出来

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    # @Time : 2019/9/9 20:49
    # @Author : ZZL
    # @File : 保存检查点文件,并恢复.py
    import tensorflow as tf
    # Saving contents and operations.
    v1 = tf.placeholder(tf.float32, name=\”v1\”)
    v2 = tf.placeholder(tf.float32, name=\”v2\”)
    v3 = tf.multiply(v1, v2)
    vx = tf.Variable(10.0, name=\”vx\”)
    v4 = tf.add(v3, vx, name=\”v4\”)
    saver = tf.train.Saver([vx])
    with tf.Session() as sess:
    with tf.device(\’/cpu:0\’):
    sess.run(tf.global_variables_initializer())
    sess.run(vx.assign(tf.add(vx, vx)))
    result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})
    print(result)
    print(saver.save(sess, \”./model_ex1\”)) # 该方法返回新创建的检查点文件的路径前缀。这个字符串可以直接传递给对“restore()”的调用。
    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    # @Time : 2019/9/9 20:54
    # @Author : ZZL
    # @File : 恢复文件.py
    import tensorflow as tf

    saver = tf.train.import_meta_graph(\”./model_ex1.meta\”)
    sess = tf.Session()
    saver.restore(sess, \”./model_ex1\”)
    result = sess.run(\”v4:0\”, feed_dict={\”v1:0\”: 12.0, \”v2:0\”: 3.3})
    print(result)

    先来个空图,loaded_graph,在会话中,导入之前构建好的图的文件 后缀meta,loader.restore(sess, save_model_path)

    在当前的loaded_graph中,导入构建好的图和图上的变量值。

    def test_model():

    test_features, test_labels = pickle.load(open(\’preprocess_test.p\’, mode=\’rb\’))
    loaded_graph = tf.Graph() # <tensorflow.python.framework.ops.Graph object at 0x0000017CB3702320>
    # print( loaded_graph)
    # print(tf.get_default_graph()) # <tensorflow.python.framework.ops.Graph object at 0x0000017C9A0C0C50>
    with tf.Session(graph=loaded_graph) as sess:
    # 读取模型
    loader = tf.train.import_meta_graph(save_model_path + \’.meta\’)
    print(loader)
    loader.restore(sess, save_model_path)

    print(tf.get_default_graph()) # <tensorflow.python.framework.ops.Graph object at 0x0000017CB3702320>
    # 从已经读入的模型中 获取tensors
    loaded_x = loaded_graph.get_tensor_by_name(\’x:0\’)
    loaded_y = loaded_graph.get_tensor_by_name(\’y:0\’)
    loaded_keep_prob = loaded_graph.get_tensor_by_name(\’keep_prob:0\’)
    loaded_logits = loaded_graph.get_tensor_by_name(\’logits:0\’)
    loaded_acc = loaded_graph.get_tensor_by_name(\’accuracy:0\’)

    # 获取每个batch的准确率,再求平均值,这样可以节约内存
    test_batch_acc_total = 0
    test_batch_count = 0

    for test_feature_batch, test_label_batch in helper.batch_features_labels(test_features, test_labels, batch_size):
    test_batch_acc_total += sess.run(
    loaded_acc,
    feed_dict={loaded_x: test_feature_batch, loaded_y: test_label_batch, loaded_keep_prob: 1.0})
    test_batch_count += 1

    总结

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持悠久资源网。

    您可能感兴趣的文章:

    • 解决tensorflow1.x版本加载saver.restore目录报错的问题
    • TensorFLow用Saver保存和恢复变量
    • TensorFlow模型保存/载入的两种方法
    • TensorFlow Saver:保存和读取模型参数.ckpt实例

    收藏 (0) 打赏

    感谢您的支持,我会继续努力的!

    打开微信/支付宝扫一扫,即可进行扫码打赏哦,分享从这里开始,精彩与您同在
    点赞 (0)

    悠久资源 Python 使用tensorflow保存和恢复模型saver.restore https://www.u-9.cn/jiaoben/python/183090.html

    常见问题

    相关文章

    发表评论
    暂无评论
    官方客服团队

    为您解决烦忧 - 24小时在线 专业服务