From 409d0d99c204ee43e43e14ca2c1f1662d7a5d924 Mon Sep 17 00:00:00 2001 From: itskalvik Date: Sat, 17 Aug 2024 13:21:26 -0400 Subject: [PATCH] Update data utils docs --- docs/API-reference.md | 2 + sgptools/models/cma_es.py | 1 + sgptools/utils/data.py | 103 +++++++++++++++++++++++++++++++------- sgptools/utils/misc.py | 28 ++++++++--- 4 files changed, 110 insertions(+), 24 deletions(-) diff --git a/docs/API-reference.md b/docs/API-reference.md index 363d9dd..521194f 100644 --- a/docs/API-reference.md +++ b/docs/API-reference.md @@ -34,6 +34,8 @@ ________________________________________________________________________________ ::: sgptools.utils.metrics --- ::: sgptools.utils.gpflow +--- +::: sgptools.utils.data ____________________________________________________________________________________________________________________________________________________________ --- ::: sgptools.kernels.neural_kernel diff --git a/sgptools/models/cma_es.py b/sgptools/models/cma_es.py index 9b10e4f..7ee2439 100644 --- a/sgptools/models/cma_es.py +++ b/sgptools/models/cma_es.py @@ -33,6 +33,7 @@ class CMA_ES: as waypoints of a path num_robots (int): Number of robots, used when modeling multi-robot IPP with a distance budget + transform (Transform): Transform object """ def __init__(self, X_train, noise_variance, kernel, distance_budget=None, diff --git a/sgptools/utils/data.py b/sgptools/utils/data.py index da7520f..8b1e816 100644 --- a/sgptools/utils/data.py +++ b/sgptools/utils/data.py @@ -18,47 +18,86 @@ from sklearn.preprocessing import StandardScaler from hkb_diamondsquare.DiamondSquare import diamond_square +# Load optional dependency try: from osgeo import gdal except: pass + #################################################### # Utils used to prepare synthetic datasets -''' -Remove points inside polygons -''' def remove_polygons(X, Y, polygons): + ''' + Remove points inside polygons. + + Args: + X (ndarray): (N,); array of x-coordinate + Y (ndarray): (N,); array of y-coordinate + polygons (list of matplotlib path polygon): Polygons to remove from the X, Y points + + Returns: + X (ndarray): (N,); array of x-coordinate + Y (ndarray): (N,); array of y-coordinate + ''' points = np.array([X.flatten(), Y.flatten()]).T for polygon in polygons: p = path.Path(polygon) points = points[~p.contains_points(points)] return points[:, 0], points[:, 1] -''' -Remove points inside circle patches -''' def remove_circle_patches(X, Y, circle_patches): + ''' + Remove points inside polycircle patchesgons. + + Args: + X (ndarray): (N,); array of x-coordinate + Y (ndarray): (N,); array of y-coordinate + polygons (list of matplotlib circle patches): Circle patches to remove from the X, Y points + + Returns: + X (ndarray): (N,); array of x-coordinate + Y (ndarray): (N,); array of y-coordinate + ''' points = np.array([X.flatten(), Y.flatten()]).T for circle_patch in circle_patches: points = points[~circle_patch.contains_points(points)] return points[:, 0], points[:, 1] -''' -Generate a point at a distance d from a point at angle theta - -Args: - point: (N, 2) array of points - d: distance - theta: angle in radians -''' def point_pos(point, d, theta): + ''' + Generate a point at a distance d from a point at angle theta. + + Args: + point (ndarray): (N, 2); array of points + d (float): distance + theta (float): angle in radians + + Returns: + X (ndarray): (N,); array of x-coordinate + Y (ndarray): (N,); array of y-coordinate + ''' return np.c_[point[:, 0] + d*np.cos(theta), point[:, 1] + d*np.sin(theta)] #################################################### -def prep_tif_dataset(dataset_path=None): +def prep_tif_dataset(dataset_path): + '''Load and preprocess a dataset from a GeoTIFF file (.tif file). The input features + are set to the x and y pixel block coordinates and the labels are read from the file. + The method also removes all invalid points. + + Large tif files + need to be downsampled using the following command: + ```gdalwarp -tr 50 50 .tif .tif``` + + Args: + dataset_path (str): Path to the dataset file, used only when dataset_type is 'tif'. + + Returns: + X: (n, d); Dataset input features + y: (n, 1); Dataset labels + ''' ds = gdal.Open(dataset_path) cols = ds.RasterXSize rows = ds.RasterYSize @@ -75,6 +114,15 @@ def prep_tif_dataset(dataset_path=None): #################################################### def prep_synthetic_dataset(): + '''Generates a 50x50 grid of synthetic elevation data using the diamond square algorithm. + ```https://github.com/buckinha/DiamondSquare``` + + Args: + + Returns: + X: (n, d); Dataset input features + y: (n, 1); Dataset labels + ''' data = diamond_square(shape=(50,50), min_height=0, max_height=30, @@ -92,15 +140,34 @@ def prep_synthetic_dataset(): #################################################### -def get_dataset(dataset, dataset_path=None, +def get_dataset(dataset_type, dataset_path=None, num_train=1000, num_test=2500, num_candidates=150): + """Method to generate/load datasets and preprocess them for SP/IPP. The method uses kmeans to + generate train and test sets. + Args: + dataset_type (str): 'tif' or 'synthetic'. 'tif' will load and proprocess data from a GeoTIFF file. + 'synthetic' will use the diamond square algorithm to generate synthetic elevation data. + dataset_path (str): Path to the dataset file, used only when dataset_type is 'tif'. + num_train (int): Number of training samples to generate. + num_test (int): Number of testing samples to generate. + num_candidates (int): Number of candidate locations to generate. + + Returns: + X_train (ndarray): (n, d); Training set inputs + y_train (ndarray): (n, 1); Training set labels + X_test (ndarray): (n, d); Testing set inputs + y_test (ndarray): (n, 1); Testing set labels + candidates (ndarray): (n, d); Candidate sensor placement locations + X: (n, d); Full dataset inputs + y: (n, 1); Full dataset labels + """ # Load the data - if dataset == 'tif': + if dataset_type == 'tif': X, y = prep_tif_dataset(dataset_path=dataset_path) - elif dataset == 'synthetic': + elif dataset_type == 'synthetic': X, y = prep_synthetic_dataset() X_train = get_inducing_pts(X, num_train) diff --git a/sgptools/utils/misc.py b/sgptools/utils/misc.py index d39617b..4583d74 100644 --- a/sgptools/utils/misc.py +++ b/sgptools/utils/misc.py @@ -103,17 +103,33 @@ def interpolate_path(waypoints, sampling_rate=0.05): interpolated_path.extend(points) return np.array(interpolated_path) -# Reorder the waypoints to match the order of the points in the path -# The waypoints are mathched to the closest points in the path -def reoder_path(path, waypoints): +def _reoder_path(path, waypoints): + """Reorder the waypoints to match the order of the points in the path. + The waypoints are mathched to the closest points in the path. Used by project_waypoints. + + Args: + path (n, d): Robot path, i.e., waypoints in the path traversal order + waypoints (n, d): Waypoints that need to be reordered to match the target path + + Returns: + waypoints (n, d): Reordered waypoints of the robot's path + """ dists = pairwise_distances(path, Y=waypoints, metric='euclidean') _, col_ind = linear_sum_assignment(dists) Xu = waypoints[col_ind].copy() return Xu -# Project the waypoints back to the candidate set while retaining the -# waypoint visitation order def project_waypoints(waypoints, candidates): + """Project the waypoints back to the candidate set while retaining the + waypoint visitation order. + + Args: + waypoints (n, d): Waypoints of the robot's path + candidates (ndarray): (n, 2); Discrete set of candidate locations + + Returns: + waypoints (n, d): Projected waypoints of the robot's path + """ waypoints_disc = cont2disc(waypoints, candidates) - waypoints_valid = reoder_path(waypoints, waypoints_disc) + waypoints_valid = _reoder_path(waypoints, waypoints_disc) return waypoints_valid \ No newline at end of file