Tensorflow加载pb文件继续训练
Tensorflow中模型即代码
上面这句话说的很对,当我们只有一个预训练好的pb
文件,我们如何加载这个模型继续训练呢?今天就来解决这个问题.
1. 加载pb文件
我们得到了一个.pb
文件,不论是提取他的参数还是用他进行推理,都得加载这个文件.请看下面几行代码:
def load_model(model, input_map=None):
# Check if the model is a model directory (containing a metagraph and a checkpoint file)
# or if it is a protobuf file with a frozen graph
model_exp = os.path.expanduser(model)
if (os.path.isfile(model_exp)):
print('Model filename: %s' % model_exp)
with gfile.GFile(model_exp, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, input_map=input_map, name='')
else:
print('Model directory: %s' % model_exp)
meta_file, ckpt_file = get_model_filenames(model_exp)
print('Metagraph file: %s' % meta_file)
print('Checkpoint file: %s' % ckpt_file)
saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map)
saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file))
讲解
这里是分两种情况,我们就看第一种:直接使用tf.gfile.GFile
打开此文件,然后获得当前文件的图定义,读入数据,最后将此图导入.
2. 获得操作
一般我们加载了图之后,都是去获得他的占位符去进行输入,然后输出.为了得到所有的权重,我们使用g.get_operations()
获得所有的操作节点.
import tensorflow as tf |
注意:
导入pb
文件之后使用tf.global_variables()
等获取变量的方式都是无效的,获得的都是空值.如下所示:
In [6]: tf.global_variables()
Out[6]: []
In [7]: tf.trainable_variables()
Out[7]: []
3. 获得tensor
当我们有了操作列表之后如何进行读取变量呢?让我们先看看操作列表中的数据:
In [11]: optlist
Out[11]:
[<tf.Operation 'inputs' type=Placeholder>,
<tf.Operation 'MobileNetV1/SpaceToBatchND/block_shape' type=Const>,
<tf.Operation 'MobileNetV1/SpaceToBatchND/paddings' type=Const>,
<tf.Operation 'MobileNetV1/SpaceToBatchND' type=SpaceToBatchND>,
<tf.Operation 'MobileNetV1/Conv2d_0_3x3/weights' type=Const>,
<tf.Operation 'MobileNetV1/Conv2d_0_3x3/weights/read' type=Identity>,
<tf.Operation 'MobileNetV1/Conv2d_0_3x3/Conv2D' type=Conv2D>,
<tf.Operation 'MobileNetV1/Conv2d_0_3x3/BatchNorm/beta' type=Const>,
<tf.Operation 'MobileNetV1/Conv2d_0_3x3/BatchNorm/beta/read' type=Identity>
我们可以看到这里的这些不是变量,并且种类烦杂,不过我直接说明xxxx/read
等操作就是读取预存的权重的操作.因此我们可以直接把这些操作过滤出来.
def get_vars_from_optlist(optlist: list)->list: |
现在我们有个对应的读取变量操作列表,但是要读取变量还是要进行转化,因为varlist
只是一个操作,还没有变成可运行的tensor
,所以我只要在操作名后面加上:0
,同时get_tensor_by_name()
即可得到对应的tensor
def convert_vars_to_tensor(g, varlist: list)->list: |
4. 读取变量
有了tensorlist
,我们可以来读取变量了.为了restore
的方便,我将他保存成字典的形式,并且修改每一个key
都与原图中的变量名相同,这样restore
的时候直接判断名字是否相同即可.
# 将所有变量存入字典 |
5. 保存字典
使用下面的这个函数保存我们的vardict
def save_pkl(obj, name):
with open(name, 'wb') as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
6. 恢复权重
注意: 要恢复权重,你必须得有原图的定义,否则你必须重新写一个.
- 首先定义一个原图,接下来需要回复权重.
- 搜集此图中所有可训练的变量(我这里用except_last控制是否加载最后一层权重)
- 加载之前保存的字典文件
- 使用
tf.assign()
,将modelvarlist
与pre_weight_dict
中名字相同的变量进行赋值,存入optlist
- 使用
sess.run(optlist)
,进行赋值操作 - 大功告成~
def restore_form_pkl(sess: tf.Session(), pklpath: str, except_last=True): |