強化学習の勉強 (2) Deep Q-Network

公開日:2019-02-01
最終更新:2019-02-02
※この記事は外部サイト(https://qiita.com/hrs1985/items/2e18099e85...)からのクロス投稿です

これの続きです

5章の Deep Q-Network (DQN) についての実装を試してみた。

Q学習のときには q_table (離散化されたstate × action) の表を作っておいて、state から次の action を決めるのはその表を元にやっていたが、DQN では state を入力、action を出力とするニューラルネットワークでやるらしい。

120エピソード程度の学習でこんな感じになった。

DQN の実装上の工夫は以下の4点。

  1. Experience Replay
    各ステップごとに学習をしてしまうと、時間的に相関が高いデータを連続して学習することになり学習が不安定化する。代わりに、各ステップの情報をメモリに保存しておき、そこからサンプリングしたものを使って学習を行う。

  2. Fixed Target Q-Network
    行動を決定する main-network と行動価値を計算する target-network を分ける。ただし今回は main-network をミニバッチ学習する形で簡便な実装を行う。

  3. 報酬のクリッピング
    各ステップの報酬は -1、0、1 のいずれかとする。

  4. Huber 関数を用いた誤差
    二乗誤差を使うと誤差関数の出力が大きくなりすぎて学習が不安定化することがあるらしい。

$$
L_1(x) = \left\{
\begin{array}{ll}
\frac{1}{2}x^2 & (|x| \leq 1) \\
|x| - \frac{1}{2} & (|x| \gt 1)
\end{array}
\right.
$$

以下コードのメモ。実際はJupyter Notebookで実行している。

import numpy as np  
import matplotlib.pyplot as plt  
%matplotlib inline  
import gym  

from JSAnimation.IPython_display import display_animation  
from matplotlib import animation   
from IPython.display import display  

# 動画を出力する関数です  
def display_frames_as_gif(frames):  
    plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0), dpi=72)  
    patch = plt.imshow(frames[0])  
    plt.axis('off')  

    def animate(i):  
        patch.set_data(frames[i])  

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)  
    anim.save('movie_cart_ple_dqn.mp4')  
    display(display_animation(anim, default_mode='loop'))  

from collections import namedtuple  

# 各ステップでの情報を保持するための namedtuple です  
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))  

# 定数の設定  
ENV = 'CartPole-v0'  
GAMMA = 0.99  
MAX_STEPS = 200  
NUM_EPISODES = 500  

# 経験を保存するメモリクラスを定義します。  

class ReplayMemory:  

    def __init__(self, CAPACITY):  
        self.capacity = CAPACITY # メモリの最大長さ  
        self.memory = []  
        self.index = 0  

    def push(self, state, action, state_next, reward):  
        if len(self.memory) < self.capacity:  
            self.memory.append(None) #メモリが満タンじゃないときには追加  

        self.memory[self.index] = Transition(state, action, state_next, reward)  
        self.index = (self.index + 1) % self.capacity  

    def sample(self, batch_size):  
        return random.sample(self.memory, batch_size)  

    def __len__(self):  
        return len(self.memory)  


import random  
import torch  
from torch import nn  
from torch import optim  
import torch.nn.functional as F  

BATCH_SIZE = 32  
CAPACITY = 10000  

# エージェントの行動方針を決めるためのクラスです。  
class Brain:  
    def __init__(self, num_states, num_actions):  
        self.num_actions = num_actions  

        # メモリオブジェクトの生成  
        self.memory = ReplayMemory(CAPACITY)  

        # ニューラルネットワークの構築  
        self.model = nn.Sequential()  
        self.model.add_module('fc1', nn.Linear(num_states, 32))  
        self.model.add_module('relu1', nn.ReLU())  
        self.model.add_module('fc2', nn.Linear(32, 32))  
        self.model.add_module('relu2', nn.ReLU())  
        self.model.add_module('fc3', nn.Linear(32, num_actions))  

        print(self.model)  

        # オプティマイザの設定  
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001)  

    def replay(self):  
        '''Experience Replay'''  

        # メモリサイズの確認  
        # メモリサイズがミニバッチサイズより小さい間は何もしない。  
        if len(self.memory) < BATCH_SIZE:  
            return  

        # ミニバッチの作成  
        transitions = self.memory.sample(BATCH_SIZE)  
        batch = Transition(*zip(*transitions))  

        state_batch = torch.cat(batch.state)  
        action_batch = torch.cat(batch.action)  
        reward_batch = torch.cat(batch.reward)  
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])  

        # ネットワークを推論モードにする  
        self.model.eval()  

        # state_batchをモデルに与え、推論結果からaction_batchで行った行動に対応するQ値を取ってくる  
        # つまりあるstateにおける行ったactionの価値をとってきている。  
        state_action_value = self.model(state_batch).gather(1, action_batch)  

        # max{Q(s_t+1, a)}を求める  
        non_final_mask = torch.ByteTensor(  
        tuple(map(lambda s: s is not None, batch.next_state)))  

        next_state_values = torch.zeros(BATCH_SIZE)  

        next_state_values[non_final_mask] = self.model(non_final_next_states).max(1)[0].detach()  

        # Q(s_t, a_t)をQ学習の式から求める。  
        expected_state_action_values = reward_batch + GAMMA * next_state_values  

        # ネットワークパラメータの更新  
        self.model.train()  
        loss = F.smooth_l1_loss(state_action_value, expected_state_action_values.unsqueeze(1))  
        self.optimizer.zero_grad()  
        loss.backward()  
        self.optimizer.step()  

    def decide_action(self, state, episode):  
        '''state に応じて行動を決定する関数'''  
        epsilon = 0.5 * (1 / (episode + 1))  

        if epsilon <= np.random.uniform(0, 1):  
            self.model.eval()  
            with torch.no_grad():  
                action = self.model(state).max(1)[1].view(1, 1)  
        else:  
            action = torch.LongTensor([[random.randrange(self.num_actions)]])  

        return action  

# エージェントクラスです。state に応じて action を行います。  
class Agent:  
    def __init__(self, num_states, num_actions):  
        self.brain = Brain(num_states, num_actions)  

    def update_q_function(self):  
        self.brain.replay()  

    def get_action(self, state, episode):  
        action = self.brain.decide_action(state, episode)  
        return action  

    def memorize(self, state, action, state_next, reward):  
        self.brain.memory.push(state, action, state_next, reward)  

# CartPole 実行環境のクラスです。  
class Environment:  
    def __init__(self):  
        self.env = gym.make(ENV)  
        self.num_states = self.env.observation_space.shape[0]  
        self.num_actions = self.env.action_space.n  

        self.agent = Agent(self.num_states, self.num_actions)  

    def run(self):  

        episode_10_list = np.zeros(10) #直近10エピソードで振り子が立ち続けたステップ数を記録  

        complete_episodes = 0  
        episode_final = False  
        frames = []  

        for episode in range(NUM_EPISODES):  
            observation = self.env.reset()  

            state = observation  
            state = torch.from_numpy(state).type(torch.FloatTensor)  
            state = torch.unsqueeze(state, 0)  

            for step in range(MAX_STEPS):  

                if episode_final is True:  
                    frames.append(self.env.render(mode='rgb_array'))  

                action = self.agent.get_action(state, episode)  

                #行動actionの実行によってs_t+1とdoneフラグを取得  
                observation_next, _, done, _ = self.env.step(action.item())  

                if done:  
                    state_next = None  

                    episode_10_list = np.hstack((episode_10_list[1:], step + 1))  

                    if step < 195:  
                        reward = torch.FloatTensor([-1.0]) # 195ステップ未満で倒れたら報酬-1  
                        complete_episodes = 0  
                    else:  
                        reward = torch.FloatTensor([1.0])  
                        complete_episodes = complete_episodes + 1  

                else:  
                    reward = torch.FloatTensor([0.0])  
                    state_next = observation_next  
                    state_next = torch.from_numpy(state_next).type(torch.FloatTensor)  
                    state_next = torch.unsqueeze(state_next, 0)  

                self.agent.memorize(state, action, state_next, reward)  
                self.agent.update_q_function()  

                state = state_next  

                if done:  
                    print("%d episode: Finished after %d steps: 10試行の平均step数 = %.lf"%(  
                    episode, step + 1, episode_10_list.mean()))  
                    observation = self.env.reset()  
                    break  

            if episode_final is True:  
                display_frames_as_gif(frames)  
                break  

            if complete_episodes >= 10:  
                print("10回連続成功")  
                episode_final = True  

cartpole_env = Environment()  
cartpole_env.run()
記事が少しでもいいなと思ったらクラップを送ってみよう!
0
+1
@kiyoの技術ブログ

よく一緒に読まれている記事

0件のコメント

ブログ開設 or ログイン してコメントを送ってみよう
目次をみる

技術ブログをはじめよう

Qrunch(クランチ)は、ITエンジニアリングに携わる全ての人のための技術ブログプラットフォームです。

技術ブログを開設する

Qrunchでアウトプットをはじめよう

Qrunch(クランチ)は、ITエンジニアリングに携わる全ての人のための技術ブログプラットフォームです。

Markdownで書ける

ログ機能でアウトプットを加速

デザインのカスタマイズが可能

技術ブログ開設

ここから先はアカウント(ブログ)開設が必要です

英数字4文字以上
.qrunch.io
英数字6文字以上
ログインする