-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* moving chain and gridworld to rlberry_scool * adding save_gif and moving back the rendering to rlberry (main)
- Loading branch information
Showing
12 changed files
with
968 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .core import Scene, GeometricPrimitive | ||
from .render_interface import RenderInterface | ||
from .render_interface import RenderInterface2D |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import numpy as np | ||
from rlberry.rendering import GeometricPrimitive | ||
|
||
|
||
def bar_shape(p0, p1, width): | ||
shape = GeometricPrimitive("QUADS") | ||
|
||
x0, y0 = p0 | ||
x1, y1 = p1 | ||
|
||
direction = np.array([x1 - x0, y1 - y0]) | ||
norm = np.sqrt((direction * direction).sum()) | ||
direction = direction / norm | ||
|
||
# get vector perpendicular to direction | ||
u_vec = np.zeros(2) | ||
u_vec[0] = -direction[1] | ||
u_vec[1] = direction[0] | ||
|
||
u_vec = u_vec * width / 2 | ||
|
||
shape.add_vertex((x0 + u_vec[0], y0 + u_vec[1])) | ||
shape.add_vertex((x0 - u_vec[0], y0 - u_vec[1])) | ||
shape.add_vertex((x1 - u_vec[0], y1 - u_vec[1])) | ||
shape.add_vertex((x1 + u_vec[0], y1 + u_vec[1])) | ||
return shape | ||
|
||
|
||
def circle_shape(center, radius, n_points=50): | ||
shape = GeometricPrimitive("POLYGON") | ||
|
||
x0, y0 = center | ||
theta = np.linspace(0.0, 2 * np.pi, n_points) | ||
for tt in theta: | ||
xx = radius * np.cos(tt) | ||
yy = radius * np.sin(tt) | ||
shape.add_vertex((x0 + xx, y0 + yy)) | ||
|
||
return shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
""" | ||
Provide classes for geometric primitives in OpenGL and scenes. | ||
""" | ||
|
||
|
||
class Scene: | ||
""" | ||
Class representing a scene, which is a vector of GeometricPrimitive objects | ||
""" | ||
|
||
def __init__(self): | ||
self.shapes = [] | ||
|
||
def add_shape(self, shape): | ||
self.shapes.append(shape) | ||
|
||
|
||
class GeometricPrimitive: | ||
""" | ||
Class representing an OpenGL geometric primitive. | ||
Primitive type (GL_LINE_LOOP by defaut) | ||
If using OpenGLRender2D, one of the following: | ||
POINTS | ||
LINES | ||
LINE_STRIP | ||
LINE_LOOP | ||
POLYGON | ||
TRIANGLES | ||
TRIANGLE_STRIP | ||
TRIANGLE_FAN | ||
QUADS | ||
QUAD_STRIP | ||
If using PyGameRender2D: | ||
POLYGON | ||
TODO: Add support to more pygame shapes, | ||
see https://www.pygame.org/docs/ref/draw.html | ||
""" | ||
|
||
def __init__(self, primitive_type="GL_LINE_LOOP"): | ||
# primitive type | ||
self.type = primitive_type | ||
# color in RGB | ||
self.color = (0.25, 0.25, 0.25) | ||
# list of vertices. each vertex is a tuple with coordinates in space | ||
self.vertices = [] | ||
|
||
def add_vertex(self, vertex): | ||
self.vertices.append(vertex) | ||
|
||
def set_color(self, color): | ||
self.color = color |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
""" | ||
OpenGL code for 2D rendering, using pygame. | ||
""" | ||
|
||
import numpy as np | ||
from os import environ | ||
|
||
from rlberry.rendering import Scene | ||
|
||
import rlberry | ||
|
||
logger = rlberry.logger | ||
environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1" | ||
|
||
_IMPORT_SUCESSFUL = True | ||
_IMPORT_ERROR_MSG = "" | ||
try: | ||
import pygame as pg | ||
from pygame.locals import DOUBLEBUF, OPENGL | ||
|
||
from OpenGL.GLU import gluOrtho2D | ||
from OpenGL.GL import glMatrixMode, glLoadIdentity, glClearColor | ||
from OpenGL.GL import glClear, glFlush, glBegin, glEnd | ||
from OpenGL.GL import glColor3f, glVertex2f | ||
from OpenGL.GL import glReadBuffer, glReadPixels | ||
from OpenGL.GL import GL_PROJECTION, GL_COLOR_BUFFER_BIT | ||
from OpenGL.GL import GL_POINTS, GL_LINES, GL_LINE_STRIP, GL_LINE_LOOP | ||
from OpenGL.GL import GL_POLYGON, GL_TRIANGLES, GL_TRIANGLE_STRIP | ||
from OpenGL.GL import GL_TRIANGLE_FAN, GL_QUADS, GL_QUAD_STRIP | ||
from OpenGL.GL import GL_FRONT, GL_RGB, GL_UNSIGNED_BYTE | ||
|
||
except Exception as ex: | ||
_IMPORT_SUCESSFUL = False | ||
_IMPORT_ERROR_MSG = str(ex) | ||
|
||
|
||
class OpenGLRender2D: | ||
""" | ||
Class to render a list of scenes using OpenGL and pygame. | ||
""" | ||
|
||
def __init__(self): | ||
# parameters | ||
self.window_width = 800 | ||
self.window_height = 800 # multiples of 16 are preferred | ||
self.background_color = (0.6, 0.75, 1.0) | ||
self.refresh_interval = 50 | ||
self.window_name = "rlberry render" | ||
self.clipping_area = (-1.0, 1.0, -1.0, 1.0) | ||
|
||
# time counter | ||
self.time_count = 0 | ||
|
||
# background scene | ||
self.background = Scene() | ||
# data to be rendered (list of scenes) | ||
self.data = [] | ||
|
||
def set_window_name(self, name): | ||
self.window_name = name | ||
|
||
def set_refresh_interval(self, interval): | ||
self.refresh_interval = interval | ||
|
||
def set_clipping_area(self, area): | ||
""" | ||
The clipping area is tuple with elements (left, right, bottom, top) | ||
Default = (-1.0, 1.0, -1.0, 1.0) | ||
""" | ||
self.clipping_area = area | ||
base_size = max(self.window_width, self.window_height) | ||
width_range = area[1] - area[0] | ||
height_range = area[3] - area[2] | ||
base_range = max(width_range, height_range) | ||
width_range /= base_range | ||
height_range /= base_range | ||
self.window_width = int(base_size * width_range) | ||
self.window_height = int(base_size * height_range) | ||
|
||
# width and height must be divisible by 2 | ||
if self.window_width % 2 == 1: | ||
self.window_width += 1 | ||
if self.window_height % 2 == 1: | ||
self.window_height += 1 | ||
|
||
def set_data(self, data): | ||
self.data = data | ||
|
||
def set_background(self, background): | ||
self.background = background | ||
|
||
def initGL(self): | ||
""" | ||
initialize GL | ||
""" | ||
glMatrixMode(GL_PROJECTION) | ||
glLoadIdentity() | ||
gluOrtho2D( | ||
self.clipping_area[0], | ||
self.clipping_area[1], | ||
self.clipping_area[2], | ||
self.clipping_area[3], | ||
) | ||
|
||
def display(self): | ||
""" | ||
Callback function, handler for window re-paint | ||
""" | ||
# Set background color (clear background) | ||
glClearColor( | ||
self.background_color[0], | ||
self.background_color[1], | ||
self.background_color[2], | ||
1.0, | ||
) | ||
glClear(GL_COLOR_BUFFER_BIT) | ||
|
||
# Display background | ||
for shape in self.background.shapes: | ||
self.draw_geometric2d(shape) | ||
|
||
# Display objects | ||
if len(self.data) > 0: | ||
idx = self.time_count % len(self.data) | ||
for shape in self.data[idx].shapes: | ||
self.draw_geometric2d(shape) | ||
|
||
self.time_count += 1 | ||
glFlush() | ||
|
||
@staticmethod | ||
def draw_geometric2d(shape): | ||
""" | ||
Draw a 2D shape, of type GeometricPrimitive | ||
""" | ||
if shape.type == "POINTS": | ||
glBegin(GL_POINTS) | ||
elif shape.type == "LINES": | ||
glBegin(GL_LINES) | ||
elif shape.type == "LINE_STRIP": | ||
glBegin(GL_LINE_STRIP) | ||
elif shape.type == "LINE_LOOP": | ||
glBegin(GL_LINE_LOOP) | ||
elif shape.type == "POLYGON": | ||
glBegin(GL_POLYGON) | ||
elif shape.type == "TRIANGLES": | ||
glBegin(GL_TRIANGLES) | ||
elif shape.type == "TRIANGLE_STRIP": | ||
glBegin(GL_TRIANGLE_STRIP) | ||
elif shape.type == "TRIANGLE_FAN": | ||
glBegin(GL_TRIANGLE_FAN) | ||
elif shape.type == "QUADS": | ||
glBegin(GL_QUADS) | ||
elif shape.type == "QUAD_STRIP": | ||
glBegin(GL_QUAD_STRIP) | ||
else: | ||
logger.error("Invalid type for geometric primitive!") | ||
raise NameError | ||
|
||
# set color | ||
glColor3f(shape.color[0], shape.color[1], shape.color[2]) | ||
|
||
# create vertices | ||
for vertex in shape.vertices: | ||
glVertex2f(vertex[0], vertex[1]) | ||
glEnd() | ||
|
||
def run_graphics(self, loop=True): | ||
""" | ||
Sequentially displays scenes in self.data | ||
If loop is True, keep rendering until user closes the window. | ||
""" | ||
global _IMPORT_SUCESSFUL | ||
|
||
if _IMPORT_SUCESSFUL: | ||
pg.init() | ||
display = (self.window_width, self.window_height) | ||
pg.display.set_mode(display, DOUBLEBUF | OPENGL) | ||
pg.display.set_caption(self.window_name) | ||
self.initGL() | ||
while True: | ||
for event in pg.event.get(): | ||
if event.type == pg.QUIT: | ||
pg.quit() | ||
return | ||
# | ||
self.display() | ||
# | ||
pg.display.flip() | ||
pg.time.wait(self.refresh_interval) | ||
|
||
# if not loop, stop | ||
if not loop: | ||
pg.quit() | ||
return | ||
else: | ||
logger.error( | ||
f"Not possible to render the environment due to the following error: {_IMPORT_ERROR_MSG}" | ||
) | ||
return | ||
|
||
def get_gl_image_str(self): | ||
# see https://gist.github.com/Jerdak/7364746 | ||
glReadBuffer(GL_FRONT) | ||
pixels = glReadPixels( | ||
0, 0, self.window_width, self.window_height, GL_RGB, GL_UNSIGNED_BYTE | ||
) | ||
return pixels | ||
|
||
def get_video_data(self): | ||
""" | ||
Stores scenes in self.data in a list of numpy arrays that can be used | ||
to save a video. | ||
""" | ||
global _IMPORT_SUCESSFUL | ||
|
||
if _IMPORT_SUCESSFUL: | ||
video_data = [] | ||
|
||
pg.init() | ||
display = (self.window_width, self.window_height) | ||
_ = pg.display.set_mode(display, DOUBLEBUF | OPENGL) | ||
pg.display.set_caption(self.window_name) | ||
self.initGL() | ||
|
||
self.time_count = 0 | ||
while self.time_count <= len(self.data): | ||
# | ||
self.display() | ||
# | ||
pg.display.flip() | ||
|
||
# | ||
# See https://stackoverflow.com/a/42754578/5691288 | ||
# | ||
string_image = self.get_gl_image_str() | ||
temp_surf = pg.image.frombytes( | ||
string_image, (self.window_width, self.window_height), "RGB" | ||
) | ||
tmp_arr = pg.surfarray.array3d(temp_surf) | ||
imgdata = np.moveaxis(tmp_arr, 0, 1) | ||
imgdata = np.flipud(imgdata) | ||
video_data.append(imgdata) | ||
|
||
pg.quit() | ||
return video_data | ||
else: | ||
logger.error( | ||
f"Not possible to render the environment due to the following error: {_IMPORT_ERROR_MSG}" | ||
) | ||
return [] |
Oops, something went wrong.