Learn practical skills, build real-world projects, and advance your career
%matplotlib inline

Reinforcement Learning (DQN) Tutorial

Author: Adam Paszke <https://github.com/apaszke>_

This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
on the CartPole-v0 task from the OpenAI Gym <https://gym.openai.com/>__.

Task

The agent has to decide between two actions - moving the cart left or
right - so that the pole attached to it stays upright. You can find an
official leaderboard with various algorithms and visualizations at the
Gym website <https://gym.openai.com/envs/CartPole-v0>__.

.. figure:: /_static/img/cartpole.gif
:alt: cartpole

cartpole

As the agent observes the current state of the environment and chooses
an action, the environment transitions to a new state, and also
returns a reward that indicates the consequences of the action. In this
task, rewards are +1 for every incremental timestep and the environment
terminates if the pole falls over too far or the cart moves more then 2.4
units away from center. This means better performing scenarios will run
for longer duration, accumulating larger return.

The CartPole task is designed so that the inputs to the agent are 4 real
values representing the environment state (position, velocity, etc.).
However, neural networks can solve the task purely by looking at the
scene, so we'll use a patch of the screen centered on the cart as an
input. Because of this, our results aren't directly comparable to the
ones from the official leaderboard - our task is much harder.
Unfortunately this does slow down the training, because we have to
render all the frames.

Strictly speaking, we will present the state as the difference between
the current screen patch and the previous one. This will allow the agent
to take the velocity of the pole into account from one image.

Packages

First, let's import needed packages. Firstly, we need
gym <https://gym.openai.com/docs>__ for the environment
(Install using pip install gym).
We'll also use the following from PyTorch:

  • neural networks (torch.nn)
  • optimization (torch.optim)
  • automatic differentiation (torch.autograd)
  • utilities for vision tasks (torchvision - a separate package <https://github.com/pytorch/vision>__).
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T


env = gym.make('CartPole-v0').unwrapped

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Replay Memory

We'll be using experience replay memory for training our DQN. It stores
the transitions that the agent observes, allowing us to reuse this data
later. By sampling from it randomly, the transitions that build up a
batch are decorrelated. It has been shown that this greatly stabilizes
and improves the DQN training procedure.

For this, we're going to need two classses:

  • Transition - a named tuple representing a single transition in
    our environment. It essentially maps (state, action) pairs
    to their (next_state, reward) result, with the state being the
    screen difference image as described later on.
  • ReplayMemory - a cyclic buffer of bounded size that holds the
    transitions observed recently. It also implements a .sample()
    method for selecting a random batch of transitions for training.
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)