Skip to content

Commit

Permalink
WIP: fix polypennfcn plots
Browse files Browse the repository at this point in the history
  • Loading branch information
anthonio9 committed Jul 1, 2024
1 parent 5e51594 commit 93185dd
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 35 deletions.
4 changes: 1 addition & 3 deletions penn/plot/to_latex/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,10 @@ def from_audio(
# Infer
logits.append(penn.infer(frames, checkpoint=checkpoint).detach())

breakpoint()

# Concatenate results
if penn.FCN:
logits = torch.cat(logits, dim=-2)
logits = logits.permute(2, 1, 3, 0)
logits = logits.permute(3, 1, 2, 0)
else:
logits = torch.cat(logits)
pitch = None
Expand Down
63 changes: 31 additions & 32 deletions penn/plot/to_latex/mplt.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def plot_pitch(axis : plt.Axes,
"""
max_pitch = []
min_pitch = []
breakpoint()

mask_for_pitch = pitch!=0

Expand Down Expand Up @@ -260,22 +259,22 @@ def plot_with_matplotlib(
else:
plot_stft(axes, audio, sr, time_offset=time_offset)

# if pred_pitch is not None and pred_times is not None:
# if mutlipitch:
# plot_multipitch(
# axes, pred_pitch, pred_times,
# linewidth=0.5,
# periodicity=periodicity,
# threshold=threshold,
# label="predicted")
# else:
# plot_pitch(
# axes[0], pred_pitch, pred_times,
# linewidth=0.5,
# periodicity=periodicity,
# threshold=threshold,
# label="predicted")
#
if pred_pitch is not None and pred_times is not None:
if mutlipitch:
plot_multipitch(
axes, pred_pitch, pred_times,
linewidth=0.5,
periodicity=periodicity,
threshold=threshold,
label="predicted")
else:
plot_pitch(
axes[0], pred_pitch, pred_times,
linewidth=0.5,
periodicity=periodicity,
threshold=threshold,
label="predicted")

if gt_pitch is not None and gt_times is not None:
if mutlipitch:
plot_multipitch(
Expand All @@ -287,21 +286,21 @@ def plot_with_matplotlib(
axes[0], gt_pitch, gt_times,
plot_red=True,
label="truth")
#
# # prepare the legend
# handles, labels = axes[-1].get_legend_handles_labels()
#
# for ind, axis in enumerate(axes):
# axis.set_title(f"String {ind}", x=0.06, y=0.7, color='r')
#
# if periodicity is not None:
# if mutlipitch:
# t_handles, t_labels = plot_multiperiodicity(axes, periodicity, pred_times, threshold)
# else:
# t_handles, t_labels = plot_periodicity(axes[0], periodicity, pred_times, threshold)
#
# handles.extend(t_handles)
# labels.extend(t_labels)

# prepare the legend
handles, labels = axes[-1].get_legend_handles_labels()

for ind, axis in enumerate(axes):
axis.set_title(f"String {ind}", x=0.06, y=0.7, color='r')

if periodicity is not None:
if mutlipitch:
t_handles, t_labels = plot_multiperiodicity(axes, periodicity, pred_times, threshold)
else:
t_handles, t_labels = plot_periodicity(axes[0], periodicity, pred_times, threshold)

handles.extend(t_handles)
labels.extend(t_labels)

figure.suptitle(f"Pitch thresholded with periodicity above {threshold}, {title}")

Expand Down

0 comments on commit 93185dd

Please sign in to comment.