Skip to content

Commit

Permalink
moving rendering to rlberry (#411)
Browse files Browse the repository at this point in the history
* moving chain and gridworld to rlberry_scool
* adding save_gif and moving back the rendering to rlberry (main)
  • Loading branch information
JulienT01 authored Feb 20, 2024
1 parent 4aa369b commit e5423ff
Show file tree
Hide file tree
Showing 12 changed files with 968 additions and 5 deletions.
4 changes: 4 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Dev version
-----------


*PR #411*

* Moving "rendering" to rlberry

*PR #397*

* Automatic save after fit() in ExperienceManager
Expand Down
4 changes: 2 additions & 2 deletions examples/demo_env/video_plot_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_chain.jpg'


from rlberry_research.envs.finite import Chain
from rlberry_scool.envs.finite import Chain

env = Chain(10, 0.1)
env.enable_rendering()
for tt in range(5):
env.step(env.action_space.sample())
env.render()
video = env.save_video("_video/video_plot_chain.mp4")
env.save_video("_video/video_plot_chain.mp4")
6 changes: 3 additions & 3 deletions examples/demo_env/video_plot_gridworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
"""
# sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_gridworld.jpg'

from rlberry_research.agents.dynprog import ValueIterationAgent
from rlberry_research.envs.finite import GridWorld
from rlberry_scool.agents.dynprog import ValueIterationAgent
from rlberry_scool.envs.finite import GridWorld


env = GridWorld(7, 10, walls=((2, 2), (3, 3)))
Expand All @@ -33,4 +33,4 @@
# See the doc of GridWorld for more informations on the default parameters of GridWorld.
break
# Save the video
video = env.save_video("_video/video_plot_gridworld.mp4", framerate=10)
env.save_video("_video/video_plot_gridworld.mp4", framerate=10)
3 changes: 3 additions & 0 deletions rlberry/rendering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .core import Scene, GeometricPrimitive
from .render_interface import RenderInterface
from .render_interface import RenderInterface2D
39 changes: 39 additions & 0 deletions rlberry/rendering/common_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np
from rlberry.rendering import GeometricPrimitive


def bar_shape(p0, p1, width):
shape = GeometricPrimitive("QUADS")

x0, y0 = p0
x1, y1 = p1

direction = np.array([x1 - x0, y1 - y0])
norm = np.sqrt((direction * direction).sum())
direction = direction / norm

# get vector perpendicular to direction
u_vec = np.zeros(2)
u_vec[0] = -direction[1]
u_vec[1] = direction[0]

u_vec = u_vec * width / 2

shape.add_vertex((x0 + u_vec[0], y0 + u_vec[1]))
shape.add_vertex((x0 - u_vec[0], y0 - u_vec[1]))
shape.add_vertex((x1 - u_vec[0], y1 - u_vec[1]))
shape.add_vertex((x1 + u_vec[0], y1 + u_vec[1]))
return shape


def circle_shape(center, radius, n_points=50):
shape = GeometricPrimitive("POLYGON")

x0, y0 = center
theta = np.linspace(0.0, 2 * np.pi, n_points)
for tt in theta:
xx = radius * np.cos(tt)
yy = radius * np.sin(tt)
shape.add_vertex((x0 + xx, y0 + yy))

return shape
56 changes: 56 additions & 0 deletions rlberry/rendering/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Provide classes for geometric primitives in OpenGL and scenes.
"""


class Scene:
"""
Class representing a scene, which is a vector of GeometricPrimitive objects
"""

def __init__(self):
self.shapes = []

def add_shape(self, shape):
self.shapes.append(shape)


class GeometricPrimitive:
"""
Class representing an OpenGL geometric primitive.
Primitive type (GL_LINE_LOOP by defaut)
If using OpenGLRender2D, one of the following:
POINTS
LINES
LINE_STRIP
LINE_LOOP
POLYGON
TRIANGLES
TRIANGLE_STRIP
TRIANGLE_FAN
QUADS
QUAD_STRIP
If using PyGameRender2D:
POLYGON
TODO: Add support to more pygame shapes,
see https://www.pygame.org/docs/ref/draw.html
"""

def __init__(self, primitive_type="GL_LINE_LOOP"):
# primitive type
self.type = primitive_type
# color in RGB
self.color = (0.25, 0.25, 0.25)
# list of vertices. each vertex is a tuple with coordinates in space
self.vertices = []

def add_vertex(self, vertex):
self.vertices.append(vertex)

def set_color(self, color):
self.color = color
252 changes: 252 additions & 0 deletions rlberry/rendering/opengl_render2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
"""
OpenGL code for 2D rendering, using pygame.
"""

import numpy as np
from os import environ

from rlberry.rendering import Scene

import rlberry

logger = rlberry.logger
environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"

_IMPORT_SUCESSFUL = True
_IMPORT_ERROR_MSG = ""
try:
import pygame as pg
from pygame.locals import DOUBLEBUF, OPENGL

from OpenGL.GLU import gluOrtho2D
from OpenGL.GL import glMatrixMode, glLoadIdentity, glClearColor
from OpenGL.GL import glClear, glFlush, glBegin, glEnd
from OpenGL.GL import glColor3f, glVertex2f
from OpenGL.GL import glReadBuffer, glReadPixels
from OpenGL.GL import GL_PROJECTION, GL_COLOR_BUFFER_BIT
from OpenGL.GL import GL_POINTS, GL_LINES, GL_LINE_STRIP, GL_LINE_LOOP
from OpenGL.GL import GL_POLYGON, GL_TRIANGLES, GL_TRIANGLE_STRIP
from OpenGL.GL import GL_TRIANGLE_FAN, GL_QUADS, GL_QUAD_STRIP
from OpenGL.GL import GL_FRONT, GL_RGB, GL_UNSIGNED_BYTE

except Exception as ex:
_IMPORT_SUCESSFUL = False
_IMPORT_ERROR_MSG = str(ex)


class OpenGLRender2D:
"""
Class to render a list of scenes using OpenGL and pygame.
"""

def __init__(self):
# parameters
self.window_width = 800
self.window_height = 800 # multiples of 16 are preferred
self.background_color = (0.6, 0.75, 1.0)
self.refresh_interval = 50
self.window_name = "rlberry render"
self.clipping_area = (-1.0, 1.0, -1.0, 1.0)

# time counter
self.time_count = 0

# background scene
self.background = Scene()
# data to be rendered (list of scenes)
self.data = []

def set_window_name(self, name):
self.window_name = name

def set_refresh_interval(self, interval):
self.refresh_interval = interval

def set_clipping_area(self, area):
"""
The clipping area is tuple with elements (left, right, bottom, top)
Default = (-1.0, 1.0, -1.0, 1.0)
"""
self.clipping_area = area
base_size = max(self.window_width, self.window_height)
width_range = area[1] - area[0]
height_range = area[3] - area[2]
base_range = max(width_range, height_range)
width_range /= base_range
height_range /= base_range
self.window_width = int(base_size * width_range)
self.window_height = int(base_size * height_range)

# width and height must be divisible by 2
if self.window_width % 2 == 1:
self.window_width += 1
if self.window_height % 2 == 1:
self.window_height += 1

def set_data(self, data):
self.data = data

def set_background(self, background):
self.background = background

def initGL(self):
"""
initialize GL
"""
glMatrixMode(GL_PROJECTION)
glLoadIdentity()
gluOrtho2D(
self.clipping_area[0],
self.clipping_area[1],
self.clipping_area[2],
self.clipping_area[3],
)

def display(self):
"""
Callback function, handler for window re-paint
"""
# Set background color (clear background)
glClearColor(
self.background_color[0],
self.background_color[1],
self.background_color[2],
1.0,
)
glClear(GL_COLOR_BUFFER_BIT)

# Display background
for shape in self.background.shapes:
self.draw_geometric2d(shape)

# Display objects
if len(self.data) > 0:
idx = self.time_count % len(self.data)
for shape in self.data[idx].shapes:
self.draw_geometric2d(shape)

self.time_count += 1
glFlush()

@staticmethod
def draw_geometric2d(shape):
"""
Draw a 2D shape, of type GeometricPrimitive
"""
if shape.type == "POINTS":
glBegin(GL_POINTS)
elif shape.type == "LINES":
glBegin(GL_LINES)
elif shape.type == "LINE_STRIP":
glBegin(GL_LINE_STRIP)
elif shape.type == "LINE_LOOP":
glBegin(GL_LINE_LOOP)
elif shape.type == "POLYGON":
glBegin(GL_POLYGON)
elif shape.type == "TRIANGLES":
glBegin(GL_TRIANGLES)
elif shape.type == "TRIANGLE_STRIP":
glBegin(GL_TRIANGLE_STRIP)
elif shape.type == "TRIANGLE_FAN":
glBegin(GL_TRIANGLE_FAN)
elif shape.type == "QUADS":
glBegin(GL_QUADS)
elif shape.type == "QUAD_STRIP":
glBegin(GL_QUAD_STRIP)
else:
logger.error("Invalid type for geometric primitive!")
raise NameError

# set color
glColor3f(shape.color[0], shape.color[1], shape.color[2])

# create vertices
for vertex in shape.vertices:
glVertex2f(vertex[0], vertex[1])
glEnd()

def run_graphics(self, loop=True):
"""
Sequentially displays scenes in self.data
If loop is True, keep rendering until user closes the window.
"""
global _IMPORT_SUCESSFUL

if _IMPORT_SUCESSFUL:
pg.init()
display = (self.window_width, self.window_height)
pg.display.set_mode(display, DOUBLEBUF | OPENGL)
pg.display.set_caption(self.window_name)
self.initGL()
while True:
for event in pg.event.get():
if event.type == pg.QUIT:
pg.quit()
return
#
self.display()
#
pg.display.flip()
pg.time.wait(self.refresh_interval)

# if not loop, stop
if not loop:
pg.quit()
return
else:
logger.error(
f"Not possible to render the environment due to the following error: {_IMPORT_ERROR_MSG}"
)
return

def get_gl_image_str(self):
# see https://gist.github.com/Jerdak/7364746
glReadBuffer(GL_FRONT)
pixels = glReadPixels(
0, 0, self.window_width, self.window_height, GL_RGB, GL_UNSIGNED_BYTE
)
return pixels

def get_video_data(self):
"""
Stores scenes in self.data in a list of numpy arrays that can be used
to save a video.
"""
global _IMPORT_SUCESSFUL

if _IMPORT_SUCESSFUL:
video_data = []

pg.init()
display = (self.window_width, self.window_height)
_ = pg.display.set_mode(display, DOUBLEBUF | OPENGL)
pg.display.set_caption(self.window_name)
self.initGL()

self.time_count = 0
while self.time_count <= len(self.data):
#
self.display()
#
pg.display.flip()

#
# See https://stackoverflow.com/a/42754578/5691288
#
string_image = self.get_gl_image_str()
temp_surf = pg.image.frombytes(
string_image, (self.window_width, self.window_height), "RGB"
)
tmp_arr = pg.surfarray.array3d(temp_surf)
imgdata = np.moveaxis(tmp_arr, 0, 1)
imgdata = np.flipud(imgdata)
video_data.append(imgdata)

pg.quit()
return video_data
else:
logger.error(
f"Not possible to render the environment due to the following error: {_IMPORT_ERROR_MSG}"
)
return []
Loading

0 comments on commit e5423ff

Please sign in to comment.