Skip to content
/ jax-rl Public

Implementation of Reinforcement Learning Algorthims (DQN) using Jax and Gym

Notifications You must be signed in to change notification settings

erees1/jax-rl

Repository files navigation

Deep reinforcment learning algorithms implemented with Jax

image-20210329144135162

Implementation of Deep Reinforcement Learning using Jax. Testing on the OpenAI gym CartPole environment.

Algorithms

  1. DQN - Mnih V, Kavukcuoglu K, Silver D, et al. Playing Atari with Deep Reinforcement Learning
  2. DQN with target network - Mnih V, Kavukcuoglu K, Silver D, et al. Human-level control through deep reinforcement learning
  3. DDQN - Van Hasselt H, Guez A, Silver D. Deep reinforcement learning with double Q-Learning

Usage

# Install deps
pip install -r requirements.text

Training models

# Using launch script, by default set up to run multiple seeds
./launch.sh

# Using python
python3 run.py --agent dqn --train_eps 100 --n_layers 3 --seed 1 --test_eps 30 --lr 0.03 --batch_size 256 --warm_up_steps 500 --epsilon_hlife 1500 --save_dir out/CartPole-v1/dqn/example_run/1

Demo

# Running demo with trained model, will search the <out> directory to find best performing model
python3 run.py --demo --save_dir out --agent dqn

Results

Results are logged while running, the notebook notebooks/results.ipynb plots reward curves from training and testing. The notebook uses utils.py to parse the logs.

Structure

About

Implementation of Reinforcement Learning Algorthims (DQN) using Jax and Gym

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published