本文共 5793 字,大约阅读时间需要 19 分钟。
我们要使用别人已经训练好的模型,就必须将.ckpt文件中的参数加载进来。我们如何有选择的加载.ckpt文件中的参数呢。首先我们要查看.ckpt都保存了哪些参数:
上代码:
import tensorflow as tf import osfrom tensorflow.python import pywrap_tensorflowmodel_dir='./model'#设置模型所在文件夹checkpoint_path = os.path.join(model_dir, "fineturing_model.ckpt")#定位ckpt文件# 从checkpoint中读出数据reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)# reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法var_to_shape_map = reader.get_variable_to_shape_map()# 输出权重tensor名字和值for key in var_to_shape_map: print("tensor_name: ", key,reader.get_tensor(key).shape)
然后我们,照着原来的模型来搞清楚该参数是否应该加载:
接下来我们来看如何有选择的加载,代码如下:
import tensorflow as tfimport tensorflow.contrib.slim as slim #我们要用到的模块with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) saver1 = tf.train.Saver() #设置默认图 model_name = 'xxxx/xxx/model.ckpt' #saver = tf.train.import_meta_graph('xxx/xxx/model.meta') #variables = tf.contrib.framework.get_variables_to_restore() #有选择的恢复参数 include = ['var_name/wc1','var_name/wc2','var_name/wc3a'.....] variables_to_restore = slim.get_variables_to_restore(include=include) saver = tf.train.Saver(variables_to_restore) print(variables_to_restore) #打印要加载的参数 saver.restore(sess,model_name) saver1.save(sess,'./model2/fineturing_model.ckpt')
注意:这里应该格外注意saver1 和 saver 的先后顺序关系。saver1 = tf.train.Saver()。默认将我们模型中所出现的参数(set1)全都保存,类似于限制默认图参数。而saver = tf.train.import_meta_graph()。表示将模型中所出现的参数集(set2)加载进来,可以理解为定义默认图中就这些参数(set2)。但是set2 真含于 set1,因此如果saver1在saver后定义,当保存某个参数A存在于set1但不存在与set2时,会报错:Key NotFoundError (see above for traceback): Variable_xxx not found in checkpoint。我们也可以添加tf.reset_default_graph()来设置默认图。
到这里我们就知道了如何有选择的加载预训练模型来进行迁移学习了。
我们知道,有的模型后缀名是.model。例,c3d 网络在UCF101上的一个预训练模型:sports1m_finetuning_ucf101.model。这种应该怎么加载呢?方法其实是一样的,以sports1m_finetuning_ucf101.model为例:
import tensorflow as tfimport tensorflow.contrib.slim as slim #我们要用到的模块with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) saver1 = tf.train.Saver() #设置默认图 model_name = 'xxxx/xxx/sports1m_finetuning_ucf101.model' #有选择的恢复参数 include = ['var_name/wc1','var_name/wc2','var_name/wc3a','var_name/wc3b','var_name/wc4a',"var_name/wc4b","var_name/wc5a",'var_name/wc5b', 'var_name/bc1','var_name/bc2','var_name/bc3a','var_name/bc3b','var_name/bc4a',"var_name/bc4b","var_name/bc5a",'var_name/bc5b'] variables_to_restore = slim.get_variables_to_restore(include=include) saver = tf.train.Saver(variables_to_restore) print(variables_to_restore) #打印要加载的参数 saver.restore(sess,model_name) saver1.save(sess,'./model2/fineturing_model.ckpt')
相关代码,注意中文注释部分:
model_name = "./model/fineturing_model.ckpt"def run_training(batch_size,dropout,epochs): #weight = [weights['wc1'],weights['wc2'],weights['wc3a'],weigths['wc3b'],weights['wc4a'],weights['wc4b'],weights['wc5a'],weights['wc5b']] with tf.Graph().as_default(): with tf.variable_scope('var_name') as var_scope: #变量初始化过程 weights = { 'wc1': _variable_with_weight_decay('wc1', [3, 3, 3, 3, 64], 0.04, 0.00), 'wc2': _variable_with_weight_decay('wc2', [3, 3, 3, 64, 128], 0.04, 0.00), 'wc3a': _variable_with_weight_decay('wc3a', [3, 3, 3, 128, 256], 0.04, 0.00), 'wc3b': _variable_with_weight_decay('wc3b', [3, 3, 3, 256, 256], 0.04, 0.00), 'wc4a': _variable_with_weight_decay('wc4a', [3, 3, 3, 256, 512], 0.04, 0.00), 'wc4b': _variable_with_weight_decay('wc4b', [3, 3, 3, 512, 512], 0.04, 0.00), 'wc5a': _variable_with_weight_decay('wc5a', [3, 3, 3, 512, 512], 0.04, 0.00), 'wc5b': _variable_with_weight_decay('wc5b', [3, 3, 3, 512, 512], 0.04, 0.00), 'cam':_variable_with_weight_decay('cam', [1,1,512,c3d_model.NUM_CLASSES], 0.04,0.00), } biases = { 'bc1': _variable_with_weight_decay('bc1', [64], 0.04, 0.0), 'bc2': _variable_with_weight_decay('bc2', [128], 0.04, 0.0), 'bc3a': _variable_with_weight_decay('bc3a', [256], 0.04, 0.0), 'bc3b': _variable_with_weight_decay('bc3b', [256], 0.04, 0.0), 'bc4a': _variable_with_weight_decay('bc4a', [512], 0.04, 0.0), 'bc4b': _variable_with_weight_decay('bc4b', [512], 0.04, 0.0), 'bc5a': _variable_with_weight_decay('bc5a', [512], 0.04, 0.0), 'bc5b': _variable_with_weight_decay('bc5b', [512], 0.04, 0.0), } images_placeholder, labels_placeholder = placeholder_inputs(batch_size) logits, CAM = c3d_model.inference_c3d(images_placeholder[:,:,:,:,:], dropout, batch_size, weights, biases)#导入模型结构 loss_ = loss(logits,labels_placeholder) accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels_placeholder), tf.float32)) train_op = tf.train.AdamOptimizer(1e-4).minimize(loss_) #softmax_ = soft(logits) print('**********') #reader = pywrap_tensorflow.NewCheckpointReader(model_name) with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) saver1 = tf.train.Saver() saver = tf.train.import_meta_graph('./model/fineturing_model.ckpt.meta') #tf.reset_default_graph() saver.restore(sess,model_name) saver1.save(sess,'./model2/fineturing_model.ckpt') for i in range (): ......
转载地址:http://vmten.baihongyu.com/