Skip to content

Commit

Permalink
Iterate over multiple logits plots
Browse files Browse the repository at this point in the history
related to: #5
  • Loading branch information
anthonio9 committed Feb 1, 2024
1 parent af24713 commit 38b5c20
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
4 changes: 4 additions & 0 deletions penn/plot/logits/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def parse_args():
'--gpu',
type=int,
help='The index of the GPU to use for inference')
parser.add_argument(
'--iters',
type=int,
help='Number of dummy iterations on the loader before extracting the data')
return parser.parse_known_args()[0]


Expand Down
35 changes: 20 additions & 15 deletions penn/plot/logits/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def logits_matplotlib(logits, bins=None, voiced=None, stem=None):
predicted_bins, pitch, periodicity = penn.postprocess(logits)

# Change font size
matplotlib.rcParams.update({'font.size': 5})
matplotlib.rcParams.update({'font.size': 10})

# Setup figure
figure, axis = plt.subplots(figsize=figsize)
Expand All @@ -101,14 +101,14 @@ def logits_matplotlib(logits, bins=None, voiced=None, stem=None):

yticks = torch.linspace(0, penn.PITCH_BINS - 1, 5)
ylabels = penn.convert.bins_to_frequency(yticks)
ylabels_chunk = [ylabels for _ in range(penn.PITCH_CATS)]
ylabels = torch.cat(ylabels_chunk)
yticks = torch.linspace(0,
penn.PITCH_BINS * penn.PITCH_CATS - 1,
5 * penn.PITCH_CATS)

ylabels = ylabels.round().int().tolist()
axis.get_yaxis().set_ticks(yticks, ylabels)
# ylabels_chunk = [ylabels for _ in range(penn.PITCH_CATS)]
# ylabels = torch.cat(ylabels_chunk)
# yticks = torch.linspace(0,
# penn.PITCH_BINS * penn.PITCH_CATS - 1,
# 5 * penn.PITCH_CATS)

# ylabels = ylabels.round().int().tolist()
# axis.get_yaxis().set_ticks(yticks, ylabels)
axis.set_xlabel('Time (seconds)')
axis.set_ylabel('Frequency (Hz)')
axis.set_title(f"track: {stem}")
Expand All @@ -127,7 +127,7 @@ def logits_matplotlib(logits, bins=None, voiced=None, stem=None):
nbins_masked = np.ma.MaskedArray(nbins, np.logical_not(nvoiced))

for nbins_row in range(penn.PITCH_CATS):
axis.plot(nbins_masked[:, nbins_row], 'r--', linewidth=1)
axis.plot(nbins_masked[:, nbins_row], 'r--', marker='o', linewidth=1)

# if predicted_bins is not None:
# npredicted_bins = predicted_bins.detach().cpu().numpy()
Expand Down Expand Up @@ -177,12 +177,15 @@ def from_audio(
###############################################################################


def from_model_and_testset(model, loader, gpu=None):
def from_model_and_testset(model, loader, gpu=None, iters=0):
device = torch.device('cpu' if gpu is None else f'cuda:{gpu}')

# Prepare model for inference
with penn.inference_context(model):

for iter in range(iters):
next(loader)

# Iterate over test set
audio, bins, _, voiced, stem = next(loader)

Expand Down Expand Up @@ -213,7 +216,7 @@ def from_model_and_testset(model, loader, gpu=None):
return logits_matplotlib(logits, bins, voiced, stem)


def from_testset(checkpoint=None, gpu=None):
def from_testset(checkpoint=None, gpu=None, iters=0):
# Initialize model
model = penn.Model()

Expand All @@ -229,7 +232,7 @@ def from_testset(checkpoint=None, gpu=None):
loader = penn.data.loader(penn.EVALUATION_DATASETS, 'test')
loader = iter(loader)

return from_model_and_testset(model, loader)
return from_model_and_testset(model, loader, iters=iters)


def from_file(audio_file, checkpoint=None, gpu=None):
Expand All @@ -241,13 +244,15 @@ def from_file(audio_file, checkpoint=None, gpu=None):
return from_audio(audio, penn.SAMPLE_RATE, checkpoint, gpu)


def from_file_to_file(audio_file=None, output_file=None, checkpoint=None, gpu=None):
def from_file_to_file(audio_file=None, output_file=None, checkpoint=None, gpu=None, iters=0):
"""Plot pitch and periodicity and save to disk"""
# Plot
if audio_file is not None:
figure = from_file(audio_file, checkpoint, gpu)
else:
figure = from_testset(checkpoint, gpu)
figure = from_testset(checkpoint, gpu, iters=iters)

breakpoint()

# Save to disk
if output_file is not None:
Expand Down

0 comments on commit 38b5c20

Please sign in to comment.