突然发现TensorFlow R12训练好的样本,用旧版restore会报错:

1
ValueError: Restore called with invalid save path

于是查了查TensorFlow v0.12.0 RC0’s release note:

New checkpoint format becomes the default in tf.train.Saver. Old V1 checkpoints continue to be readable; controlled by the write_version argument, tf.train.Saver now by default writes out in the new V2 format. It significantly reduces the peak memory required and latency incurred during restore.

就是说tf.train.Saver换新的checkpoint格式了,减少了峰值内存占用,但是旧的也能读,多了一个write_version标签, 新的V2,旧的V1,像这样子:

old format (V1) new format (V2)
model.ckpt-12345 model.ckpt-12345.index
model.ckpt-12345.meta model.ckpt-12345.meta
model.ckpt-12345.data-00000-of-00001

但是看了这段话显然我还不知道到底该怎么随意切换呢?然后我扒了扒源码(居然是我第一次扒TensorFlow这么新鲜的代码),找到了 saver.py, a。其中 tf.train.saver 构造函数是:

1
2
def __init__(self, write_version=saver_pb2.SaverDef.V2):
self._write_version = write_version

所以,要写成旧版的checkpoint格式,就要这样:

1
2
3
4
5
import tensorflow as tf
from tensorflow.core.protobuf import saver_pb2
...
saver = tf.train.Saver(write_version = saver_pb2.SaverDef.V1)
saver.save(sess, './model.ckpt', global_step = step)

虽然可以用了,但是有好多WARNING … :

1
2
3
4
5
6
WARNING:tensorflow:*******************************************************
WARNING:tensorflow:TensorFlow's V1 checkpoint format has been deprecated.
WARNING:tensorflow:Consider switching to the more efficient V2 format:
WARNING:tensorflow: `tf.train.Saver(write_version=tf.train.SaverDef.V2)`
WARNING:tensorflow:now on by default.
WARNING:tensorflow:*******************************************************

所以呢,旧的模型文件可以这样凑合用一下,还是尽快把程序统一升级到R12吧,毕竟已经支持pip无缝安装了。