ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

打造自己的图像识别模型

2019-05-15 17:38:34  阅读:279  来源: 互联网

标签:satellite 图像识别 训练 模型 train 打造 数据 dir


1.目标

本篇文章介绍的重点是如何使用TensorFlow在自己的图像数据上训练深度学习模型,主要涉及的方法是对已经预训练好的ImageNet模型进行微调(Fine-tune)。使用谷歌的Colaboratory(python3 环境)实现。

2.微调原理

什么是微调?这里以VGG16为例进行讲解。

如图下图所示,VGG16的结构为卷积+全连接层。卷积层分为5个部分共13层,即图中的conv1~conv5。还有3层是全连接层,即图中的fc6、fe7、fc8。卷积层加上全连接层合起来一共为16层,因此它被称为VGG16。如果要将VGG16的结构用于一个新的数据集,首先要去掉fc8这一层。原因是fc8层的输入是fc7层的特征,输出是1000类的概率,这1000类正好对应了ImageNet 模型中的1000个类别。在自己的数据中,类别数一般不是1000类,因此fc8层的结构在此时是不适用的,必须将fc8层去掉,重新采用符合数据集类别数的全连接层,作为新的fe8。比如数据集为5类,那么新的fc8的输出也应当是5类。

此外,在训练的时候,网络的参数的初始值并不是随机化生成的,而是采用VGG16在ImageNet 上已经训练好的参数作为训练的初始值。这样做的原因在于,在ImageNet数据集上训练过的VGG16中的参数已经包含了大量有用的卷积过滤器,与其从零开始初始化VGG16的所有参数,不如使用已经训练好的参数当作训练的起点。这样做不仅可以节约大量训练时间,而且有助于分类器性能的提高。

载入VGG16的参数后,就可以开始训练了。此时需要指定训练层数的范围。一般来说,可以选择以下几种范围进行训练:

  • 只训练fc8。训练范围一定要包含fc8这一层。之前说过,fc8的结构被调整过,因此它的参数不能直接从lmageNet预训练模型中取得。可以只训练fe8,保持其他层的参数不动。这就相当于将VGG16当作一个“特征提取器”:用fc7层提取的特征做一个Softmax模型分类。这样做的好处是训练速度快,但往往性能不会太好。
  • 训练所有参数。还可以对网络中的所有参数进行训练,这种方法的训练速度可能比较慢,但是能取得较高的性能,可以充分发挥深度模型的威力。
  • 训练部分参数。通常是固定浅层参数不变,训练深层参数。如固定 conv1、conv2部分的参数不训练,只训练 conv3、conv 4、conv 5、fc6、fc7、fc8的参数。

这种训练方法就是所谓的对神经网络模型做微调。借助微调,可以从预训练模型出发,将神经网络应用到自己的数据集上。下面介绍如何在TensorFlow中进行微调。

3.TensorFlow Slim 微调

TensorFlow Slim 是Google公司公布的一个图像分类工具包,它不仅定义了一些方便的接口,还提供了很多ImageNet数据集上常用的网络结构和预训练模型。截至2017年7月,Slim提供包括VGG16、VGG19、Inception V1~V4、ResNet 50、ResNet 101、MobileNet在内大多数常用模型的结构以及预训练模型,更多的模型还会被持续添加进来。

3.1 数据准备

首先要将自己的数据集切分为训练集和验证集,训练集用于训练模型,验证集用来验证模型的准确率。本次使用的是卫星图片分类数据集,这个数据集一共有6个类别,见下表所示:

类别名 含义
Wetland 农田
Glacier 冰川
Urban 城市区域
Rock 岩石
water 水域
Wood 森林

在data_prepare目录中,用一个pic文件夹保存原始的图像文件,图像文件保存的结构如下:

data prepare/
    pic/
        train/
            wood/
            water/
            rock/
            wetland/
            glacier/
            urban/
        validation/
            wood/
            water/
            rock/
            wetland/
            glacier/
            urban/

将图片分为trainvalidation两个目录,分别表示训练使用的图片和验证使用的图片。在每个目录中,分别以类别名为文件夹名保存所有图像。在每个类别文件夹下,存放的就是原始的图像(如jpg格式的图像文件)。下面,在data_prepare文件夹下,使用预先编制好的脚本data_convert.py,将图片转换为为tfrecord格式:

!python data_ convert.py -t pic/ \
    --train-shards 2 \
    --validation-shards 2 \
    --num-threads 2 \
    --dataset-name satellite

解释这里参数的含义:

  • -t pic/:表示转换pic文件夹中的数据。pic文件夹中必须有一个train目录和一个validation目录,分别代表训练和验证数据集。每个目录下按类别存放了图像数据。
  • --train-shards 2:将训练数据集分为两块,即最后的训练数据就是两个tfrecord格式的文件。如果读者的数据集较大,可以考虑将其分为更多的数据块。
  • --validation-shards 2:将验证数据集分为两块。
  • --num-threads 2:采用两个线程产生数据。注意线程数必须要能整除train-shards 和validation-shards,来保证每个线程处理的数据块数是相同的。
  • --dataset-name satellite:给生成的数据集起一个名字。这里将数据集起名叫“satellite”,最后生成文件的开头就是satelite_train 和satelite_validation。

运行上述命令后,就可以在pic文件夹中找到5个新生成的文件,分别是训练数据 satellite_train_00000-of-00002.tfrecord、satellite_train_00001-of-00002.tfrecord,以及验证数据 satellite validation_00000-of-00002.tfrecord、satellite validation_00001-of-00002.tfrecord。另外,还有一个文本文件label.txt,它表示图片的内部标签(数字)到真实类别(字符串)之间的映射顺序。如图片在tfrecord中的标签为0,那么就对应label.txt 第一行的类别,在tfrecord的标签为1,就对应label.txt中第二行的类别,依此类推。

3.2 下载TensorFlow Slim

如果需要使用Slim微调模型,首先要下载Slim的源代码。Slim的源代码保存在tensorflow/models项目中,可以使用下面的git命令下载tensorflow/models:

git clone https://github.com/tensorflow/models. git

找到models/research/目录中的slim文件夹,这就是要用到的TensorFlow Slim的源代码。这里简单介绍TensorFlow Slim的代码结构,见下表。

文件夹或文件名 用途
datasets/ 定义一些训练时使用的数据集。如果需要训练自己的数据,必须同样在datasets文件夹中进行定义,会在下面对此进行介绍
nets/ 定义了一些常用的网络结构,如AlexNet、VGGl6、VGG19、Inception 系列、ResNet、MobileNet等
preprocessing/ 在模型读入图片前,常常需要对图像做预处理和数据增强。这个文件夹针对不同的网络,分别定义了它们的预处理方法
scripts 包含了一些训练的示例脚本
train_ image_classifier.py 训练模型的入口代码
eval_image_classifier.py 验证模型性能的入口代码
download_and _convert data.py 下载并转换数据集格式的入口代码

上表只列出了TensorFlow Slim中最重要的几个文件以及文件夹的作用。其他还有少量文件和文件夹,如果读者对它们的作用感兴趣,可以自行参阅其文档。

3.3 定义新的datasets文件

在slim/datasets中,定义了所有可以使用的数据库,为了使用在第3.1节中创建的tfrecord数据进行训练,必须要在datasets中定义新的数据库。

首先,在datasets/目录下新建一个文件satellite.py,并将flowers.py文件中的内容复制到satellite.py中。接下来,需要修改以下几处内容。

第一处是_FILE_PATTERNSPLITS_TO_SIZES_NUM_CLASSES,将其进行以下修改:

_FILE_PATTERN='satellite _%s*. tfrecord'
SPLTTS_TO_SIZES={' train:4800,' validation':1200}
_NUM_CLASSES=6

_FILE_PATTERN 变量定义了数据的文件名的格式和训练集、验证集的数量。

_NUM_CLASSES 变量定义了数据集中图片的类别数目。

第二处修改为image/format部分,将之修改为:

'image/format': tf. FixedLenFeature((), tf. string, default_value ='jpg'),

此处定义了图片的默认格式。收集的卫星图片的格式为jpg图片,因此修改为jpg。最后,读者也可以对文件中的注释内容进行合适的修改。修改完satellite.py后,还需要在同目录的dataset_factory.py文件中注册satellite数据库。如下:

datasets_map={
'cifar10':cifarl0,
'flowers':flowers,
'imagenet':imagenet,
'satellite':satellite,}

3.4 准备训练文件夹

定义完数据集后,在slim文件夹下再新建一个satellite目录,在这个目录中,完成最后的几项准备工作:

  • 新建一个data目录,并将第3.1节中准备好的5个转换好格式的训练数据复制进去。
  • 新建一个空的train_dir目录,用来保存训练过程中的日志和模型。
  • 新建一个pretrained目录,在slim的GitHub页面找到InceptionV3模型的下载地址http:/download.tensorflow.org/models/inception_V3_2016_0828.tar.gz,下载并解压后,会得到一个 inception_v3.ckpt文件,将该文件复制到pretrained目录下。

3.5 开始训练

在slim文件夹下,运行以下命令就可以开始训练了:

!python train_image_classifier. py \
--train_dir=satellite/train_dir \
--dataset_name=satellite \
--dataset_split_name=train \
--model_name=inception_v3 \
--checkpoint_path=satellite/pretrained/inception_v3. ckpt \
--checkpoint_exclude_scopes=InceptionV3/Logits, InceptionV3/AuxLogits \
--trainable_scopes=InceptionV3/Logits, InceptionV3/AuxLogits \
--max_number_of steps=100000 \
--batch_size=32 \
--learning_rate=0.001 \
--learning_rate_decay_type=fixed  \
--save_interval_secs=300 \
--save_summaries_secs=2 \
--log_every_n_steps=10 \
--optimizer=rmsprop \
--weight_decay=0.00004

这里的参数比较多,下面一一进行介绍:

  • --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits:首先来解释参数trainable_scopes的作用,因为它非常重要。trainable_scopes规定了在模型中微调变量的范围。这里的设定表示只对InceptionV3/Logits,InceptionV3/AuxLogits 两个变量进行微调,其他变量都保持不动。InceptionV3/Logits,InceptionV3/AuxLogits 就相当于在第2章中所讲的fc8,它们是Inception V3的“末端层”。如果不设定trainable_scopes,就会对模型中所有的参数进行训练。
  • --train_dir-satellite/train_dir:表明会在satellite/train_dir目录下保存日志和checkpoint。
  • --dataset_name=satellite、--dataset_split_name=train:指定训练的数据集。
  • --dataset_dir=satellite/data:指定训练数据集保存的位置。
  • --model_name=inception_v3:使用的模型名称。
  • --checkpoint_path=satellite/pretrained/inception_v3.ckpt:预训练模型的保存位置。
  • --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits:在恢复预训练模型时,不恢复这两层。正如之前所说,这两层是InceptionV3模型的末端层,对应着ImageNet数据集的1000类,和当前的数据集不符,因此不要去恢复它。
  • --max_number_of_steps=100000:最大的执行步数。
  • --batch_size=32:每步使用的batch数量。
  • ---learning_rate=0.001:学习率。
  • --learning_rate_decay_type=fixed:学习率是否自动下降,此处使用固定的学习率。
  • --save_interval_secs=300:每隔300s,程序会把当前模型保存到train_dir中。此处就是目录satellite/train_dir。
  • --save_summaries_secs=2:每隔2s,就会将日志写入到train_dir中。可以用TensorBoard 查看该日志。此处为了方便观察,设定的时间间隔较多,实际训练时,为了性能考虑,可以设定较长的时间间隔。
  • --log_every_n_steps=10:每隔10步,就会在屏幕上打出训练信息。
  • --optimizer=rmsprop:表示选定的优化器。
  • --weight_decay=0.00004:选定的weight_decay值。即模型中所有参数的二次正则化超参数。

以上命令是只训练未端层InceptionV3/Logits,InceptionV3/AuxLogits,还可以对所有层进行训练:与只训练末端层的命令相比,只有一处发生了变化,即去掉了--trainable_scopes参数。

3.6 验证模型准确率

在slim文件下执行下列命令:

!python eval_image_classifier.py  \
--checkpoint_path=satellite/train_dir  \
-eval_dir=satellite/eval_dir  \
--dataset_name=satellite  \
--dataset_split_name=validation  \
--dataset_dir=satellite/data  \
--model_name=inception_v3

这里参数的含义为:

  • --checkpoint_path=satellite/train_dir:这个参数既可以接收一个目录的路径,也可以接收一个文件的路径。如果接收的是一个目录的路径,如这里的satellite/train_dir,就会在这个目录中寻找最新保存的模型文件,执行验证。也可以指定一个模型进行验证,以第300步的模型为例,在satellite/train_dir文件夹下它被保存为model.ckpt-300.meta、model.ckpt-300.index、model.ckpt-300.data-00000-of-00001三个文件。此时,如果要对它执行验证,给checkpoint_path传递的参数应该为satellite/train_dir/model.ckpt-300。|
  • --eval_dir-=satellite/eval_dir:执行结果的日志就保存在eval_dir中,同样可以通过TensorBoard查看。
  • --dataset_name=satellite、--dataset_split_name=validation:指定需要执行的数据集。注意此处是使用验证集(validation)执行验证。
  • --dataset_dir=satellite/data:数据集保存的位置。
  • --model_name=inception_v3:使用的模型。

执行后,应该会出现类似下面的结果:

eval/Accuracy[0.51]
eval/Recal1_5[0.97333336]

Accuracy 表示模型的分类准确率,而Recall_5表示Top5的准确率,如果不需要top5 。而需要top2或者top3准确率,只要在eval_image_classifier.py中修改下面的部分就可以了:

names_to_values, names_to_updates=slim. metrics. aggregate_metric map({
'Accuracy': slim. metrics. streaming_accuracy (predictions,labels),
'Recall_5': slim. metrics. streaming_recall_at_k(1ogits,labels,5),
})

4 代码及数据集

百度网盘 提取码:8qqt

标签:satellite,图像识别,训练,模型,train,打造,数据,dir
来源: https://www.cnblogs.com/Terrypython/p/10858803.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有