Skip to content

Commit

Permalink
WIP: Introduce a time offset and duration of choice
Browse files Browse the repository at this point in the history
related to: #5
  • Loading branch information
anthonio9 committed Jun 3, 2024
1 parent c2c8bb0 commit 2834fde
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 20 deletions.
8 changes: 8 additions & 0 deletions penn/plot/to_latex/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ def parse_args():
'--gpu',
type=int,
help='The index of the GPU to use for inference')
parser.add_argument(
'--start',
type=float,
help='Start timestamp of the audio file in seconds')
parser.add_argument(
'--duration',
type=float,
help='Duration of the audio excerpt in seconds')
return parser.parse_known_args()[0]


Expand Down
56 changes: 48 additions & 8 deletions penn/plot/to_latex/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,38 @@ def from_audio(
logits = torch.cat(logits)
pitch = None
times = None
periodicity = None

with torchutil.time.context('decode'):
# pitch is in Hz
predicted, pitch, periodicity = penn.postprocess(logits)
pitch = pitch.detach().numpy()[0, ...]
pitch = np.split(pitch, pitch.shape[0])
# pitch = np.split(pitch, pitch.shape[0])
times = penn.HOPSIZE_SECONDS * np.arange(pitch[0].shape[-1])
periodicity = periodicity.detach().numpy()

return pitch, times
return pitch, times, periodicity


def get_ground_truth(ground_truth_file):
def get_ground_truth(ground_truth_file,
start : float=0,
duration : float=None):
assert isfile(ground_truth_file)

jams_track = jams.load(str(ground_truth_file))
duration = jams_track.file_metadata.duration
notes_dict = penn.data.preprocess.jams_to_notes(jams_track)
pitch_array, times_array = penn.data.preprocess.notes_dict_to_pitch_array(notes_dict, duration)
return pitch_array, times_array

start_frame = 0
end_frame = -1

start_frame = np.argmin(np.abs(times_array - start))

if duration is not None:
end_frame = np.argmin(np.abs(times_array - start - duration))

return pitch_array[..., start_frame:end_frame], times_array[..., start_frame:end_frame]


def plot_over_gt_with_plotly(audio, sr, pred_freq, pred_times, gt, return_fig=False):
Expand Down Expand Up @@ -113,23 +126,48 @@ def plot_over_gt_with_plotly(audio, sr, pred_freq, pred_times, gt, return_fig=Fa
fig.show()


def from_file_to_file(audio_file, ground_truth_file, checkpoint, output_file=None, gpu=None):
def from_file_to_file(audio_file,
ground_truth_file,
checkpoint,
output_file=None,
gpu=None,
start : float=0.0,
duration : float=None):
# Load audio
audio = penn.load.audio(audio_file)

if checkpoint is None:
return

# get the timestamps in frame numbers
start_frame = round(start * penn.SAMPLE_RATE)

end_frame = -1
if duration is not None:
end_frame = round((start + duration)* penn.SAMPLE_RATE)

audio = audio[..., start_frame : end_frame]

# get logits
pred_freq, pred_times = from_audio(audio, penn.SAMPLE_RATE, checkpoint, gpu)
pred_freq, pred_times, periodicity = from_audio(audio, penn.SAMPLE_RATE, checkpoint, gpu)
pred_times += start

# get the ground truth
gt_pitch, gt_times = get_ground_truth(ground_truth_file)
gt_pitch, gt_times = get_ground_truth(ground_truth_file, start, duration)

# get the stft of the audio
audio, sr = torchaudio.load(audio_file)
audio = audio.cpu().numpy()

# get the timestamps in frame numbers
start_frame = round(start * sr)

end_frame = -1
if duration is not None:
end_frame = round((start + duration)* sr)

audio = audio[..., start_frame:end_frame]

# now that we have both ground truth, STFT and the preditcted pitch, plot all with matplotlib and plotly
# well, do we have predicted pitch?
# plot_over_gt_with_plotly(audio, sr, pred_freq, pred_times, gt)
Expand All @@ -140,4 +178,6 @@ def from_file_to_file(audio_file, ground_truth_file, checkpoint, output_file=Non
pred_pitch=pred_freq,
pred_times=pred_times,
gt_pitch=gt_pitch,
gt_times=gt_times)
gt_times=gt_times,
periodicity=periodicity,
time_offset=start)
58 changes: 46 additions & 12 deletions penn/plot/to_latex/mplt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ def plot_stft(axis : plt.Axes,
audio,
sr=penn.SAMPLE_RATE,
window_length=2048*4,
hop_length=penn.data.preprocess.GSET_HOPSIZE):
hop_length=penn.data.preprocess.GSET_HOPSIZE,
time_offset=0):
"""
Add a plot of STFT to given audio.
Expand All @@ -24,6 +25,7 @@ def plot_stft(axis : plt.Axes,
sr=sr,
window_length=window_length,
hop_length=hop_length)
times += time_offset

axis.pcolormesh(times, freqs, np.abs(stft), cmap='grey')
axis.set_ylim([50, 300])
Expand All @@ -37,7 +39,9 @@ def plot_pitch(axis : plt.Axes,
times,
set_pitch_lims=True,
plot_red=False,
linewidth=1):
linewidth=1,
periodicity=None,
threshold=0.05):
"""
Add a plot of pitch. Optionally, set the frequency limits based
on the max i min values of the provided pitch.
Expand All @@ -53,17 +57,27 @@ def plot_pitch(axis : plt.Axes,
max_pitch = []
min_pitch = []

pitch_masked = np.ma.MaskedArray(pitch, pitch==0)
mask_for_pitch = pitch==0

for no_slice, pitch_slice in enumerate(pitch_masked):
y = pitch_slice.reshape(-1)
if periodicity is not None:
periodicity_for_mask = periodicity.squeeze()
periodicity_mask = periodicity_for_mask >= threshold
mask_for_pitch = np.logical_and(mask_for_pitch, periodicity_for_mask)

# pitch_masked = np.ma.MaskedArray(pitch, mask_for_pitch)
pitch_split = np.split(pitch, pitch.shape[0])
mask_split = np.split(mask_for_pitch, mask_for_pitch.shape[0])

for no_slice, (pitch_slice, mask_slice) in enumerate(zip(pitch_split, mask_split)):
pitch_slice = pitch_slice.reshape(-1)
y = np.ma.MaskedArray(pitch_slice, mask_slice)
x = times

# axis.scatter(x, y, label=f"String {no_slice}")
if plot_red:
axis.plot(x, y, 'r-', linewidth=linewidth, label=f"String {no_slice}")
else:
axis.plot(x, y, '-', linewidth=linewidth, label=f"String {no_slice}")
axis.scatter(x, y, linewidth=linewidth, label=f"String {no_slice}")


if pitch_slice.size > 0:
Expand All @@ -84,15 +98,32 @@ def plot_pitch(axis : plt.Axes,
axis.set_xlabel('Time [s]')


def plot_periodicity(axis : plt.Axes, periodicity, threshold=None):
def plot_periodicity(axis : plt.Axes,
periodicity : np.ndarray,
threshold : float=0.05):
"""
Plot the periodicity plot with or without threshold.
"""
if threshold is None:
pass

periodicity_for_plot = periodicity.squeeze().T

offset = np.arange(0, penn.PITCH_CATS) * int(not penn.LOSS_MULTI_HOT)
periodicity_for_plot += offset

twin = axis.twinx()
twin.set_ylim(ymin=0, ymax=penn.PITCH_CATS)

twin.plot(periodicity_for_plot, 'g:', linewidth=2)

if threshold is not None:
periodicity_mask = periodicity_for_plot >= threshold
# mask periodicity under the threshold
periodicity_masked = np.ma.MaskedArray(periodicity_for_plot, np.logical_not(periodicity_mask))

twin.plot(periodicity_masked, 'm:', linewidth=2)


def plot_with_matplotlib(audio, sr=penn.SAMPLE_RATE, pred_pitch=None, pred_times=None, gt_pitch=None, gt_times=None, periodicity=None, threshold=None):
def plot_with_matplotlib(audio, sr=penn.SAMPLE_RATE, pred_pitch=None, pred_times=None, gt_pitch=None, gt_times=None, periodicity=None, threshold=0.05, time_offset=0):
"""
Plot stft to the given audio. Optionally put raw pitch data
or even thresholded periodicity data on top of it.
Expand All @@ -107,10 +138,13 @@ def plot_with_matplotlib(audio, sr=penn.SAMPLE_RATE, pred_pitch=None, pred_times
axis.spines['bottom'].set_visible(False)
axis.spines['left'].set_visible(False)

plot_stft(axis, audio, sr)
plot_stft(axis, audio, sr, time_offset=time_offset)

if pred_pitch is not None and pred_times is not None:
plot_pitch(axis, pred_pitch, pred_times, linewidth=2)
plot_pitch(axis, pred_pitch, pred_times,
linewidth=1.5,
periodicity=periodicity,
threshold=0.9)

if gt_pitch is not None and gt_times is not None:
plot_pitch(axis, gt_pitch, gt_times, plot_red=True)
Expand Down

0 comments on commit 2834fde

Please sign in to comment.