Skip to content

Commit

Permalink
Fix some input type issues in `supermarq-benchmarks/supermarq.plottin…
Browse files Browse the repository at this point in the history
…g.py` (#1024)

Fixes type check issues found in
#1018 and addresses
#1018 (comment)
  • Loading branch information
bharat-thotakura authored Aug 9, 2024
1 parent 0e72d8e commit af8942e
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,9 @@
],
"source": [
"supermarq.plotting.plot_benchmark(\n",
" [\"A single GHZ benchmark\", [\"ghz10\"], [ghz_features]],\n",
" title=\"A single GHZ benchmark\",\n",
" labels=[\"ghz10\"],\n",
" features=[ghz_features],\n",
" spoke_labels=[\"PC\", \"CD\", \"Ent\", \"Liv\", \"Mea\", \"Par\"],\n",
")"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,9 @@
],
"source": [
"supermarq.plotting.plot_benchmark(\n",
" [\"A single GHZ benchmark\", [\"ghz10\"], [ghz_features]],\n",
" title=\"A single GHZ benchmark\",\n",
" labels=[\"ghz10\"],\n",
" features=[ghz_features],\n",
" spoke_labels=[\"PC\", \"CD\", \"Ent\", \"Liv\", \"Mea\", \"Par\"],\n",
")"
]
Expand Down
17 changes: 10 additions & 7 deletions supermarq-benchmarks/supermarq/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ def plot_correlations(


def plot_benchmark(
data: list[str | list[str] | list[list[float]]],
title: str,
labels: list[str],
features: list[list[float]],
show: bool = True,
savefn: str | None = None,
spoke_labels: list[str] | None = None,
Expand All @@ -191,9 +193,10 @@ def plot_benchmark(
"""Create a radar plot showing the feature vectors of the given benchmarks.
Args:
data: Contains the title, feature data, and labels in the format:
[title, [benchmark labels], [[features_1], [features_2], ...]].
show: Display the plot using `plt.show`.
title: The string title of the plot.
labels: A list of string benchmark labels for the plot.
features: A list of feature data in the format: [[features_1], [features_2], ...]].
show: Boolean flag to display the plot using `plt.show`.
savefn: Path to save the plot, if `None`, the plot is not saved.
spoke_labels: Optional labels for the feature vector dimensions.
legend_loc: Optional argument to fine tune the legend placement.
Expand All @@ -205,18 +208,18 @@ def plot_benchmark(
theta = radar_factory(num_spokes)

_, ax = plt.subplots(dpi=150, subplot_kw=dict(projection="radar"))
assert isinstance(ax, RadarAxesMeta)

title, labels, case_data = data
ax.set_rgrids([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_title(
title,
str(title),
weight="bold",
size="medium",
position=(0.5, 1.1),
horizontalalignment="center",
verticalalignment="center",
)
for d, label in zip(case_data, labels):
for d, label in zip(features, labels):
ax.plot(theta, d, label=label)
ax.fill(theta, d, alpha=0.25)
ax.set_varlabels(spoke_labels)
Expand Down
12 changes: 6 additions & 6 deletions supermarq-benchmarks/supermarq/plotting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@

def test_plot_benchmark() -> None:
supermarq.plotting.plot_benchmark(
["test title", ["b1", "b2", "b3"], [[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]]],
"test title",
["b1", "b2", "b3"],
[[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3]],
spoke_labels=["f1", "f2", "f3"],
show=False,
)

supermarq.plotting.plot_benchmark(
[
"test title",
["b1", "b2", "b3", "b4", "b5"],
[[0.1] * 5, [0.2] * 5, [0.3] * 5, [0.4] * 5, [0.5] * 5],
],
"test title",
["b1", "b2", "b3", "b4", "b5"],
[[0.1] * 5, [0.2] * 5, [0.3] * 5, [0.4] * 5, [0.5] * 5],
spoke_labels=None,
show=False,
)
Expand Down

0 comments on commit af8942e

Please sign in to comment.