Skip to content

Commit

Permalink
Add the polyphonic logits plot to wandb logging
Browse files Browse the repository at this point in the history
related to: #5
  • Loading branch information
anthonio9 committed Jan 18, 2024
1 parent 37d7693 commit 8629663
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 42 deletions.
2 changes: 0 additions & 2 deletions penn/plot/logits/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ def parse_args():
parser = argparse.ArgumentParser(description='Create logits figure')
parser.add_argument(
'--audio_file',
required=True,
type=Path,
help='The audio file to plot the logits of')
parser.add_argument(
'--output_file',
required=True,
type=Path,
help='The jpg file to save the plot')
parser.add_argument(
Expand Down
213 changes: 173 additions & 40 deletions penn/plot/logits/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,111 @@

import penn

import plotly.express as px
import plotly.graph_objects as go


###############################################################################
# Create figure with plotly (compatible with wandb)
###############################################################################


def logits_plotly(
logits,
ground_truth_logits=None,
):

# convert to numpy
logits = logits.cpu().numpy()
logits_and_gnd = logits

if ground_truth_logits is not None:
ground_truth_logits = ground_truth_logits.cpu().numpy()
logits_and_gnd = np.concatenate((logits, ground_truth_logits), axis=0)

lh = logits.shape[0]
lw = logits.shape[0]


fig = px.imshow(
logits_and_gnd,
color_continuous_scale=px.colors.sequential.Cividis_r,
height=lh,
width= lh if lw <= lh else lw)

return fig


###############################################################################
# Create figure
###############################################################################


def logits_matplotlib(logits, bins=None, figsize=(18, 10)):
import matplotlib
import matplotlib.pyplot as plt

# Change font size
matplotlib.rcParams.update({'font.size': 5})

# Setup figure
figure, axis = plt.subplots(figsize=figsize)

# Make pretty
axis.spines['top'].set_visible(False)
axis.spines['right'].set_visible(False)
axis.spines['bottom'].set_visible(False)
axis.spines['left'].set_visible(False)
xticks = torch.arange(0, len(logits), int(penn.SAMPLE_RATE / penn.HOPSIZE))
xlabels = torch.round(xticks * (penn.HOPSIZE / penn.SAMPLE_RATE)).int()
axis.get_xaxis().set_ticks(xticks.tolist(), xlabels.tolist())

yticks = torch.linspace(0, penn.PITCH_BINS - 1, 5)
ylabels = penn.convert.bins_to_frequency(yticks)
ylabels_chunk = [ylabels for _ in range(penn.PITCH_CATS)]
ylabels = torch.cat(ylabels_chunk)
yticks = torch.linspace(0,
penn.PITCH_BINS * penn.PITCH_CATS - 1,
5 * penn.PITCH_CATS)

ylabels = ylabels.round().int().tolist()
axis.get_yaxis().set_ticks(yticks, ylabels)
axis.set_xlabel('Time (seconds)')
axis.set_ylabel('Frequency (Hz)')

# if bins is not None:
# bins = bins.squeeze(dim=0)
# bins_chunks = bins.chunk(penn.PITCH_CATS, dim=0)
# bins_chunks_hot = [
# torch.nn.functional.one_hot(chunk, penn.PITCH_BINS).float().squeeze(0).T
# for chunk in bins_chunks]
#
# bins = torch.vstack(bins_chunks_hot)
# bins[bins == 0] = logits.min()
# logits = torch.cat((logits, bins), dim=-1)
#
# xticks = torch.cat((xticks, xticks))
# xlabels = torch.cat((xlabels, xlabels))
# axis.get_xaxis().set_ticks(xticks.tolist(), xlabels.tolist())

# Plot pitch posteriorgram
# if len(distributions.shape) == 4:
# axis.imshow(new_distributions, extent=[0,100,0,1], aspect=80, origin='lower')
# else:
# axis.imshow(new_distributions, aspect='auto', origin='lower')
axis.imshow(logits, aspect='auto', origin='lower')

return figure


def from_audio(
audio,
sample_rate,
checkpoint=None,
gpu=None):
"""Plot logits with pitch overlay"""
import matplotlib
import matplotlib.pyplot as plt

logits = []

# Change font size
matplotlib.rcParams.update({'font.size': 5})

# Preprocess audio
for frames in penn.preprocess(
audio,
Expand Down Expand Up @@ -63,43 +148,85 @@ def from_audio(
new_distributions = torch.vstack(distr_chunk)
figsize = (18, 10)

# Setup figure
figure, axis = plt.subplots(figsize=figsize)
return logits_matplotlib(new_distributions, figsize)

# Make pretty
axis.spines['top'].set_visible(False)
axis.spines['right'].set_visible(False)
axis.spines['bottom'].set_visible(False)
axis.spines['left'].set_visible(False)
xticks = torch.arange(0, len(logits), int(penn.SAMPLE_RATE / penn.HOPSIZE))
xlabels = torch.round(xticks * (penn.HOPSIZE / penn.SAMPLE_RATE)).int()
axis.get_xaxis().set_ticks(xticks.tolist(), xlabels.tolist())

yticks = torch.linspace(0, penn.PITCH_BINS - 1, 5)
ylabels = penn.convert.bins_to_frequency(yticks)

if len(distributions.shape) == 4:
no_poly_cats = distributions.shape[-2]
ylabels = penn.convert.bins_to_frequency(yticks)
ylabels_chunk = [ylabels for _ in range(no_poly_cats)]
ylabels = torch.cat(ylabels_chunk)
yticks = torch.linspace(0,
penn.PITCH_BINS * no_poly_cats - 1,
5 * no_poly_cats)
###############################################################################
# Plotting the logits from a testset
###############################################################################

ylabels = ylabels.round().int().tolist()
axis.get_yaxis().set_ticks(yticks, ylabels)
axis.set_xlabel('Time (seconds)')
axis.set_ylabel('Frequency (Hz)')

# Plot pitch posteriorgram
# if len(distributions.shape) == 4:
# axis.imshow(new_distributions, extent=[0,100,0,1], aspect=80, origin='lower')
# else:
# axis.imshow(new_distributions, aspect='auto', origin='lower')
axis.imshow(new_distributions, aspect='auto', origin='lower')
def from_model_and_testset(model, loader, gpu=None):
device = torch.device('cpu' if gpu is None else f'cuda:{gpu}')

return figure
# Prepare model for inference
with penn.inference_context(model):

# Iterate over test set
audio, bins, _, voiced, stem = next(loader)

# Accumulate logits
logits = []

# Preprocess audio
batch_size = \
None if gpu is None else penn.EVALUATION_BATCH_SIZE
for frames in penn.preprocess(
audio[0],
penn.SAMPLE_RATE,
batch_size=batch_size,
center='half-hop'
):

# Copy to device
frames = frames.to(device)

# Infer
batch_logits = model(frames)

# Accumulate logits
logits.append(batch_logits)

logits = torch.cat(logits)

distributions = torch.nn.functional.softmax(logits, dim=1)

# Take the log again for display
distributions = torch.log(distributions)
distributions[torch.isinf(distributions)] = \
distributions[~torch.isinf(distributions)].min()

# Prepare for plotting
distributions = distributions.cpu().squeeze(2).T

# Prepare the ptich posteriorgram in case it's multipitch
if len(distributions.shape) == 4:
distr_chunk = torch.chunk(distributions, distributions.shape[-2], -2)
distr_chunk = [distr.squeeze(dim=-2).squeeze(dim=0) for distr in distr_chunk]
new_distributions = torch.vstack(distr_chunk)
figsize = (18, 10)

return logits_matplotlib(new_distributions, bins)


def from_testset(checkpoint=None, gpu=None):
# Initialize model
model = penn.Model()

# Maybe download from HuggingFace
if checkpoint is None:
return

checkpoint = torch.load(checkpoint, map_location='cpu')

# Load from disk
model.load_state_dict(checkpoint['model'])

loader = penn.data.loader(penn.EVALUATION_DATASETS, 'test')
loader = iter(loader)

return from_model_and_testset(model, loader)


def from_file(audio_file, checkpoint=None, gpu=None):
Expand All @@ -111,10 +238,16 @@ def from_file(audio_file, checkpoint=None, gpu=None):
return from_audio(audio, penn.SAMPLE_RATE, checkpoint, gpu)


def from_file_to_file(audio_file, output_file, checkpoint=None, gpu=None):
def from_file_to_file(audio_file=None, output_file=None, checkpoint=None, gpu=None):
"""Plot pitch and periodicity and save to disk"""
# Plot
figure = from_file(audio_file, checkpoint, gpu)
if audio_file is not None:
figure = from_file(audio_file, checkpoint, gpu)
else:
figure = from_testset(checkpoint, gpu)

breakpoint()

# Save to disk
figure.savefig(output_file, bbox_inches='tight', pad_inches=0, dpi=900)
if output_file is not None:
figure.savefig(output_file, bbox_inches='tight', pad_inches=0, dpi=900)
10 changes: 10 additions & 0 deletions penn/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def train(datasets, directory, gpu=None, use_wand=False):
torch.manual_seed(penn.RANDOM_SEED)
train_loader = penn.data.loader(datasets, 'train')
valid_loader = penn.data.loader(datasets, 'valid')
test_loader = penn.data.loader(datasets, 'test')
test_loader = iter(test_loader)

####################
# Create optimizer #
Expand Down Expand Up @@ -160,6 +162,14 @@ def train(datasets, directory, gpu=None, use_wand=False):
step=step,
epoch=epoch)

fig = penn.plot.logits.from_model_and_testset(
model=model,
loader=test_loader,
gpu=gpu)

if use_wand:
log_wandb.log({"test_logits": wandb.Image(fig)})

# Evaluate
if step % penn.LOG_INTERVAL == 0:
evaluate_fn = functools.partial(
Expand Down

0 comments on commit 8629663

Please sign in to comment.