Skip to content

Commit

Permalink
Update plotting utility
Browse files Browse the repository at this point in the history
  • Loading branch information
itskalvik committed Sep 12, 2024
1 parent 61f5c4f commit daad713
Showing 1 changed file with 52 additions and 19 deletions.
71 changes: 52 additions & 19 deletions benchmarks/plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -21,8 +21,16 @@
"num_robots = 4\n",
"sampling_rate = 5\n",
"dataset = 'bathymetry'\n",
"budget = True\n",
"\n",
"if budget:\n",
" budget = '_B'\n",
"else:\n",
" budget = ''\n",
"\n",
"filename = f'../datasets/results/{dataset}_{num_robots}R_{sampling_rate}S.json'\n",
"remove_online = False\n",
"\n",
"filename = f'results/{dataset}_{num_robots}R_{sampling_rate}S{budget}.json'\n",
"save_filename = filename[:-5] + '-{}.pdf'\n",
"results = json.load(open(filename, 'r'))\n",
"\n",
Expand All @@ -36,10 +44,9 @@
"ipp_time_std = defaultdict(list)\n",
"\n",
"xrange = np.array(list(results.keys())).astype(int)\n",
"\n",
"for num_sensors in results.keys():\n",
" for method in results[num_sensors].keys():\n",
" if 'Agg' in method:\n",
" continue\n",
" for metric in results[num_sensors][method].keys():\n",
" if metric=='RMSE':\n",
" rmse[method].append(np.mean(results[num_sensors][method][metric]))\n",
Expand All @@ -57,41 +64,67 @@
" ipp_time[method].append(np.mean(data))\n",
" ipp_time_std[method].append(np.std(data))\n",
"\n",
"colors = ['C2', 'C3', 'C0', 'C1', 'C4', 'C5']\n",
"datasets = ['intel', 'precipitation', 'soil', 'salinity']\n",
"colors = ['C2', 'C3', 'C0', 'C1', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']\n",
"methods = ['Adaptive-SGP', \n",
" 'Adaptive-CMA-ES', \n",
" 'Online-SGP', \n",
" 'Online-CMA-ES']\n",
"\n",
"if sampling_rate == 2 and num_robots == 1 and len(budget) == 0:\n",
" methods.append('Online-BO')\n",
" methods.append('Online-Greedy-MI')\n",
" methods.append('Online-Greedy-SGP')\n",
" methods.append('Online-Discrete-SGP')\n",
"\n",
"'''\n",
"# Discrete Online Methods\n",
"colors = ['C2', 'C3', 'C1']\n",
"methods = ['Online-Discrete-SGP',\n",
" 'Online-Greedy-MI', \n",
" 'Online-Greedy-SGP']\n",
"\n",
"# Continuous Online Methods\n",
"methods = ['Online-SGP', \n",
" 'Online-CMA-ES', \n",
" 'Online-BO']\n",
"'''\n",
"\n",
"plt.figure()\n",
"for i, key in enumerate(rmse.keys()):\n",
"for i, key in enumerate(methods):\n",
" xrange_ = xrange[:len(rmse[key])]\n",
" label = key.strip().replace('Adaptive-SGP', 'Adaptive-SGP (Ours)')\n",
" if 'CMA-ES' in key:\n",
" label = key.replace('CMA-ES', 'CIPP')\n",
" plt.plot(xrange, rmse[key], label=label, c=colors[i])\n",
" plt.fill_between(xrange,\n",
" label = label.replace('CMA-ES', 'CIPP')\n",
" if remove_online:\n",
" label = label.replace('Online-', '')\n",
" plt.plot(xrange_, rmse[key], label=label, c=colors[i])\n",
" plt.fill_between(xrange_,\n",
" np.array(rmse[key])+rmse_std[key], \n",
" np.array(rmse[key])-rmse_std[key], \n",
" alpha=0.2,\n",
" color=colors[i])\n",
"plt.legend()\n",
"plt.legend(loc='upper right')\n",
"plt.xlabel(\"Number of Waypoints\")\n",
"plt.ylabel(\"RMSE\")\n",
"plt.savefig(f'{dataset}_{num_robots}R_{sampling_rate}S-RMSE.pdf', bbox_inches='tight')\n",
"plt.savefig(f'{dataset}_{num_robots}R_{sampling_rate}S{budget}-RMSE.pdf', bbox_inches='tight')\n",
"plt.show()\n",
"\n",
"plt.figure()\n",
"for i, key in enumerate(ipp_time.keys()):\n",
"for i, key in enumerate(methods):\n",
" xrange_ = xrange[:len(ipp_time[key])]\n",
" label = key.strip().replace('Adaptive-SGP', 'Adaptive-SGP (Ours)')\n",
" if 'CMA-ES' in key:\n",
" label = key.replace('CMA-ES', 'CIPP')\n",
" plt.plot(xrange, ipp_time[key], label=label, c=colors[i])\n",
" plt.fill_between(xrange,\n",
" label = label.replace('CMA-ES', 'CIPP')\n",
" if remove_online:\n",
" label = label.replace('Online-', '')\n",
" plt.plot(xrange_, ipp_time[key], label=label, c=colors[i])\n",
" plt.fill_between(xrange_,\n",
" np.array(ipp_time[key])+ipp_time_std[key], \n",
" np.array(ipp_time[key])-ipp_time_std[key], \n",
" alpha=0.2,\n",
" color=colors[i])\n",
"plt.legend(loc='upper left')\n",
"plt.xlabel(\"Number of Waypoints\")\n",
"plt.ylabel(\"IPP Runtime (s)\")\n",
"plt.savefig(f'{dataset}_{num_robots}R_{sampling_rate}S-IPP.pdf', bbox_inches='tight')\n",
"plt.savefig(f'{dataset}_{num_robots}R_{sampling_rate}S{budget}-IPP.pdf', bbox_inches='tight')\n",
"plt.show()"
]
}
Expand Down

0 comments on commit daad713

Please sign in to comment.