Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gif for gymnasium env #467

Merged
merged 7 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/basics/userguide/visu_gymnasium_gif.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 35 additions & 0 deletions docs/basics/userguide/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,41 @@ env.disable_rendering()

The tool `save_gif` works by collecting all the frames generated by the steps of `env` for each of the steps done after `env.enable_rendering()`.

<span>&#9888;&#9888;&#9888;</span> **warning :** Gymnasium<span>&#9888;&#9888;&#9888;</span>

Be carful with [gymnasium](https://gymnasium.farama.org/) env.
JulienT01 marked this conversation as resolved.
Show resolved Hide resolved
If you want to use the `save_gif` tool with a gymnasium environment, you need to :
- Use the {mod}`rlberry.envs.gym_make` function to create the (wrapped) environment.
- Use `render_mode = "rgb_array"` as parameter to create the environment. More information [here](https://gymnasium.farama.org/api/env/#gymnasium.Env.render).

An example here :
```python
from rlberry.envs import gym_make
from rlberry.seeding import Seeder
import numpy as np


seeder = Seeder(123)
env = gym_make("MountainCar-v0", render_mode="rgb_array")
env.reseed(seeder)


env.enable_rendering()
observation, info = env.reset()
for tt in range(100):
action = np.random.randint(0, 1)
observation, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
if done:
# Warning: this will never happen in the present case.
break
JulienT01 marked this conversation as resolved.
Show resolved Hide resolved

saving_path = "./visu_gymnasium_gif.gif"
env.save_gif(saving_path)
```

![](visu_gymnasium_gif.gif)

## Plotting training data and reward curves in rlberry

Training metrics in reinforcement learning are typically very unstable because there is randomness in the training process (e.g. neural networks) and the environment changes during the training. Then, if a single plot is very non-smooth due to the process variablility, it may become necessary to smooth the curve in order to be able to get any information from it.
Expand Down
4 changes: 3 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ Changelog
Dev version
-----------

*nothing*
*PR #467*

* Allow the save as Gif for gymnasium env (make_gym) : https://github.com/rlberry-py/rlberry/issues/453


Version 0.7.3
Expand Down
19 changes: 18 additions & 1 deletion rlberry/envs/basewrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from rlberry.seeding import Seeder, safe_reseed
import numpy as np
from rlberry.envs.interface import Model
from rlberry.rendering import RenderInterface
from rlberry.spaces.from_gym import convert_space_from_gym
from rlberry.rendering.utils import video_write, gif_write


class Wrapper(Model):
class Wrapper(Model, RenderInterface):
"""
Wraps a given environment, similar to OpenAI gym's wrapper [1,2] (now updated to gymnasium).
Can also be used to wrap gym environments.
Expand Down Expand Up @@ -46,6 +48,7 @@ def __init__(self, env, wrap_spaces=False):
self.env = env
self.metadata = self.env.metadata
self.render_mode = self.env.render_mode
self.frames = []

if wrap_spaces:
self.observation_space = convert_space_from_gym(self.env.observation_space)
Expand Down Expand Up @@ -107,11 +110,14 @@ def reseed(self, seed_seq=None):
def reset(self, seed=None, options=None):
if self.env.render_mode == "human":
self.render()
self.frames = []
return self.env.reset(seed=seed, options=options)

def step(self, action):
if self.render_mode == "human":
self.render()
elif self.render_mode == "rgb_array":
self.frames.append(self.render())
return self.env.step(action)

def sample(self, state, action):
Expand Down Expand Up @@ -147,6 +153,17 @@ def is_generative(self):
except Exception:
return False

def get_video(self, **kwargs):
return self.frames

def save_video(self, filename, framerate=25, **kwargs):
video_data = self.get_video(**kwargs)
video_write(filename, video_data, framerate=framerate)

def save_gif(self, filename, **kwargs):
video_data = self.get_video(**kwargs)
gif_write(filename, video_data)

def __repr__(self):
return str(self)

Expand Down
8 changes: 4 additions & 4 deletions rlberry/rendering/render_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def render(self, loop=True, **kwargs):
logger.info("Rendering not enabled for the environment.")
return 1

def get_video(self, framerate=25, **kwargs):
def get_video(self, **kwargs):
# background and data
background, data = self._get_background_and_scenes()

Expand All @@ -158,9 +158,9 @@ def get_video(self, framerate=25, **kwargs):
return renderer.get_video_data()

def save_video(self, filename, framerate=25, **kwargs):
video_data = self.get_video(framerate=framerate, **kwargs)
video_data = self.get_video(**kwargs)
video_write(filename, video_data, framerate=framerate)

def save_gif(self, filename, framerate=25, **kwargs):
video_data = self.get_video(framerate=framerate, **kwargs)
def save_gif(self, filename, **kwargs):
video_data = self.get_video(**kwargs)
gif_write(filename, video_data)
95 changes: 90 additions & 5 deletions rlberry/rendering/tests/test_rendering_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from rlberry_scool.envs.finite import Chain
from rlberry_scool.envs.finite import GridWorld
from rlberry_scool.agents.dynprog import ValueIterationAgent
from rlberry_research.agents import RSUCBVIAgent
from rlberry_research.envs.benchmarks.grid_exploration.four_room import FourRoom
from rlberry_research.envs.benchmarks.grid_exploration.six_room import SixRoom
from rlberry_research.envs.benchmarks.grid_exploration.apple_gold import AppleGold
Expand All @@ -16,6 +17,8 @@
from rlberry.rendering import RenderInterface
from rlberry.rendering import RenderInterface2D
from rlberry.envs import Wrapper
from rlberry.envs import gym_make
from rlberry.seeding import Seeder

import tempfile

Expand Down Expand Up @@ -156,7 +159,7 @@
done = terminated or truncated
if done:
# Warning: this will never happen in the present case because there is no terminal state.
# See the doc of GridWorld for more informations on the default parameters of GridWorld.
# See the doc of GridWorld for more information on the default parameters of GridWorld.
break

with tempfile.TemporaryDirectory() as tmpdirname:
Expand All @@ -169,7 +172,7 @@
pass


##### Fonctionne uniquement si on ajoute une dépendance à ffmpeg ############
# ##### Works only if you add a dependency to ffmpeg ############
# @pytest.mark.xfail(sys.platform == "darwin", reason="bug with Mac with pygame")
# @pytest.mark.parametrize("rendering_tool", RENDERING_TOOL)
# def test_gridworld_rendering_mp4(rendering_tool):
Expand All @@ -188,11 +191,11 @@
# done = terminated or truncated
# if done:
# # Warning: this will never happen in the present case because there is no terminal state.
# # See the doc of GridWorld for more informations on the default parameters of GridWorld.
# # See the doc of GridWorld for more information on the default parameters of GridWorld.
# break

# with tempfile.TemporaryDirectory() as tmpdirname:
# saving_path = tmpdirname + "/test_gif.mp4"
# saving_path = tmpdirname + "/test_video.mp4"
# env.save_video(saving_path)
# assert os.path.isfile(saving_path)
# try:
Expand All @@ -219,7 +222,89 @@
done = terminated or truncated
if done:
# Warning: this will never happen in the present case because there is no terminal state.
# See the doc of GridWorld for more informations on the default parameters of GridWorld.
# See the doc of GridWorld for more information on the default parameters of GridWorld.
break

env.render(loop=False)


@pytest.mark.skipif(sys.platform == "darwin", reason="bug with Mac with pygame")
@pytest.mark.parametrize("rendering_tool", RENDERING_TOOL)
def test_gym_make_rendering_gif(rendering_tool):
seeder = Seeder(123)
env = gym_make("MountainCar-v0", render_mode="rgb_array")
env.reseed(seeder)
env.renderer_type = rendering_tool

agent = RSUCBVIAgent(
env,
gamma=0.99,
horizon=200,
bonus_scale_factor=0.1,
copy_env=False,
min_dist=0.1,
)

info = agent.fit(15)
print(info)

env.enable_rendering()
observation, info = env.reset()
for tt in range(100):
action = agent.policy(observation)
observation, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
if done:
# Warning: this will never happen in the present case.
break

Check warning on line 259 in rlberry/rendering/tests/test_rendering_interface.py

View check run for this annotation

Codecov / codecov/patch

rlberry/rendering/tests/test_rendering_interface.py#L259

Added line #L259 was not covered by tests

with tempfile.TemporaryDirectory() as tmpdirname:
saving_path = tmpdirname + "/test_gif.gif"
env.save_gif(saving_path)
assert os.path.isfile(saving_path)

try:
os.remove(saving_path)
except Exception:
pass

Check warning on line 269 in rlberry/rendering/tests/test_rendering_interface.py

View check run for this annotation

Codecov / codecov/patch

rlberry/rendering/tests/test_rendering_interface.py#L268-L269

Added lines #L268 - L269 were not covered by tests


# ##### Works only if you add a dependency to ffmpeg ############
# @pytest.mark.skipif(sys.platform == "darwin", reason="bug with Mac with pygame")
# @pytest.mark.parametrize("rendering_tool", RENDERING_TOOL)
# def test_gym_make_rendering_mp4(rendering_tool):
# seeder = Seeder(123)
# env = gym_make("MountainCar-v0", render_mode="rgb_array")
# env.reseed(seeder)
# env.renderer_type = rendering_tool

# agent = RSUCBVIAgent(
# env,
# gamma=0.99,
# horizon=200,
# bonus_scale_factor=0.1,
# copy_env=False,
# min_dist=0.1,
# )

# info = agent.fit(15)
# print(info)

# env.enable_rendering()
# observation, info = env.reset()
# for tt in range(100):
# action = agent.policy(observation)
# observation, reward, terminated, truncated, info = env.step(action)
# done = terminated or truncated
# if done:
# # Warning: this will never happen in the present case.
# break

# with tempfile.TemporaryDirectory() as tmpdirname:
# saving_path = tmpdirname + "/test_video.mp4"
# env.save_video(saving_path)
# assert os.path.isfile(saving_path)
# try:
# os.remove(saving_path)
# except Exception:
# pass
Loading