From 38b5c205a97a1772e1d103e7187a26d71dad0494 Mon Sep 17 00:00:00 2001 From: antoni Date: Fri, 2 Feb 2024 00:50:11 +0200 Subject: [PATCH] Iterate over multiple logits plots related to: https://github.com/anthonio9/penn/issues/5 --- penn/plot/logits/__main__.py | 4 ++++ penn/plot/logits/core.py | 35 ++++++++++++++++++++--------------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/penn/plot/logits/__main__.py b/penn/plot/logits/__main__.py index ec07a59..0eee4af 100644 --- a/penn/plot/logits/__main__.py +++ b/penn/plot/logits/__main__.py @@ -28,6 +28,10 @@ def parse_args(): '--gpu', type=int, help='The index of the GPU to use for inference') + parser.add_argument( + '--iters', + type=int, + help='Number of dummy iterations on the loader before extracting the data') return parser.parse_known_args()[0] diff --git a/penn/plot/logits/core.py b/penn/plot/logits/core.py index 66adde9..7256b70 100644 --- a/penn/plot/logits/core.py +++ b/penn/plot/logits/core.py @@ -84,7 +84,7 @@ def logits_matplotlib(logits, bins=None, voiced=None, stem=None): predicted_bins, pitch, periodicity = penn.postprocess(logits) # Change font size - matplotlib.rcParams.update({'font.size': 5}) + matplotlib.rcParams.update({'font.size': 10}) # Setup figure figure, axis = plt.subplots(figsize=figsize) @@ -101,14 +101,14 @@ def logits_matplotlib(logits, bins=None, voiced=None, stem=None): yticks = torch.linspace(0, penn.PITCH_BINS - 1, 5) ylabels = penn.convert.bins_to_frequency(yticks) - ylabels_chunk = [ylabels for _ in range(penn.PITCH_CATS)] - ylabels = torch.cat(ylabels_chunk) - yticks = torch.linspace(0, - penn.PITCH_BINS * penn.PITCH_CATS - 1, - 5 * penn.PITCH_CATS) - - ylabels = ylabels.round().int().tolist() - axis.get_yaxis().set_ticks(yticks, ylabels) + # ylabels_chunk = [ylabels for _ in range(penn.PITCH_CATS)] + # ylabels = torch.cat(ylabels_chunk) + # yticks = torch.linspace(0, + # penn.PITCH_BINS * penn.PITCH_CATS - 1, + # 5 * penn.PITCH_CATS) + + # ylabels = ylabels.round().int().tolist() + # axis.get_yaxis().set_ticks(yticks, ylabels) axis.set_xlabel('Time (seconds)') axis.set_ylabel('Frequency (Hz)') axis.set_title(f"track: {stem}") @@ -127,7 +127,7 @@ def logits_matplotlib(logits, bins=None, voiced=None, stem=None): nbins_masked = np.ma.MaskedArray(nbins, np.logical_not(nvoiced)) for nbins_row in range(penn.PITCH_CATS): - axis.plot(nbins_masked[:, nbins_row], 'r--', linewidth=1) + axis.plot(nbins_masked[:, nbins_row], 'r--', marker='o', linewidth=1) # if predicted_bins is not None: # npredicted_bins = predicted_bins.detach().cpu().numpy() @@ -177,12 +177,15 @@ def from_audio( ############################################################################### -def from_model_and_testset(model, loader, gpu=None): +def from_model_and_testset(model, loader, gpu=None, iters=0): device = torch.device('cpu' if gpu is None else f'cuda:{gpu}') # Prepare model for inference with penn.inference_context(model): + for iter in range(iters): + next(loader) + # Iterate over test set audio, bins, _, voiced, stem = next(loader) @@ -213,7 +216,7 @@ def from_model_and_testset(model, loader, gpu=None): return logits_matplotlib(logits, bins, voiced, stem) -def from_testset(checkpoint=None, gpu=None): +def from_testset(checkpoint=None, gpu=None, iters=0): # Initialize model model = penn.Model() @@ -229,7 +232,7 @@ def from_testset(checkpoint=None, gpu=None): loader = penn.data.loader(penn.EVALUATION_DATASETS, 'test') loader = iter(loader) - return from_model_and_testset(model, loader) + return from_model_and_testset(model, loader, iters=iters) def from_file(audio_file, checkpoint=None, gpu=None): @@ -241,13 +244,15 @@ def from_file(audio_file, checkpoint=None, gpu=None): return from_audio(audio, penn.SAMPLE_RATE, checkpoint, gpu) -def from_file_to_file(audio_file=None, output_file=None, checkpoint=None, gpu=None): +def from_file_to_file(audio_file=None, output_file=None, checkpoint=None, gpu=None, iters=0): """Plot pitch and periodicity and save to disk""" # Plot if audio_file is not None: figure = from_file(audio_file, checkpoint, gpu) else: - figure = from_testset(checkpoint, gpu) + figure = from_testset(checkpoint, gpu, iters=iters) + + breakpoint() # Save to disk if output_file is not None: