Breakout | Pong | Space Ms Pacman |
---|---|---|
This repository contains a PyTorch implementation of the Deep Q-Network (DQN) algorithm for playing Atari games. The implementation is based on the original paper by Mnih et al. (2015) and contains the following extensions:
- Double Q-Learning (van Hasselt et al., 2015)
- Dueling Network Architectures (Wang et al., 2016)
- Prioritized Experience Replay (Schaul et al., 2016)
- Vectorized environment for parallel training
It is recommended to install the dependencies in a virtual environment, preferably with conda. The following commands will create a new environment and install the required packages using Poetry.
conda create -n dqn python=3.10
conda activate dqn
pip install poetry
poetry install
The DQN agent can be trained using the command line with Hydra. The following command will train a DQN agent to play the Breakout Atari game.
python src/dqn_atari/main.py model.env_id=BreakoutNoFrameskip-v4
Run python src/dqn_atari/main.py --help
to see all available options.
Noteable features include:
- Checkpointing: Set
train.checkpoint_every
to save the model everyn
steps. - Evaluation: Set
train.eval_every
to evaluate the model everyn
steps. - GIF recording: Set
train.num_gifs
to recordn
GIFs during evaluation.
Note: only the last checkpoint saves the replay buffer, since it can be large.
The following code snippet demonstrates how to train a DQN agent to play the Breakout Atari game.
from dqn_atari import DQN
from dqn_atari import PrioritizedReplayBuffer
# initialize the DQN agent
dqn_model = DQN(
'BreakoutNoFrameskip-v4',
num_envs=8, # or 1 for single environment
double_dqn=True,
dueling=True,
layers=[64, 64],
buffer_class=PrioritizedReplayBuffer,
)
# train the agent
dqn_model.train(
training_steps=1_000_000,
eval_every=10_000, # evaluate every 10k steps
eval_runs=30,
)
The trained model can be saved and loaded using the save
and load
methods. All the training parameters are saved along with the model, so you can continue training the model from where you left off. Because the replay buffer can be large, it is not saved by default. To save the replay buffer, set include_buffer=True
.
# save and load the model
dqn_model.save('my/folder/breakout.pt', include_buffer=True)
dqn_model = DQN.load('my/folder/breakout.pt')
# continue training the model
dqn_model.train(training_steps=1_000_000)
The trained agent can be evaluated directly or used to play the game. The following code snippet demonstrates how to evaluate the trained agent and play the game.
import gymnasium as gym
from dqn_atari import wrap_atari
# create the environment
env = gym.make('BreakoutNoFrameskip-v4')
env = wrap_atari(env)
# use `get_action` to play the game
state, _ = env.reset()
done = False
reward = 0
while not done:
action = dqn_model.get_action(state)
state, reward, done, *_ = env.step(action)
reward += reward
print(f'Total reward: {reward}')
# Or use `evaluate` to evaluate the agent directly
reward = dqn_model.evaluate(eval_runs=20)
print(f'Average reward: {reward}')