Skip to content

Commit

Permalink
WIP: set plots to grey, fix y axis alignment and common scale for logits
Browse files Browse the repository at this point in the history
related to: #5
  • Loading branch information
anthonio9 committed Jul 4, 2024
1 parent 93185dd commit 6ec3ec3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
6 changes: 6 additions & 0 deletions penn/plot/to_latex/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ def from_audio(
times = penn.HOPSIZE_SECONDS * np.arange(pitch[0].shape[-1])
periodicity = periodicity.detach().numpy()

logits = torch.nan_to_num(
logits,
neginf=torch.min(logits[torch.logical_not(torch.isneginf(logits))]),
posinf=torch.max(logits[torch.logical_not(torch.isposinf(logits))])
)
logits = torch.softmax(logits, dim=2)

return pitch, times, periodicity, logits


Expand Down
14 changes: 12 additions & 2 deletions penn/plot/to_latex/mplt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def plot_logits(axes : plt.Axes,
"""
logits = logits.squeeze()

style = {
"norm" : plt.matplotlib.colors.LogNorm(vmin=logits.min(),
vmax=logits.max())
}

# breakpoint()

# divide logits into strings
logits_chunks = np.split(logits, logits.shape[1], axis=1)

Expand All @@ -27,7 +34,9 @@ def plot_logits(axes : plt.Axes,

for axis, logits_chunk in zip(axes, logits_chunks):
logits_chunk = logits_chunk.squeeze(axis=1).T
axis.pcolormesh(times, freqs, logits_chunk)
axis.pcolormesh(times, freqs, logits_chunk,
cmap="gray_r",
**style)
axis.set_ylim([0, 500])
axis.set_xlim([times[0], times[-1]])

Expand Down Expand Up @@ -242,6 +251,7 @@ def plot_with_matplotlib(
# Create plot
figure, axes = plt.subplots(nrows=penn.PITCH_CATS,
ncols=1)
axes = np.flip(axes)
else:
# Create plot
figure, axis = plt.subplots(figsize=(7, 3))
Expand Down Expand Up @@ -304,7 +314,7 @@ def plot_with_matplotlib(

figure.suptitle(f"Pitch thresholded with periodicity above {threshold}, {title}")

# figure.legend(handles, labels, loc='lower right')
figure.legend(handles, labels, loc='lower right')
figure.set_tight_layout({'pad' : 0.5})

# figure.show()
Expand Down

0 comments on commit 6ec3ec3

Please sign in to comment.