KerasでGANは難しい?

公開日:2018-10-17
最終更新:2018-10-24

※2018年3月に調べた内容・Googleドライブに眠っていた個人用メモで投稿テストしたもの

結論

  • KerasでGANを実装するのは難しい(2018/3/1)
  • まだ銀の弾丸はなさそう

KERASでGANを学習するときの問題点

複数ネットワークの学習が必要だが、Kerasは1ネットワークの学習が前提。

Kerasの学習実行部分は↓のようである。

model.fit(...)

 modelは1つのネットワークを表すオブジェクトで、GANのような複数ネットワークの学習は想定されていない。2つのネットワークを学習するときは、↓のように2つのModelのfit()を交互に呼び出す必要がある。

for i in range():
    generator.fit(...)
    discriminator.fit(...)

学習対象切り替えの煩雑さ

 GANのような複数ネットワークの学習ではGとDを交互に学習するが、その際にDの学習のON/OFFの切り替えが必要。Kerasではtrainableフラグで学習のON/OFFを切り替えられる(参考:https://qiita.com/mokemokechicken/items/937a82cfdc31e9a6ca12)が、それをDを構成する各レイヤに対して設定するのは大変(=コードが大きくなる)という問題がある。

そこで様々なWorkaroundが提案されているので、現状をまとめてみた。

GAN実装の実現策

kerasのContainer機能を使用

Container内にLayerを定義すれば、「Containerの内包する全てのLayerにtrainableを設定する必要はない」

tensorflowと併用

tensorflowでは細かく処理を指定できるが、データフローをフルスクラッチする必要がある。特に入力画像の処理が手間。

参考1 tfの簡単なインターフェースとしてのkeras: https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html

Kerasを利用して簡単にレイヤを定義し、Tensorflowでloss計算・学習を行う

# kerasでレイヤ定義
preds = Dense(10, activation='softmax')(x) 
# tensorflowでレイヤ定義
labels = tf.placeholder(tf.float32, shape=(None, 10))
# kerasのレイヤもtensorflowレイヤと同様に扱える
loss = tf.reduce_mean(categorical_crossentropy(labels, preds))

参考2 kerasのissue KerasにGAN機能追加しないか?

https://github.com/keras-team/keras/issues/5312

GANのバリエーションは多岐にわたるので統一されたインターフェースはまだ作れないと結論付けた。今(2017年)はTFでスクラッチするべし。

記事が少しでもいいなと思ったらクラップを送ってみよう!
20
+1
@lilacsの技術ブログ

よく一緒に読まれている記事

0件のコメント

ブログ開設 or ログイン してコメントを送ってみよう
目次をみる

技術ブログをはじめよう

Qrunch(クランチ)は、ITエンジニアリングに携わる全ての人のための技術ブログプラットフォームです。

技術ブログを開設する

Qrunchでアウトプットをはじめよう

Qrunch(クランチ)は、ITエンジニアリングに携わる全ての人のための技術ブログプラットフォームです。

Markdownで書ける

ログ機能でアウトプットを加速

デザインのカスタマイズが可能

技術ブログ開設

ここから先はアカウント(ブログ)開設が必要です

英数字4文字以上
.qrunch.io
英数字6文字以上
ログインする