博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Tensorflow加载模型(进阶版):如何利用预训练模型进行微调(fintuning)
阅读量:3903 次
发布时间:2019-05-23

本文共 5793 字,大约阅读时间需要 19 分钟。

我们要使用别人已经训练好的模型,就必须将.ckpt文件中的参数加载进来。我们如何有选择的加载.ckpt文件中的参数呢。首先我们要查看.ckpt都保存了哪些参数:

上代码:

import tensorflow as tf   import osfrom tensorflow.python import pywrap_tensorflowmodel_dir='./model'#设置模型所在文件夹checkpoint_path = os.path.join(model_dir, "fineturing_model.ckpt")#定位ckpt文件# 从checkpoint中读出数据reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)# reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法var_to_shape_map = reader.get_variable_to_shape_map()# 输出权重tensor名字和值for key in var_to_shape_map:    print("tensor_name: ", key,reader.get_tensor(key).shape)

然后我们,照着原来的模型来搞清楚该参数是否应该加载:

接下来我们来看如何有选择的加载,代码如下:

import tensorflow as tfimport tensorflow.contrib.slim as slim #我们要用到的模块with tf.Session() as sess:            init = tf.global_variables_initializer()            sess.run(init)            saver1 = tf.train.Saver() #设置默认图            model_name = 'xxxx/xxx/model.ckpt'            #saver = tf.train.import_meta_graph('xxx/xxx/model.meta')            #variables = tf.contrib.framework.get_variables_to_restore()            #有选择的恢复参数            include = ['var_name/wc1','var_name/wc2','var_name/wc3a'.....]                        variables_to_restore = slim.get_variables_to_restore(include=include)            saver = tf.train.Saver(variables_to_restore)                         print(variables_to_restore) #打印要加载的参数            saver.restore(sess,model_name)            saver1.save(sess,'./model2/fineturing_model.ckpt')

注意:这里应该格外注意saver1 和 saver 的先后顺序关系。saver1 = tf.train.Saver()。默认将我们模型中所出现的参数(set1)全都保存,类似于限制默认图参数。而saver =  tf.train.import_meta_graph()。表示将模型中所出现的参数集(set2)加载进来,可以理解为定义默认图中就这些参数(set2)。但是set2 真含于 set1,因此如果saver1在saver后定义,当保存某个参数A存在于set1但不存在与set2时,会报错:Key NotFoundError (see above for traceback): Variable_xxx not found in checkpoint。我们也可以添加tf.reset_default_graph()来设置默认图。

到这里我们就知道了如何有选择的加载预训练模型来进行迁移学习了。

我们知道,有的模型后缀名是.model。例,c3d 网络在UCF101上的一个预训练模型:sports1m_finetuning_ucf101.model。这种应该怎么加载呢?方法其实是一样的,以sports1m_finetuning_ucf101.model为例:

import tensorflow as tfimport tensorflow.contrib.slim as slim #我们要用到的模块with tf.Session() as sess:            init = tf.global_variables_initializer()            sess.run(init)            saver1 = tf.train.Saver() #设置默认图            model_name = 'xxxx/xxx/sports1m_finetuning_ucf101.model'            #有选择的恢复参数            include = ['var_name/wc1','var_name/wc2','var_name/wc3a','var_name/wc3b','var_name/wc4a',"var_name/wc4b","var_name/wc5a",'var_name/wc5b',                       'var_name/bc1','var_name/bc2','var_name/bc3a','var_name/bc3b','var_name/bc4a',"var_name/bc4b","var_name/bc5a",'var_name/bc5b']                        variables_to_restore = slim.get_variables_to_restore(include=include)            saver = tf.train.Saver(variables_to_restore)                         print(variables_to_restore) #打印要加载的参数            saver.restore(sess,model_name)            saver1.save(sess,'./model2/fineturing_model.ckpt')

相关代码,注意中文注释部分:

model_name = "./model/fineturing_model.ckpt"def run_training(batch_size,dropout,epochs):        #weight = [weights['wc1'],weights['wc2'],weights['wc3a'],weigths['wc3b'],weights['wc4a'],weights['wc4b'],weights['wc5a'],weights['wc5b']]        with tf.Graph().as_default():                with tf.variable_scope('var_name') as var_scope: #变量初始化过程            weights = {                'wc1': _variable_with_weight_decay('wc1', [3, 3, 3, 3, 64], 0.04, 0.00),                'wc2': _variable_with_weight_decay('wc2', [3, 3, 3, 64, 128], 0.04, 0.00),                'wc3a': _variable_with_weight_decay('wc3a', [3, 3, 3, 128, 256], 0.04, 0.00),                'wc3b': _variable_with_weight_decay('wc3b', [3, 3, 3, 256, 256], 0.04, 0.00),                'wc4a': _variable_with_weight_decay('wc4a', [3, 3, 3, 256, 512], 0.04, 0.00),                'wc4b': _variable_with_weight_decay('wc4b', [3, 3, 3, 512, 512], 0.04, 0.00),                'wc5a': _variable_with_weight_decay('wc5a', [3, 3, 3, 512, 512], 0.04, 0.00),                'wc5b': _variable_with_weight_decay('wc5b', [3, 3, 3, 512, 512], 0.04, 0.00),                'cam':_variable_with_weight_decay('cam', [1,1,512,c3d_model.NUM_CLASSES], 0.04,0.00),            }            biases = {                'bc1': _variable_with_weight_decay('bc1', [64], 0.04, 0.0),                'bc2': _variable_with_weight_decay('bc2', [128], 0.04, 0.0),                'bc3a': _variable_with_weight_decay('bc3a', [256], 0.04, 0.0),                'bc3b': _variable_with_weight_decay('bc3b', [256], 0.04, 0.0),                'bc4a': _variable_with_weight_decay('bc4a', [512], 0.04, 0.0),                'bc4b': _variable_with_weight_decay('bc4b', [512], 0.04, 0.0),                'bc5a': _variable_with_weight_decay('bc5a', [512], 0.04, 0.0),                'bc5b': _variable_with_weight_decay('bc5b', [512], 0.04, 0.0),            }                    images_placeholder, labels_placeholder = placeholder_inputs(batch_size)        logits, CAM = c3d_model.inference_c3d(images_placeholder[:,:,:,:,:], dropout, batch_size, weights, biases)#导入模型结构        loss_ = loss(logits,labels_placeholder)        accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels_placeholder), tf.float32))        train_op = tf.train.AdamOptimizer(1e-4).minimize(loss_)        #softmax_ = soft(logits)        print('**********')        #reader = pywrap_tensorflow.NewCheckpointReader(model_name)        with tf.Session() as sess:            init = tf.global_variables_initializer()            sess.run(init)            saver1 = tf.train.Saver()            saver = tf.train.import_meta_graph('./model/fineturing_model.ckpt.meta')            #tf.reset_default_graph()            saver.restore(sess,model_name)                saver1.save(sess,'./model2/fineturing_model.ckpt')            for i in range ():                ......

 

转载地址:http://vmten.baihongyu.com/

你可能感兴趣的文章
GPU
查看>>
Android Audio Feature
查看>>
我的自传
查看>>
专业音频术语中英文对照
查看>>
集成电路专业术语简介
查看>>
成长日记
查看>>
从3个科技公司里学到的57条经验
查看>>
程序员应该投资的10件事
查看>>
多媒体
查看>>
沟通技巧
查看>>
专业camera/isp术语中英文对照
查看>>
摄像头
查看>>
我的理想,我的奋斗目标
查看>>
Nginx基于多域名、多端口、多IP配置虚拟主机
查看>>
一次Linux 系统受攻击的解决过程
查看>>
最新最全Apache源码编译安装
查看>>
最新mysql数据库源码编译安装。
查看>>
第一章 vue入门
查看>>
Linux文件引用计数的逻辑
查看>>
linux PCIe hotplug arch analysis
查看>>