Skip to content

Commit

Permalink
WIP: plot pred pitch in colors, gt in red, stft in b&w
Browse files Browse the repository at this point in the history
related to: #5
  • Loading branch information
anthonio9 committed May 27, 2024
1 parent 04bca1a commit c2c8bb0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 8 deletions.
4 changes: 2 additions & 2 deletions penn/plot/to_latex/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def from_file_to_file(audio_file, ground_truth_file, checkpoint, output_file=Non
penn.plot.to_latex.mplt.plot_with_matplotlib(
audio=audio,
sr=sr,
# pitch_pred=pred_freq,
# pred_times=pred_times,
pred_pitch=pred_freq,
pred_times=pred_times,
gt_pitch=gt_pitch,
gt_times=gt_times)
40 changes: 34 additions & 6 deletions penn/plot/to_latex/mplt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@ def plot_stft(axis : plt.Axes,
window_length=window_length,
hop_length=hop_length)

axis.pcolormesh(times, freqs, np.abs(stft), )
axis.pcolormesh(times, freqs, np.abs(stft), cmap='grey')
axis.set_ylim([50, 300])
axis.set_xlim([times[0], times[-1]])

# take inspiration from this post: https://dsp.stackexchange.com/a/70136


def plot_pitch(axis : plt.Axes, pitch, times, set_pitch_lims=True):
def plot_pitch(axis : plt.Axes,
pitch,
times,
set_pitch_lims=True,
plot_red=False,
linewidth=1):
"""
Add a plot of pitch. Optionally, set the frequency limits based
on the max i min values of the provided pitch.
Expand All @@ -42,15 +47,24 @@ def plot_pitch(axis : plt.Axes, pitch, times, set_pitch_lims=True):
pitch - pitch array
times - times array
set_pitch_lims - flag indicating if freqnecy limits are to be adjusted
plot_red - set true to plot all lines in a red color
linewidth - set the matplotlib plot line width
"""
max_pitch = []
min_pitch = []

for no_slice, pitch_slice in enumerate(pitch):
pitch_masked = np.ma.MaskedArray(pitch, pitch==0)

for no_slice, pitch_slice in enumerate(pitch_masked):
y = pitch_slice.reshape(-1)
x = times

axis.scatter(x, y, label=f"String {no_slice}")
# axis.scatter(x, y, label=f"String {no_slice}")
if plot_red:
axis.plot(x, y, 'r-', linewidth=linewidth, label=f"String {no_slice}")
else:
axis.plot(x, y, '-', linewidth=linewidth, label=f"String {no_slice}")


if pitch_slice.size > 0:
max_pitch.append(pitch_slice.max())
Expand All @@ -66,6 +80,17 @@ def plot_pitch(axis : plt.Axes, pitch, times, set_pitch_lims=True):
if set_pitch_lims:
axis.set_ylim([ymin, ymax])

axis.set_ylabel('Frequency [Hz]')
axis.set_xlabel('Time [s]')


def plot_periodicity(axis : plt.Axes, periodicity, threshold=None):
"""
Plot the periodicity plot with or without threshold.
"""
if threshold is None:
pass


def plot_with_matplotlib(audio, sr=penn.SAMPLE_RATE, pred_pitch=None, pred_times=None, gt_pitch=None, gt_times=None, periodicity=None, threshold=None):
"""
Expand All @@ -85,10 +110,13 @@ def plot_with_matplotlib(audio, sr=penn.SAMPLE_RATE, pred_pitch=None, pred_times
plot_stft(axis, audio, sr)

if pred_pitch is not None and pred_times is not None:
plot_pitch(axis, pred_pitch, pred_times)
plot_pitch(axis, pred_pitch, pred_times, linewidth=2)

if gt_pitch is not None and gt_times is not None:
plot_pitch(axis, gt_pitch, gt_times)
plot_pitch(axis, gt_pitch, gt_times, plot_red=True)

if periodicity is not None:
plot_periodicity(axis, periodicity, threshold)

# figure.show()
plt.show()

0 comments on commit c2c8bb0

Please sign in to comment.