解决tensorflow 的 Saver.restore()无法从本地读取变量的问题
最近做tensorflow 手写数字识别的时候遇到了一个问题,Saver的restore()方法无法从本地恢复变量,导致了每次都会重新训练。
原来代码
saver = tf.train.Saver(max_to_keep=5)epoch = tf.Variable(0, name='epoch', trainable=False)sess = tf.Session()sess.run(tf.global_variables_initializer())ckpt_dir = "./model/"if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir)ckpt = tf.train.latest_checkpoint(ckpt_dir)if ckpt != None: saver.restore(sess, ckpt)else: print('Train from scratch')start = sess.run(epoch)
修改代码
epoch = tf.Variable(0, name='epoch', trainable=False)saver = tf.train.Saver(max_to_keep=5)sess = tf.Session()sess.run(tf.global_variables_initializer())ckpt_dir = "./model/"if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir)ckpt = tf.train.latest_checkpoint(ckpt_dir)if ckpt != None: saver.restore(sess, ckpt)else: print('Train from scratch')start = sess.run(epoch)
其实主要改变的就是以下两行的顺序
epoch = tf.Variable(0, name='epoch', trainable=False)
saver = tf.train.Saver(max_to_keep=5)