使用MxNet进行图片的二分类

版权声明:本文所有内容都是原创,如有转载请注明出处,谢谢。

1. 引言

偶尔遇到了这样一个问题,如何进行图片分类。最近deep learning很火,我们就试着用了一把,效果不错。Mxnet里面有很多训练的例子,可是如何拿训练好的模型去预测,如何生成训练数据,都没有实际的例子。这个就做一个实际的例子,来看看mxnet如何在我们工作中被使用。这个应用可以用来分别图片中是否是人,是否是狗,是否是猫。。。

2. 生成训练Rec文件

1
2
python ../tools/make_list.py --train_ratio 0.9 --recursive true ../data/2016-06-12/ samples0612
python ../tools/im2rec.py --train_ratio 0.9 --recursive true --resize 224 samples0612_train ../data/2016-06-12/

3. 利用ImageClassification 进行训练

这里我们使用了inception-bn的网络。相对于lenet, inception的深度更深,参数更复杂,训练和预测也更慢一点,但效果更好。相对于inception,incetion-bn对参数做了batch normalization,更加容易收敛,效果也更好。

1
python2.7 ../train_imagenet.py --network inception-bn --data-dir ./  --model-prefix inception --gpus 0 --num-classes 2 --num-epochs 20 --load-epoch 3 --train-dataset samples0612_train.rec --val-dataset samples0612_val.rec --batch-size 70 --log-dir ./ --log-file train.log

4. 生成测试Rec文件

1
2
python2.7 ${CURR_DIR}/tools/make_list.py ${IMAGE_DATA} ${CURR_DIR}/Images
python2.7 ${CURR_DIR}/tools/im2rec.py --resize 224 ${CURR_DIR}/Images ${IMAGE_DATA}

5. 利用训练好的模型进行预测

1
2
3
4
5
data = mx.io.ImageRecordIter( path_imglst = 'Images.lst',  path_imgrec = 'Images.rec', ctx = mx.gpu(0), data_shape = (3,224,224), batch_size = 50, mean_r  = 123.68, mean_g      = 116.779, mean_b      = 103.939 )
model = mx.model.FeedForward.load('inception-0', 3, data_shape = (3,224,224), batch_size = 50, ctx= mx.gpu(0))
y = model.predict( data )
## y = 1 - y
pkl.dump(y, open( 'y-labels', 'w'))

6. 经验之谈

  1. 图片分类对训练数据很敏感。如果训练数据都是错误的,那么预测就根本不可能正确了。所以训练数据需要非常认真的检查
  2. 训练数据不仅仅需要正确,也需要全面。对于图片方面的训练,更多的图片就能覆盖更多的方面
  3. 训练和预测的时候,参数需要是一样的。其中,我们设置了 mean_r, mean_g, mean_b。