From e5423ff420bfd3d05960b209ff9967139b9338ee Mon Sep 17 00:00:00 2001 From: Ju T <53004817+JulienT01@users.noreply.github.com> Date: Tue, 20 Feb 2024 09:25:54 +0100 Subject: [PATCH] moving rendering to rlberry (#411) * moving chain and gridworld to rlberry_scool * adding save_gif and moving back the rendering to rlberry (main) --- docs/changelog.rst | 4 + examples/demo_env/video_plot_chain.py | 4 +- examples/demo_env/video_plot_gridworld.py | 6 +- rlberry/rendering/__init__.py | 3 + rlberry/rendering/common_shapes.py | 39 +++ rlberry/rendering/core.py | 56 ++++ rlberry/rendering/opengl_render2d.py | 252 ++++++++++++++++++ rlberry/rendering/pygame_render2d.py | 197 ++++++++++++++ rlberry/rendering/render_interface.py | 166 ++++++++++++ rlberry/rendering/tests/__init__.py | 0 .../tests/test_rendering_interface.py | 140 ++++++++++ rlberry/rendering/utils.py | 106 ++++++++ 12 files changed, 968 insertions(+), 5 deletions(-) create mode 100644 rlberry/rendering/__init__.py create mode 100644 rlberry/rendering/common_shapes.py create mode 100644 rlberry/rendering/core.py create mode 100644 rlberry/rendering/opengl_render2d.py create mode 100644 rlberry/rendering/pygame_render2d.py create mode 100644 rlberry/rendering/render_interface.py create mode 100644 rlberry/rendering/tests/__init__.py create mode 100644 rlberry/rendering/tests/test_rendering_interface.py create mode 100644 rlberry/rendering/utils.py diff --git a/docs/changelog.rst b/docs/changelog.rst index 3c005f73c..45aa07551 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,6 +8,10 @@ Dev version ----------- +*PR #411* + + * Moving "rendering" to rlberry + *PR #397* * Automatic save after fit() in ExperienceManager diff --git a/examples/demo_env/video_plot_chain.py b/examples/demo_env/video_plot_chain.py index 42c2b3c8b..46d954749 100644 --- a/examples/demo_env/video_plot_chain.py +++ b/examples/demo_env/video_plot_chain.py @@ -11,11 +11,11 @@ # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_chain.jpg' -from rlberry_research.envs.finite import Chain +from rlberry_scool.envs.finite import Chain env = Chain(10, 0.1) env.enable_rendering() for tt in range(5): env.step(env.action_space.sample()) env.render() -video = env.save_video("_video/video_plot_chain.mp4") +env.save_video("_video/video_plot_chain.mp4") diff --git a/examples/demo_env/video_plot_gridworld.py b/examples/demo_env/video_plot_gridworld.py index 872b46fbb..85647b84e 100644 --- a/examples/demo_env/video_plot_gridworld.py +++ b/examples/demo_env/video_plot_gridworld.py @@ -12,8 +12,8 @@ """ # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_gridworld.jpg' -from rlberry_research.agents.dynprog import ValueIterationAgent -from rlberry_research.envs.finite import GridWorld +from rlberry_scool.agents.dynprog import ValueIterationAgent +from rlberry_scool.envs.finite import GridWorld env = GridWorld(7, 10, walls=((2, 2), (3, 3))) @@ -33,4 +33,4 @@ # See the doc of GridWorld for more informations on the default parameters of GridWorld. break # Save the video -video = env.save_video("_video/video_plot_gridworld.mp4", framerate=10) +env.save_video("_video/video_plot_gridworld.mp4", framerate=10) diff --git a/rlberry/rendering/__init__.py b/rlberry/rendering/__init__.py new file mode 100644 index 000000000..5bdd0e295 --- /dev/null +++ b/rlberry/rendering/__init__.py @@ -0,0 +1,3 @@ +from .core import Scene, GeometricPrimitive +from .render_interface import RenderInterface +from .render_interface import RenderInterface2D diff --git a/rlberry/rendering/common_shapes.py b/rlberry/rendering/common_shapes.py new file mode 100644 index 000000000..91f942c14 --- /dev/null +++ b/rlberry/rendering/common_shapes.py @@ -0,0 +1,39 @@ +import numpy as np +from rlberry.rendering import GeometricPrimitive + + +def bar_shape(p0, p1, width): + shape = GeometricPrimitive("QUADS") + + x0, y0 = p0 + x1, y1 = p1 + + direction = np.array([x1 - x0, y1 - y0]) + norm = np.sqrt((direction * direction).sum()) + direction = direction / norm + + # get vector perpendicular to direction + u_vec = np.zeros(2) + u_vec[0] = -direction[1] + u_vec[1] = direction[0] + + u_vec = u_vec * width / 2 + + shape.add_vertex((x0 + u_vec[0], y0 + u_vec[1])) + shape.add_vertex((x0 - u_vec[0], y0 - u_vec[1])) + shape.add_vertex((x1 - u_vec[0], y1 - u_vec[1])) + shape.add_vertex((x1 + u_vec[0], y1 + u_vec[1])) + return shape + + +def circle_shape(center, radius, n_points=50): + shape = GeometricPrimitive("POLYGON") + + x0, y0 = center + theta = np.linspace(0.0, 2 * np.pi, n_points) + for tt in theta: + xx = radius * np.cos(tt) + yy = radius * np.sin(tt) + shape.add_vertex((x0 + xx, y0 + yy)) + + return shape diff --git a/rlberry/rendering/core.py b/rlberry/rendering/core.py new file mode 100644 index 000000000..0cc5e92ef --- /dev/null +++ b/rlberry/rendering/core.py @@ -0,0 +1,56 @@ +""" +Provide classes for geometric primitives in OpenGL and scenes. +""" + + +class Scene: + """ + Class representing a scene, which is a vector of GeometricPrimitive objects + """ + + def __init__(self): + self.shapes = [] + + def add_shape(self, shape): + self.shapes.append(shape) + + +class GeometricPrimitive: + """ + Class representing an OpenGL geometric primitive. + + Primitive type (GL_LINE_LOOP by defaut) + + If using OpenGLRender2D, one of the following: + POINTS + LINES + LINE_STRIP + LINE_LOOP + POLYGON + TRIANGLES + TRIANGLE_STRIP + TRIANGLE_FAN + QUADS + QUAD_STRIP + + If using PyGameRender2D: + POLYGON + + + TODO: Add support to more pygame shapes, + see https://www.pygame.org/docs/ref/draw.html + """ + + def __init__(self, primitive_type="GL_LINE_LOOP"): + # primitive type + self.type = primitive_type + # color in RGB + self.color = (0.25, 0.25, 0.25) + # list of vertices. each vertex is a tuple with coordinates in space + self.vertices = [] + + def add_vertex(self, vertex): + self.vertices.append(vertex) + + def set_color(self, color): + self.color = color diff --git a/rlberry/rendering/opengl_render2d.py b/rlberry/rendering/opengl_render2d.py new file mode 100644 index 000000000..64ec79646 --- /dev/null +++ b/rlberry/rendering/opengl_render2d.py @@ -0,0 +1,252 @@ +""" +OpenGL code for 2D rendering, using pygame. +""" + +import numpy as np +from os import environ + +from rlberry.rendering import Scene + +import rlberry + +logger = rlberry.logger +environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1" + +_IMPORT_SUCESSFUL = True +_IMPORT_ERROR_MSG = "" +try: + import pygame as pg + from pygame.locals import DOUBLEBUF, OPENGL + + from OpenGL.GLU import gluOrtho2D + from OpenGL.GL import glMatrixMode, glLoadIdentity, glClearColor + from OpenGL.GL import glClear, glFlush, glBegin, glEnd + from OpenGL.GL import glColor3f, glVertex2f + from OpenGL.GL import glReadBuffer, glReadPixels + from OpenGL.GL import GL_PROJECTION, GL_COLOR_BUFFER_BIT + from OpenGL.GL import GL_POINTS, GL_LINES, GL_LINE_STRIP, GL_LINE_LOOP + from OpenGL.GL import GL_POLYGON, GL_TRIANGLES, GL_TRIANGLE_STRIP + from OpenGL.GL import GL_TRIANGLE_FAN, GL_QUADS, GL_QUAD_STRIP + from OpenGL.GL import GL_FRONT, GL_RGB, GL_UNSIGNED_BYTE + +except Exception as ex: + _IMPORT_SUCESSFUL = False + _IMPORT_ERROR_MSG = str(ex) + + +class OpenGLRender2D: + """ + Class to render a list of scenes using OpenGL and pygame. + """ + + def __init__(self): + # parameters + self.window_width = 800 + self.window_height = 800 # multiples of 16 are preferred + self.background_color = (0.6, 0.75, 1.0) + self.refresh_interval = 50 + self.window_name = "rlberry render" + self.clipping_area = (-1.0, 1.0, -1.0, 1.0) + + # time counter + self.time_count = 0 + + # background scene + self.background = Scene() + # data to be rendered (list of scenes) + self.data = [] + + def set_window_name(self, name): + self.window_name = name + + def set_refresh_interval(self, interval): + self.refresh_interval = interval + + def set_clipping_area(self, area): + """ + The clipping area is tuple with elements (left, right, bottom, top) + Default = (-1.0, 1.0, -1.0, 1.0) + """ + self.clipping_area = area + base_size = max(self.window_width, self.window_height) + width_range = area[1] - area[0] + height_range = area[3] - area[2] + base_range = max(width_range, height_range) + width_range /= base_range + height_range /= base_range + self.window_width = int(base_size * width_range) + self.window_height = int(base_size * height_range) + + # width and height must be divisible by 2 + if self.window_width % 2 == 1: + self.window_width += 1 + if self.window_height % 2 == 1: + self.window_height += 1 + + def set_data(self, data): + self.data = data + + def set_background(self, background): + self.background = background + + def initGL(self): + """ + initialize GL + """ + glMatrixMode(GL_PROJECTION) + glLoadIdentity() + gluOrtho2D( + self.clipping_area[0], + self.clipping_area[1], + self.clipping_area[2], + self.clipping_area[3], + ) + + def display(self): + """ + Callback function, handler for window re-paint + """ + # Set background color (clear background) + glClearColor( + self.background_color[0], + self.background_color[1], + self.background_color[2], + 1.0, + ) + glClear(GL_COLOR_BUFFER_BIT) + + # Display background + for shape in self.background.shapes: + self.draw_geometric2d(shape) + + # Display objects + if len(self.data) > 0: + idx = self.time_count % len(self.data) + for shape in self.data[idx].shapes: + self.draw_geometric2d(shape) + + self.time_count += 1 + glFlush() + + @staticmethod + def draw_geometric2d(shape): + """ + Draw a 2D shape, of type GeometricPrimitive + """ + if shape.type == "POINTS": + glBegin(GL_POINTS) + elif shape.type == "LINES": + glBegin(GL_LINES) + elif shape.type == "LINE_STRIP": + glBegin(GL_LINE_STRIP) + elif shape.type == "LINE_LOOP": + glBegin(GL_LINE_LOOP) + elif shape.type == "POLYGON": + glBegin(GL_POLYGON) + elif shape.type == "TRIANGLES": + glBegin(GL_TRIANGLES) + elif shape.type == "TRIANGLE_STRIP": + glBegin(GL_TRIANGLE_STRIP) + elif shape.type == "TRIANGLE_FAN": + glBegin(GL_TRIANGLE_FAN) + elif shape.type == "QUADS": + glBegin(GL_QUADS) + elif shape.type == "QUAD_STRIP": + glBegin(GL_QUAD_STRIP) + else: + logger.error("Invalid type for geometric primitive!") + raise NameError + + # set color + glColor3f(shape.color[0], shape.color[1], shape.color[2]) + + # create vertices + for vertex in shape.vertices: + glVertex2f(vertex[0], vertex[1]) + glEnd() + + def run_graphics(self, loop=True): + """ + Sequentially displays scenes in self.data + + If loop is True, keep rendering until user closes the window. + """ + global _IMPORT_SUCESSFUL + + if _IMPORT_SUCESSFUL: + pg.init() + display = (self.window_width, self.window_height) + pg.display.set_mode(display, DOUBLEBUF | OPENGL) + pg.display.set_caption(self.window_name) + self.initGL() + while True: + for event in pg.event.get(): + if event.type == pg.QUIT: + pg.quit() + return + # + self.display() + # + pg.display.flip() + pg.time.wait(self.refresh_interval) + + # if not loop, stop + if not loop: + pg.quit() + return + else: + logger.error( + f"Not possible to render the environment due to the following error: {_IMPORT_ERROR_MSG}" + ) + return + + def get_gl_image_str(self): + # see https://gist.github.com/Jerdak/7364746 + glReadBuffer(GL_FRONT) + pixels = glReadPixels( + 0, 0, self.window_width, self.window_height, GL_RGB, GL_UNSIGNED_BYTE + ) + return pixels + + def get_video_data(self): + """ + Stores scenes in self.data in a list of numpy arrays that can be used + to save a video. + """ + global _IMPORT_SUCESSFUL + + if _IMPORT_SUCESSFUL: + video_data = [] + + pg.init() + display = (self.window_width, self.window_height) + _ = pg.display.set_mode(display, DOUBLEBUF | OPENGL) + pg.display.set_caption(self.window_name) + self.initGL() + + self.time_count = 0 + while self.time_count <= len(self.data): + # + self.display() + # + pg.display.flip() + + # + # See https://stackoverflow.com/a/42754578/5691288 + # + string_image = self.get_gl_image_str() + temp_surf = pg.image.frombytes( + string_image, (self.window_width, self.window_height), "RGB" + ) + tmp_arr = pg.surfarray.array3d(temp_surf) + imgdata = np.moveaxis(tmp_arr, 0, 1) + imgdata = np.flipud(imgdata) + video_data.append(imgdata) + + pg.quit() + return video_data + else: + logger.error( + f"Not possible to render the environment due to the following error: {_IMPORT_ERROR_MSG}" + ) + return [] diff --git a/rlberry/rendering/pygame_render2d.py b/rlberry/rendering/pygame_render2d.py new file mode 100644 index 000000000..a8d5b3990 --- /dev/null +++ b/rlberry/rendering/pygame_render2d.py @@ -0,0 +1,197 @@ +""" +Code for 2D rendering, using pygame (without OpenGL) +""" + +import numpy as np +from os import environ + +from rlberry.rendering import Scene + +import rlberry + +logger = rlberry.logger + +environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1" + +_IMPORT_SUCESSFUL = True +_IMPORT_ERROR_MSG = "" +try: + import pygame as pg + +except Exception as ex: + _IMPORT_SUCESSFUL = False + _IMPORT_ERROR_MSG = str(ex) + + +class PyGameRender2D: + """Class to render a list of scenes using pygame.""" + + def __init__(self): + # parameters + self.window_width = 800 + self.window_height = 800 # multiples of 16 are preferred + self.background_color = (150, 190, 255) + self.refresh_interval = 50 + self.window_name = "rlberry render" + self.clipping_area = (-1.0, 1.0, -1.0, 1.0) + + # time counter + self.time_count = 0 + + # background scene + self.background = Scene() + # data to be rendered (list of scenes) + self.data = [] + + def set_window_name(self, name): + self.window_name = name + + def set_refresh_interval(self, interval): + self.refresh_interval = interval + + def set_clipping_area(self, area): + """ + The clipping area is tuple with elements (left, right, bottom, top) + Default = (-1.0, 1.0, -1.0, 1.0) + """ + self.clipping_area = area + base_size = max(self.window_width, self.window_height) + width_range = area[1] - area[0] + height_range = area[3] - area[2] + base_range = max(width_range, height_range) + width_range /= base_range + height_range /= base_range + self.window_width = int(base_size * width_range) + self.window_height = int(base_size * height_range) + + # width and height must be divisible by 2 + if self.window_width % 2 == 1: + self.window_width += 1 + if self.window_height % 2 == 1: + self.window_height += 1 + + def set_data(self, data): + self.data = data + + def set_background(self, background): + self.background = background + + def display(self): + """ + Callback function, handler for window re-paint + """ + # Set background color (clear background) + self.screen.fill(self.background_color) + + # Display background + for shape in self.background.shapes: + self.draw_geometric2d(shape) + + # Display objects + if len(self.data) > 0: + idx = self.time_count % len(self.data) + for shape in self.data[idx].shapes: + self.draw_geometric2d(shape) + + self.time_count += 1 + + def draw_geometric2d(self, shape): + """ + Draw a 2D shape, of type GeometricPrimitive + """ + if shape.type in ["POLYGON"]: + area = self.clipping_area + width_range = area[1] - area[0] + height_range = area[3] - area[2] + + vertices = [] + for vertex in shape.vertices: + xx = vertex[0] * self.window_width / width_range + yy = vertex[1] * self.window_height / height_range + + # put origin at bottom left instead of top left + yy = self.window_height - yy + + pg_vertex = (xx, yy) + vertices.append(pg_vertex) + + color = (255 * shape.color[0], 255 * shape.color[1], 255 * shape.color[2]) + pg.draw.polygon(self.screen, color, vertices) + + else: + raise NotImplementedError( + "Shape type %s not implemented in pygame renderer." % shape.type + ) + + def run_graphics(self, loop=True): + """ + Sequentially displays scenes in self.data + """ + global _IMPORT_SUCESSFUL + + if _IMPORT_SUCESSFUL: + pg.init() + display = (self.window_width, self.window_height) + self.screen = pg.display.set_mode(display) + pg.display.set_caption(self.window_name) + while True: + for event in pg.event.get(): + if event.type == pg.QUIT: + pg.quit() + return + # + self.display() + # + pg.display.flip() + pg.time.wait(self.refresh_interval) + + # if not loop, stop + if not loop: + pg.quit() + return + else: + logger.error( + f"Not possible to render the environment due to the following error: {_IMPORT_ERROR_MSG}" + ) + return + + def get_video_data(self): + """ + Stores scenes in self.data in a list of numpy arrays that can be used + to save a video. + """ + global _IMPORT_SUCESSFUL + + if _IMPORT_SUCESSFUL: + video_data = [] + + pg.init() + display = (self.window_width, self.window_height) + self.screen = pg.display.set_mode(display) + pg.display.set_caption(self.window_name) + + self.time_count = 0 + while self.time_count <= len(self.data): + # + self.display() + # + pg.display.flip() + + # + # See https://stackoverflow.com/a/42754578/5691288 + # + string_image = pg.image.tobytes(self.screen, "RGB") + temp_surf = pg.image.frombytes( + string_image, (self.window_width, self.window_height), "RGB" + ) + tmp_arr = pg.surfarray.array3d(temp_surf) + imgdata = np.moveaxis(tmp_arr, 0, 1) + video_data.append(imgdata) + + pg.quit() + return video_data + else: + logger.error( + f"Not possible to render the environment due to the following error: {_IMPORT_ERROR_MSG}" + ) + return [] diff --git a/rlberry/rendering/render_interface.py b/rlberry/rendering/render_interface.py new file mode 100644 index 000000000..123d0918a --- /dev/null +++ b/rlberry/rendering/render_interface.py @@ -0,0 +1,166 @@ +""" +Interface that allows 2D rendering. +""" + + +from abc import ABC, abstractmethod + +from rlberry.rendering.opengl_render2d import OpenGLRender2D +from rlberry.rendering.pygame_render2d import PyGameRender2D +from rlberry.rendering.utils import video_write, gif_write + +import rlberry + +logger = rlberry.logger + + +class RenderInterface(ABC): + """ + Common interface for rendering in rlberry. + """ + + def __init__(self): + self._rendering_enabled = False + + def is_render_enabled(self): + return self._rendering_enabled + + def enable_rendering(self): + self._rendering_enabled = True + + def disable_rendering(self): + self._rendering_enabled = False + + def save_video(self, filename, **kwargs): + """ + Save video file. + """ + pass + + def get_video(self, **kwargs): + """ + Get video data. + """ + pass + + @abstractmethod + def render(self, **kwargs): + """ + Display on screen. + """ + pass + + +class RenderInterface2D(RenderInterface): + """ + Interface for 2D rendering in rlberry. + """ + + def __init__(self): + RenderInterface.__init__(self) + self._rendering_enabled = False + self._rendering_type = "2d" + self._state_history_for_rendering = [] + self._refresh_interval = 50 # in milliseconds + self._clipping_area = (-1.0, 1.0, -1.0, 1.0) # (left,right,bottom,top) + + # rendering type, either 'pygame' or 'opengl' + self.renderer_type = "opengl" + + def get_renderer(self): + if self.renderer_type == "opengl": + return OpenGLRender2D() + elif self.renderer_type == "pygame": + return PyGameRender2D() + else: + raise NotImplementedError("Unknown renderer type.") + + @abstractmethod + def get_scene(self, state): + """ + Return scene (list of shapes) representing a given state + """ + pass + + @abstractmethod + def get_background(self): + """ + Returne a scene (list of shapes) representing the background + """ + pass + + def append_state_for_rendering(self, state): + self._state_history_for_rendering.append(state) + + def set_refresh_interval(self, interval): + self._refresh_interval = interval + + def clear_render_buffer(self): + self._state_history_for_rendering = [] + + def set_clipping_area(self, area): + self._clipping_area = area + + def _get_background_and_scenes(self): + # background + background = self.get_background() + + # data: convert states to scenes + scenes = [] + for state in self._state_history_for_rendering: + scene = self.get_scene(state) + scenes.append(scene) + return background, scenes + + def render(self, loop=True, **kwargs): + """ + Function to render an environment that implements the interface. + """ + + if self.is_render_enabled(): + # background and data + background, data = self._get_background_and_scenes() + + if len(data) == 0: + logger.info("No data to render.") + return + + # render + renderer = self.get_renderer() + + renderer.window_name = self.name + renderer.set_refresh_interval(self._refresh_interval) + renderer.set_clipping_area(self._clipping_area) + renderer.set_data(data) + renderer.set_background(background) + renderer.run_graphics(loop) + return 0 + else: + logger.info("Rendering not enabled for the environment.") + return 1 + + def get_video(self, framerate=25, **kwargs): + # background and data + background, data = self._get_background_and_scenes() + + if len(data) == 0: + logger.info("No data to save.") + return + + # get video data from renderer + renderer = self.get_renderer() + renderer.window_name = self.name + renderer.set_refresh_interval(self._refresh_interval) + renderer.set_clipping_area(self._clipping_area) + renderer.set_data(data) + renderer.set_background(background) + + return renderer.get_video_data() + + def save_video(self, filename, framerate=25, **kwargs): + video_data = self.get_video(framerate=framerate, **kwargs) + video_write(filename, video_data, framerate=framerate) + + def save_gif(self, filename, framerate=25, **kwargs): + video_data = self.get_video(framerate=framerate, **kwargs) + gif_write(filename, video_data) diff --git a/rlberry/rendering/tests/__init__.py b/rlberry/rendering/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rlberry/rendering/tests/test_rendering_interface.py b/rlberry/rendering/tests/test_rendering_interface.py new file mode 100644 index 000000000..9d5366ecf --- /dev/null +++ b/rlberry/rendering/tests/test_rendering_interface.py @@ -0,0 +1,140 @@ +import os +import pytest +import sys + +from pyvirtualdisplay import Display +from rlberry_research.envs.classic_control import MountainCar +from rlberry_research.envs.classic_control import Acrobot +from rlberry_research.envs.classic_control import Pendulum +from rlberry_scool.envs.finite import Chain +from rlberry_scool.envs.finite import GridWorld +from rlberry_research.envs.benchmarks.grid_exploration.four_room import FourRoom +from rlberry_research.envs.benchmarks.grid_exploration.six_room import SixRoom +from rlberry_research.envs.benchmarks.grid_exploration.apple_gold import AppleGold +from rlberry_research.envs.benchmarks.ball_exploration import PBall2D, SimplePBallND +from rlberry_research.envs.benchmarks.generalization.twinrooms import TwinRooms +from rlberry.rendering import RenderInterface +from rlberry.rendering import RenderInterface2D +from rlberry.envs import Wrapper + +import tempfile + +try: + display = Display(visible=0, size=(1400, 900)) + display.start() +except Exception: + pass + +classes = [ + Acrobot, + Pendulum, + MountainCar, + GridWorld, + Chain, + PBall2D, + SimplePBallND, + FourRoom, + SixRoom, + AppleGold, + TwinRooms, +] + + +@pytest.mark.parametrize("ModelClass", classes) +def test_instantiation(ModelClass): + env = ModelClass() + + if isinstance(env, RenderInterface): + env.disable_rendering() + assert not env.is_render_enabled() + env.enable_rendering() + assert env.is_render_enabled() + + +@pytest.mark.xfail(sys.platform != "linux", reason="bug with mac and windows???") +@pytest.mark.parametrize("ModelClass", classes) +def test_render2d_interface(ModelClass): + env = ModelClass() + + if isinstance(env, RenderInterface2D): + env.enable_rendering() + + if env.is_online(): + for _ in range(2): + observation, info = env.reset() + for _ in range(5): + assert env.observation_space.contains(observation) + action = env.action_space.sample() + observation, _, _, _, _ = env.step(action) + env.render(loop=False) + + with tempfile.TemporaryDirectory() as tmpdirname: + saving_path = tmpdirname + "/test_video.mp4" + + env.save_video(saving_path) + env.clear_render_buffer() + + +@pytest.mark.xfail(sys.platform != "linux", reason="bug with mac and windows???") +@pytest.mark.parametrize("ModelClass", classes) +def test_render2d_interface_wrapped(ModelClass): + env = Wrapper(ModelClass()) + + if isinstance(env.env, RenderInterface2D): + env.enable_rendering() + if env.is_online(): + for _ in range(2): + observation, info = env.reset() + for _ in range(5): + assert env.observation_space.contains(observation) + action = env.action_space.sample() + observation, _, _, _, _ = env.step(action) + env.render(loop=False) + + with tempfile.TemporaryDirectory() as tmpdirname: + saving_path = tmpdirname + "/test_video.mp4" + env.save_video(saving_path) + env.clear_render_buffer() + try: + os.remove("test_video.mp4") + except Exception: + pass + + +def test_render_appelGold(): + env = AppleGold() + env.render_mode = "human" + env = Wrapper(env) + + if env.is_online(): + for _ in range(2): + observation, info = env.reset() + for _ in range(5): + assert env.observation_space.contains(observation) + action = env.action_space.sample() + observation, _, _, _, _ = env.step(action) + env.render(loop=False) + + with tempfile.TemporaryDirectory() as tmpdirname: + saving_path = tmpdirname + "/test_video.mp4" + env.save_video(saving_path) + env.clear_render_buffer() + try: + os.remove("test_video.mp4") + except Exception: + pass + + +def test_write_gif(): + env = Chain(10, 0.3) + env.enable_rendering() + for tt in range(20): + env.step(env.action_space.sample()) + with tempfile.TemporaryDirectory() as tmpdirname: + saving_path = tmpdirname + "/test_gif.mp4" + env.save_gif(saving_path) + assert os.path.isfile(saving_path) + try: + os.remove(saving_path) + except Exception: + pass diff --git a/rlberry/rendering/utils.py b/rlberry/rendering/utils.py new file mode 100644 index 000000000..ac1be2858 --- /dev/null +++ b/rlberry/rendering/utils.py @@ -0,0 +1,106 @@ +import numpy as np +import rlberry + +logger = rlberry.logger +import imageio + + +_FFMPEG_INSTALLED = True +try: + import ffmpeg +except Exception: + _FFMPEG_INSTALLED = False + + +def video_write(fn, images, framerate=60, vcodec="libx264"): + """ + Save list of images to a video file. + + Source: + https://github.com/kkroening/ffmpeg-python/issues/246#issuecomment-520200981 + Modified so that framerate is given to .input(), as suggested in the + thread, to avoid + skipping frames. + + Parameters + ---------- + fn : string + filename + images : list or np.array + list of images to save to a video. + framerate : int + """ + global _FFMPEG_INSTALLED + + try: + if len(images) == 0: + logger.warning("Calling video_write() with empty images.") + return + + if not _FFMPEG_INSTALLED: + logger.error( + "video_write(): Unable to save video, ffmpeg-python \ + package required (https://github.com/kkroening/ffmpeg-python)" + ) + return + + if not isinstance(images, np.ndarray): + images = np.asarray(images) + _, height, width, channels = images.shape + process = ( + ffmpeg.input( + "pipe:", + format="rawvideo", + pix_fmt="rgb24", + s="{}x{}".format(width, height), + r=framerate, + ) + .output(fn, pix_fmt="yuv420p", vcodec=vcodec) + .overwrite_output() + .run_async(pipe_stdin=True) + ) + for frame in images: + process.stdin.write(frame.astype(np.uint8).tobytes()) + process.stdin.close() + process.wait() + + except Exception as ex: + logger.warning( + "Not possible to save \ +video, due to exception: {}".format( + str(ex) + ) + ) + + +def gif_write(fn, images): + """ + Save list of images to a gif file + + Parameters + ---------- + fn : string + filename + images : list or np.array + list of images to save to a gif. + """ + + try: + if len(images) == 0: + logger.warning("Calling gif_write() with empty images.") + return + + if not isinstance(images, np.ndarray): + images = np.asarray(images) + + with imageio.get_writer(fn, mode="I") as writer: + for frame in images: + writer.append_data(frame) + + except Exception as ex: + logger.warning( + "Not possible to save \ +gif, due to exception: {}".format( + str(ex) + ) + )