Implementation of Deep Reinforcement Learning using Jax. Testing on the OpenAI gym CartPole environment.
- DQN - Mnih V, Kavukcuoglu K, Silver D, et al. Playing Atari with Deep Reinforcement Learning
- DQN with target network - Mnih V, Kavukcuoglu K, Silver D, et al. Human-level control through deep reinforcement learning
- DDQN - Van Hasselt H, Guez A, Silver D. Deep reinforcement learning with double Q-Learning
# Install deps
pip install -r requirements.text
# Using launch script, by default set up to run multiple seeds
./launch.sh
# Using python
python3 run.py --agent dqn --train_eps 100 --n_layers 3 --seed 1 --test_eps 30 --lr 0.03 --batch_size 256 --warm_up_steps 500 --epsilon_hlife 1500 --save_dir out/CartPole-v1/dqn/example_run/1
# Running demo with trained model, will search the <out> directory to find best performing model
python3 run.py --demo --save_dir out --agent dqn
Results are logged while running, the notebook notebooks/results.ipynb
plots reward curves from training and testing. The notebook uses utils.py
to parse the logs.
notebooks/results.ipynb
- Visualization of training and test curvessrc/agents
- Directory containing algorithmssrc/model.py
- Jax code with neural network implementation, loss function and SGDsrc/run.py
- Top level interface to train, test and demo models.src/utils.py
- Code to parse logs