Skip to content

Commit

Permalink
Add Python Interface (#38)
Browse files Browse the repository at this point in the history
* Make members of reach study class protected

* Added Python bindings

* Moved Python bindings of interfaces into separate files; created separate file for Python utilities

* Simplified python binding classes/functions

* Updated Python bindings to utilize updated structures

* Use CMake subdirectory for Python code

* Updated CMakeLists

* Fixed free function definitions

* Changed module name and install location

* Added heat map scripts

* Added Python test script

* Updated boost python object to YAML function

* Added compile definition for Python build flag

* Updated file path for local files in unit test

* Ran clang-format

* Revised to/from Eigen utilities

* Added python bindings for heat map colorization functions

* Updated heat map generator script to utilize internal colorization functions

* Renamed utils.h to utils.hpp

* Install Python scripts

* Ran cmake format

* Added Python requirements file

* Updated CI configuration to run Python unit tests

* Removed unnecessary forward declaration
  • Loading branch information
marip8 authored Dec 6, 2022
1 parent 0e65fa1 commit 711595d
Show file tree
Hide file tree
Showing 20 changed files with 1,222 additions and 8 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/focal_noetic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
ROS_DISTRO: noetic
PREFIX: ${{ github.repository }}_

jobs:
ci:
Expand Down Expand Up @@ -56,9 +57,11 @@ jobs:
UPSTREAM_WORKSPACE: dependencies.repos
PREFIX: ${{ github.repository }}_
CMAKE_ARGS: '-DENABLE_TESTING=ON -DENABLE_RUN_TESTING=OFF'
BEFORE_RUN_TARGET_TEST_EMBED: 'ici_with_unset_variables source $BASEDIR/${PREFIX}target_ws/install/setup.bash'
DOCKER_IMAGE: 'ros:${{ env.ROS_DISTRO }}'
DOCKER_COMMIT: ${{ steps.meta.outputs.tags }}
AFTER_INSTALL_TARGET_DEPENDENCIES: 'python3 -m pip install -r reach/requirements.txt -qq'
BEFORE_RUN_TARGET_TEST_EMBED: 'ici_with_unset_variables source $BASEDIR/${PREFIX}target_ws/install/setup.bash'
AFTER_RUN_TARGET_TEST: 'rosenv python3 -m pytest -v'

- name: Push post-build Docker
if: ${{ github.ref == 'refs/heads/master' || github.event_name == 'release' }}
Expand Down
31 changes: 25 additions & 6 deletions reach/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@ find_package(ros_industrial_cmake_boilerplate REQUIRED)
extract_package_metadata(pkg)
project(${pkg_extracted_name} VERSION ${pkg_extracted_version} LANGUAGES CXX)

option(BUILD_PYTHON "Build Python bindings" ON)

# Python dependencies need to be found first
if(BUILD_PYTHON)
find_package(Python REQUIRED COMPONENTS Interpreter Development)
find_package(PythonLibs 3 REQUIRED)
find_package(Boost REQUIRED COMPONENTS python numpy)
endif()

find_package(Boost REQUIRED COMPONENTS serialization program_options)
find_package(Eigen3 REQUIRED)
find_package(PCL REQUIRED COMPONENTS io search)
Expand All @@ -23,6 +32,8 @@ if(OPENMP_FOUND)
endif()
endif()

set(TARGETS "")

# Interface library
add_library(${PROJECT_NAME}_interface INTERFACE)
target_include_directories(${PROJECT_NAME}_interface INTERFACE "$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>"
Expand All @@ -35,8 +46,10 @@ target_compile_definitions(
EVALUATOR_SECTION="eval"
IK_SOLVER_SECTION="ik"
LOGGER_SECTION="logger"
TARGET_POSE_GEN_SECTION="pose")
TARGET_POSE_GEN_SECTION="pose"
BUILD_PYTHON=${BUILD_PYTHON})
target_cxx_version(${PROJECT_NAME}_interface INTERFACE VERSION 14)
list(APPEND TARGETS ${PROJECT_NAME}_interface)

# Reach Study Library
add_library(
Expand All @@ -56,6 +69,7 @@ target_link_libraries(
target_compile_definitions(${PROJECT_NAME} PUBLIC SEARCH_LIBRARIES_ENV="REACH_PLUGINS"
PLUGIN_LIBRARIES="${PROJECT_NAME}_plugins:reach_ros_plugins")
target_cxx_version(${PROJECT_NAME} PUBLIC VERSION 14)
list(APPEND TARGETS ${PROJECT_NAME})

# Plugins Library
add_library(
Expand All @@ -72,16 +86,25 @@ target_link_libraries(
${PCL_LIBRARIES}
boost_plugin_loader::boost_plugin_loader)
target_cxx_version(${PROJECT_NAME}_plugins PUBLIC VERSION 14)
list(APPEND TARGETS ${PROJECT_NAME}_plugins)

# Reach Study App
add_executable(${PROJECT_NAME}_app src/app/reach_study.cpp)
target_link_libraries(${PROJECT_NAME}_app PRIVATE ${PROJECT_NAME} Boost::program_options)
target_cxx_version(${PROJECT_NAME}_app PUBLIC VERSION 14)
list(APPEND TARGETS ${PROJECT_NAME}_app)

# Data Loader App
add_executable(${PROJECT_NAME}_data_loader src/app/data_loader.cpp)
target_link_libraries(${PROJECT_NAME}_data_loader PRIVATE ${PROJECT_NAME} Boost::program_options)
target_cxx_version(${PROJECT_NAME}_data_loader PUBLIC VERSION 14)
list(APPEND TARGETS ${PROJECT_NAME}_data_loader)

if(BUILD_PYTHON)
message("Building Python bindings")
add_subdirectory(src/python)
install(DIRECTORY scripts/ DESTINATION bin)
endif()

# ######################################################################################################################
# TEST ##
Expand Down Expand Up @@ -109,8 +132,4 @@ configure_package(
yaml-cpp
boost_plugin_loader
OpenMP
TARGETS ${PROJECT_NAME}_interface
${PROJECT_NAME}
${PROJECT_NAME}_plugins
${PROJECT_NAME}_app
${PROJECT_NAME}_data_loader)
TARGETS ${TARGETS})
18 changes: 18 additions & 0 deletions reach/include/reach/interfaces/display.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ namespace YAML
class Node;
}

#ifdef BUILD_PYTHON
namespace boost
{
namespace python
{
class dict;
} // namespace python
} // namespace boost
#endif

namespace reach
{
/**
Expand Down Expand Up @@ -52,6 +62,10 @@ struct Display

/** @brief Visualizes the results of a reach study */
virtual void showResults(const ReachResult& db) const = 0;

#ifdef BUILD_PYTHON
void updateRobotPose(const boost::python::dict&) const;
#endif
};

/**
Expand All @@ -71,6 +85,10 @@ struct DisplayFactory
{
return DISPLAY_SECTION;
}

#ifdef BUILD_PYTHON
Display::ConstPtr create(const boost::python::dict& pyyaml_config) const;
#endif
};

} // namespace reach
Expand Down
18 changes: 18 additions & 0 deletions reach/include/reach/interfaces/evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ namespace YAML
class Node;
}

#ifdef BUILD_PYTHON
namespace boost
{
namespace python
{
class dict;
}
} // namespace boost
#endif

namespace reach
{
/**
Expand All @@ -43,6 +53,10 @@ struct Evaluator
* @details The better the reachability of the pose, the higher the score should be.
*/
virtual double calculateScore(const std::map<std::string, double>& pose) const = 0;

#ifdef BUILD_PYTHON
double calculateScore(const boost::python::dict& pose) const;
#endif
};

/**
Expand All @@ -62,6 +76,10 @@ struct EvaluatorFactory
{
return EVALUATOR_SECTION;
}

#ifdef BUILD_PYTHON
Evaluator::ConstPtr create(const boost::python::dict& pyyaml_config) const;
#endif
};

} // namespace reach
Expand Down
23 changes: 23 additions & 0 deletions reach/include/reach/interfaces/ik_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ namespace YAML
class Node;
}

#ifdef BUILD_PYTHON
namespace boost
{
namespace python
{
namespace numpy
{
class ndarray;
}
class list;
class dict;
} // namespace python
} // namespace boost
#endif

namespace reach
{
/**
Expand All @@ -46,6 +61,10 @@ struct IKSolver
/** @brief Solves IK for a given target pose and seed state */
virtual std::vector<std::vector<double>> solveIK(const Eigen::Isometry3d& target,
const std::map<std::string, double>& seed) const = 0;

#ifdef BUILD_PYTHON
boost::python::list solveIK(const boost::python::numpy::ndarray& target, const boost::python::dict& seed) const;
#endif
};

/** @brief Plugin interface for generating IK solver interfaces */
Expand All @@ -63,6 +82,10 @@ struct IKSolverFactory
{
return IK_SOLVER_SECTION;
}

#ifdef BUILD_PYTHON
IKSolver::ConstPtr create(const boost::python::dict& pyyaml_config) const;
#endif
};

} // namespace reach
Expand Down
14 changes: 14 additions & 0 deletions reach/include/reach/interfaces/logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ namespace YAML
class Node;
}

#ifdef BUILD_PYTHON
namespace boost
{
namespace python
{
class dict;
} // namespace python
} // namespace boost
#endif

namespace reach
{
class ReachResultSummary;
Expand Down Expand Up @@ -40,6 +50,10 @@ struct LoggerFactory
{
return LOGGER_SECTION;
}

#ifdef BUILD_PYTHON
Logger::Ptr create(const boost::python::dict& pyyaml_config) const;
#endif
};

} // namespace reach
Expand Down
14 changes: 14 additions & 0 deletions reach/include/reach/interfaces/target_pose_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ namespace YAML
class Node;
}

#ifdef BUILD_PYTHON
namespace boost
{
namespace python
{
class dict;
} // namespace python
} // namespace boost
#endif

namespace reach
{
/** @brief Interface for generating Cartesian target poses for the reach study */
Expand Down Expand Up @@ -43,6 +53,10 @@ struct TargetPoseGeneratorFactory
{
return TARGET_POSE_GEN_SECTION;
}

#ifdef BUILD_PYTHON
TargetPoseGenerator::ConstPtr create(const boost::python::dict& pyyaml_config) const;
#endif
};

} // namespace reach
Expand Down
2 changes: 1 addition & 1 deletion reach/include/reach/reach_study.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class ReachStudy
*/
std::tuple<double, double> getAverageNeighborsCount() const;

private:
protected:
Parameters params_;
ReachDatabase db_;

Expand Down
6 changes: 6 additions & 0 deletions reach/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
numpy
scipy
open3d
tqdm
matplotlib
pytest
60 changes: 60 additions & 0 deletions reach/scripts/heat_map_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import argparse
import numpy as np
import open3d as o3d
import os.path
from reach import ReachDatabase, load, computeHeatMapColors, normalizeScores
from scipy.interpolate import RBFInterpolator


def main():
parser = argparse.ArgumentParser(description="Generate a reachability heatmap from a preexisting reach database.")
parser.add_argument(type=str, dest="db_file", help="Filepath of the reach database")
parser.add_argument(type=str, dest="mesh_file", help='Filepath of the part mesh')
parser.add_argument("-k", "--kernel", type=str, default="thin_plate_spline", help="Kernel for RBF interpolation")
parser.add_argument("-e", "--epsilon", type=float, default=None, help="Shape parameter for RBF interpolation")
parser.add_argument("-s", "--smoothing", type=float, default=0.0, help="Smoothing parameter for RBF interpolation")
parser.add_argument("-o", "--output-mesh", type=str, default=None, help="Filepath for output heatmap")
parser.add_argument("-n", "--number-subdivisions", type=int, default=2,
help="Order of subdivision. Each triangle is divided once for n iterations")
parser.add_argument("-fcr", "--full-color-range", action='store_true', default=False,
help="Display scores using the full color range rather than only scaling scores by the max")
args = parser.parse_args()

# Load database
if not os.path.exists(args.db_file):
raise FileExistsError(f'File \'{args.db_file}\' does not exist')
db = load(args.db_file)

# Use the last set of results in the database
res = db.results[-1]

# Loop over records in database to extract point position and scores into Numpy array
positions = np.array([r.goal()[0:3, 3] for r in res])
scores = normalizeScores(res, args.full_color_range)

# Calculate the RBF
rbf = RBFInterpolator(y=positions, d=scores, kernel=args.kernel, epsilon=args.epsilon,
smoothing=args.smoothing)

# Load the mesh and subdivide it
mesh = o3d.io.read_triangle_mesh(args.mesh_file).subdivide_midpoint(args.number_subdivisions)

# Extract the vertices of the sub-sampled mesh as a numpy array and calculate the interpolated score for each
dims = rbf.y.shape[1]
vert_scores = rbf(np.asarray(mesh.vertices)[:, :dims])

# Clip the scores on [0, 1]
vert_scores = np.clip(vert_scores, a_min=0.0, a_max=1.0)

# Colorize the mesh vertices
mesh.vertex_colors = o3d.utility.Vector3dVector(computeHeatMapColors(vert_scores.tolist()))

# Visualize the output
o3d.visualization.draw_geometries([mesh], mesh_show_wireframe=False)

if args.output_mesh is not None:
o3d.io.write_triangle_mesh(args.output_mesh, mesh)


if __name__ == '__main__':
main()
Loading

0 comments on commit 711595d

Please sign in to comment.