diff --git a/penn/plot/logits/core.py b/penn/plot/logits/core.py index 4e81cc9..66adde9 100644 --- a/penn/plot/logits/core.py +++ b/penn/plot/logits/core.py @@ -63,14 +63,13 @@ def process_logits(logits: torch.Tensor): distributions = distributions.cpu().squeeze(2).T new_distributions = distributions - figsize=(18, 2) + figsize=(18, 10) # Prepare the ptich posteriorgram in case it's multipitch if len(distributions.shape) == 4: distr_chunk = torch.chunk(distributions, distributions.shape[-2], -2) distr_chunk = [distr.squeeze(dim=-2).squeeze(dim=0) for distr in distr_chunk] new_distributions = torch.vstack(distr_chunk) - figsize = (18, 10) return new_distributions, figsize @@ -95,7 +94,8 @@ def logits_matplotlib(logits, bins=None, voiced=None, stem=None): axis.spines['right'].set_visible(False) axis.spines['bottom'].set_visible(False) axis.spines['left'].set_visible(False) - xticks = torch.arange(0, len(distributions), int(penn.SAMPLE_RATE / penn.HOPSIZE)) + + xticks = torch.arange(0, distributions.shape[-1], int(penn.SAMPLE_RATE / penn.HOPSIZE)) xlabels = torch.round(xticks * (penn.HOPSIZE / penn.SAMPLE_RATE)).int() axis.get_xaxis().set_ticks(xticks.tolist(), xlabels.tolist()) @@ -113,29 +113,30 @@ def logits_matplotlib(logits, bins=None, voiced=None, stem=None): axis.set_ylabel('Frequency (Hz)') axis.set_title(f"track: {stem}") - if bins is not None and voiced is not None and not penn.LOSS_MULTI_HOT: + if bins is not None and voiced is not None: nbins = bins.detach().cpu().numpy() nvoiced = voiced.detach().cpu().numpy() nbins = nbins.squeeze().T nvoiced = nvoiced.squeeze().T - offset = np.arange(0, penn.PITCH_CATS)*penn.PITCH_BINS + offset = np.arange(0, penn.PITCH_CATS)*penn.PITCH_BINS * int(not penn.LOSS_MULTI_HOT) nbins += offset nbins_masked = np.ma.MaskedArray(nbins, np.logical_not(nvoiced)) - axis.plot(nbins_masked, 'r--', linewidth=2) - - if predicted_bins is not None: - npredicted_bins = predicted_bins.detach().cpu().numpy() - npredicted_bins = npredicted_bins.squeeze().T - - npredicted_bins += offset - npredicted_bins_masked = np.ma.MaskedArray(npredicted_bins, np.logical_not(nvoiced)) - - axis.plot(npredicted_bins_masked, 'b:', linewidth=2) + for nbins_row in range(penn.PITCH_CATS): + axis.plot(nbins_masked[:, nbins_row], 'r--', linewidth=1) + + # if predicted_bins is not None: + # npredicted_bins = predicted_bins.detach().cpu().numpy() + # npredicted_bins = npredicted_bins.squeeze().T + # + # npredicted_bins += offset + # npredicted_bins_masked = np.ma.MaskedArray(npredicted_bins, np.logical_not(nvoiced)) + # + # axis.plot(npredicted_bins_masked, 'b:', linewidth=2) axis.imshow(distributions, aspect='auto', origin='lower')