BETA

fastai動かしてみたLesson2①※殴り書き

投稿日:2019-04-21
最終更新:2019-04-22

fastai動かしてみたLesson2①

今回はこの講義

https://course.fast.ai/videos/?lesson=2

今後の講義の流れとしてはComputer vision(画像?)、自然言語処理、テーブル状のデータ、協調フィルタリングに関する機能を学んで、Embeddingについて学び、Computer visionと自然言語処理についてより詳しく学ぶ。

ビデオの10分辺りまではFastaiを利用した例を紹介している。

テディベア分類器を作ろう

参考:https://github.com/fastai/course-v3/blob/master/nbs/dl1/lesson2-download.ipynb

まず初めにGoogle画像検索を使って画像のURLを収集する。

Google画像検索で"Teddy bear"と検索し、Ctrl, Shift, Jを押すと、Javascriptのコンソール画面が出てくるので、

urls = Array.from(document.querySelectorAll('.rg_di .rg_meta')).map(el=>JSON.parse(el.textContent).ou);  

window.open('data:text/csv;charset=utf-8,' + escape(urls.join('\n')));

を入力し、url_teddies.txtという名前で保存。他に"black bear"と"grizzly bear"についても同様に検索して保存しておく。

保存した画像のURLから画像をダウンロードする

folder = 'black'  
file = 'urls_black.txt'  

folder = 'teddys'  
file = 'urls_teddys.txt'  

folder = 'grizzly'  
file = 'urls_grizzly.txt'  

ここで、画像を保存するディレクトリ名とURLの書いてあるtxtの名前を定義しておく。

そして、ディレクトリを作成する。

path = Path('data/bears')  
dest = path/folder  
dest.mkdir(parents=True, exist_ok=True)  

クラス名を定義して、画像をダウンロードする。

classes = ['teddys','grizzly','black']  
download_images(path/file, dest, max_pics=200)  
#もし上記でエラーなら以下で実行  
#download_images(path/file, dest, max_pics=20, max_workers=0)  

ImageDataBunchを作る

最初に、データ収集した時に画像以外の物が紛れ込んでいるかもしれないので、verify_imagesで確認してみる。

for c in classes:  
    print(c)  
    verify_images(path/c, delete=True, max_size=500)

ここで、再現性のためにランダムシードを設定しておく。

np.random.seed(42)  

ImageDataBunchを作成する。

data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2,  
        ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)

from_folderはデータがフォルダ毎に分けられているときに使う
pathはteddy, black, grizzlyなどを含むファイルのパス
trainはどこにtrainのパス
valid_pctは何割をvalidaion(評価)データにするか
ds_tfmsはデータをどう変形するか
get_transforms()は回転させたり反転させたり指定できる
sizeは読み込ませる画像サイズ
num_workersはGPUの指定 エラーが起こったら0に指定するといいかもしれない
normalizeは標準化でimagenet_stats、つまりImageNetを基にして標準化する

以下で実際に読み込んだ画像を確認できる。

data.show_batch(rows=3, figsize=(7,8))

dataの中身は以下のようにしてみれる。

data.classes, data.c, len(data.train_ds), len(data.valid_ds)  

(['black', 'grizzly', 'teddys'], 3, 448, 111)

困ったときはhelp(data)や? dataなどをJupyter Notebook上で行うといいかもしれない。

Trainingしてみる

learn = create_cnn(data, models.resnet34, metrics=error_rate)

resnet34はResNetの34層のニューラルネットワークを呼び出している
metricsは評価関数で、error_rateはどれくらい外れたか、他にはaccuracyとかもある

create_cnnでresnet34, resnet50, resnet101, resnet152なども使える。

learn = Learner(data, models.wrn_22(), metrics=accuracy)
でWideResNetなども使える。

learn.fit_one_cycle(4)  
learn.save('stage-1')  

one cycleを4epochsで実行する。
learn.fit_one_cycle(4, 1e-3, wd=0.4)
などで、Weight decayやLearning Rateなども指定できる。

fitさせたらlearn.save(filename)でモデルを保存できる。

learn.save('stage-1')  

デフォルトでは大部分のモデルの重み(?)は固定されている。

learn.unfreeze()  

これで重みを動かす事が出来る。

fastaiには適切な学習率を求める関数がある。

learn.lr_find()  
learn.recorder.plot()  

learn.lr_find()で求めた後に、learn.recorder.plot()で図示出来る。

3e-5(=0.000003)を最低の学習率とした。

グラフの谷底の数に設定するのではなく、谷底に行く手前の坂道の上ぐらいを設定するといいのかもしれない。

製作者曰く大体1e-4~3e-4にしておけばいい感じになる。(ブログ著者もそう思う)

learn.save('stage-2')  

モデルを保存しておく。

結果を見てみる

結果を見るときは大体ClassificationInterpretationで混合行列を使えばよく分かる

interp = ClassificationInterpretation.from_learner(learn)  
interp.plot_confusion_matrix()  

この後FileDeleterを使ってデータを整理して精度を上げているが、fastai v1.0.51とかだと入っていなかった気がするので割愛。

予測してみる

fastaiだとこれで画像を見れる

img = open_image(path/'black'/'00000021.jpg')  
img  

モデルを作成して読み込む

classes = ['black', 'grizzly', 'teddys']  

data2 = ImageDataBunch.single_from_classes(path, classes, tfms=get_transforms(), size=224).normalize(imagenet_stats)  

learn = create_cnn(data2, models.resnet34)  
learn.load('stage-2')  

fastaiだとresnetはただの関数としてみなすとかなんとか(よくわかってないですすみません)
データを読み込んでCNNを作ったあと、あらかじめ保存しておいたstage-2を読み込む。

pred_class,pred_idx,outputs = learn.predict(img)  
pred_class  

'black'

こんな感じで予測を出せる。

'00000021.jpg'はblackなので正解。

Learning Rate(LR)やEpoch数が少ないと起こる問題

LRが大きいと

valid lossが極端に大きくなる
Epoch数は関係ない

LRが小さいと

error_rateが下がるのが小さい(accuracyが上がるのが小さい)
実際にlossをプロットしてみるとすごい緩やか

Epoch数が少ない

train loss > valid loss になっている

Epoch数が多い

精度が上がったり下がったりを繰り返す
つまりOver fitting(過剰適合)が発生する

参考:
https://github.com/hiromis/notes/blob/master/Lesson2.md

技術ブログをはじめよう Qrunch(クランチ)は、プログラマの技術アプトプットに特化したブログサービスです
駆け出しエンジニアからエキスパートまで全ての方々のアウトプットを歓迎しております!
or 外部アカウントで 登録 / ログイン する
クランチについてもっと詳しく

この記事が掲載されているブログ

大学生の書きなぐりブログ 間違ってる事も書いてるので自己責任で勉強しましょう

よく一緒に読まれる記事

0件のコメント

ブログ開設 or ログイン してコメントを送ってみよう
目次をみる
技術ブログをはじめよう Qrunch(クランチ)は、プログラマの技術アプトプットに特化したブログサービスです
or 外部アカウントではじめる
10秒で技術ブログが作れます!