Skip to content

Commit

Permalink
Add ground truth back to the plot
Browse files Browse the repository at this point in the history
  • Loading branch information
anthonio9 committed Jan 27, 2024
1 parent 583ea20 commit af24713
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions penn/plot/logits/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,13 @@ def process_logits(logits: torch.Tensor):
distributions = distributions.cpu().squeeze(2).T
new_distributions = distributions

figsize=(18, 2)
figsize=(18, 10)

# Prepare the ptich posteriorgram in case it's multipitch
if len(distributions.shape) == 4:
distr_chunk = torch.chunk(distributions, distributions.shape[-2], -2)
distr_chunk = [distr.squeeze(dim=-2).squeeze(dim=0) for distr in distr_chunk]
new_distributions = torch.vstack(distr_chunk)
figsize = (18, 10)

return new_distributions, figsize

Expand All @@ -95,7 +94,8 @@ def logits_matplotlib(logits, bins=None, voiced=None, stem=None):
axis.spines['right'].set_visible(False)
axis.spines['bottom'].set_visible(False)
axis.spines['left'].set_visible(False)
xticks = torch.arange(0, len(distributions), int(penn.SAMPLE_RATE / penn.HOPSIZE))

xticks = torch.arange(0, distributions.shape[-1], int(penn.SAMPLE_RATE / penn.HOPSIZE))
xlabels = torch.round(xticks * (penn.HOPSIZE / penn.SAMPLE_RATE)).int()
axis.get_xaxis().set_ticks(xticks.tolist(), xlabels.tolist())

Expand All @@ -113,29 +113,30 @@ def logits_matplotlib(logits, bins=None, voiced=None, stem=None):
axis.set_ylabel('Frequency (Hz)')
axis.set_title(f"track: {stem}")

if bins is not None and voiced is not None and not penn.LOSS_MULTI_HOT:
if bins is not None and voiced is not None:
nbins = bins.detach().cpu().numpy()
nvoiced = voiced.detach().cpu().numpy()

nbins = nbins.squeeze().T
nvoiced = nvoiced.squeeze().T

offset = np.arange(0, penn.PITCH_CATS)*penn.PITCH_BINS
offset = np.arange(0, penn.PITCH_CATS)*penn.PITCH_BINS * int(not penn.LOSS_MULTI_HOT)

nbins += offset

nbins_masked = np.ma.MaskedArray(nbins, np.logical_not(nvoiced))

axis.plot(nbins_masked, 'r--', linewidth=2)

if predicted_bins is not None:
npredicted_bins = predicted_bins.detach().cpu().numpy()
npredicted_bins = npredicted_bins.squeeze().T

npredicted_bins += offset
npredicted_bins_masked = np.ma.MaskedArray(npredicted_bins, np.logical_not(nvoiced))

axis.plot(npredicted_bins_masked, 'b:', linewidth=2)
for nbins_row in range(penn.PITCH_CATS):
axis.plot(nbins_masked[:, nbins_row], 'r--', linewidth=1)

# if predicted_bins is not None:
# npredicted_bins = predicted_bins.detach().cpu().numpy()
# npredicted_bins = npredicted_bins.squeeze().T
#
# npredicted_bins += offset
# npredicted_bins_masked = np.ma.MaskedArray(npredicted_bins, np.logical_not(nvoiced))
#
# axis.plot(npredicted_bins_masked, 'b:', linewidth=2)

axis.imshow(distributions, aspect='auto', origin='lower')

Expand Down

0 comments on commit af24713

Please sign in to comment.