Skip to content

Commit

Permalink
Better handling of flat data in numba function and spindles detect (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelvallat authored Jul 8, 2022
1 parent a281d55 commit b1890be
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ c. Added a new dataset containing 8 hours of ECG data. The dataset is in compres

a. When using an MNE.Raw object, conversion of the data from Volts to micro-Volts is now performed within MNE. `PR 70 <https://github.com/raphaelvallat/yasa/pull/70>`_
b. Use `black <https://black.readthedocs.io/en/stable/>`_ code formatting.
c. Better handling of flat data in :py:func:`yasa.spindles_detect`. The function previously returned a division by zero error if part of the data was flat. See `issue 85 <https://github.com/raphaelvallat/yasa/issues/85>`_

**Dependencies**

Expand Down
4 changes: 4 additions & 0 deletions yasa/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def _corr(x, y):
r_d1 = np.sqrt(xm2s)
r_d2 = np.sqrt(ym2s)
r_den = r_d1 * r_d2
if r_den == 0:
return np.nan
return r_num / r_den


Expand Down Expand Up @@ -69,6 +71,8 @@ def _slope_lstsq(x, y):
sy += y[j]
den = n_times * sx2 - (sx**2)
num = n_times * sxy - sx * sy
if den == 0:
return np.nan
return num / den


Expand Down
6 changes: 6 additions & 0 deletions yasa/tests/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ def test_spindles_detect(self):
# Test with hypnogram
spindles_detect(data, sf, hypno=np.ones(data.size))

# Test with 1-sec of flat data -- we should still have 2 detected spindles
data_flat = data.copy()
data_flat[100:200] = 1
sp = spindles_detect(data_flat, sf).summary()
assert sp.shape[0] == 2

# Single channel with Isolation Forest + hypnogram
sp = spindles_detect(data_full[1, :], sf, hypno=hypno_full, remove_outliers=True)

Expand Down
4 changes: 4 additions & 0 deletions yasa/tests/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@ def test_numba(self):
"""Test numba functions"""
x = np.asarray([4, 5, 7, 8, 5, 6], dtype=np.float64)
y = np.asarray([1, 5, 4, 6, 8, 5], dtype=np.float64)
y_flat = np.ones_like(x)

np.testing.assert_almost_equal(_corr(x, y), np.corrcoef(x, y)[0][1])
assert np.isnan(_corr(x, y_flat))
assert _covar(x, y) == np.cov(x, y)[0][1]
assert _rms(x) == np.sqrt(np.mean(np.square(x)))

# Least square slope and detrending
y = np.arange(30) + 3 * np.random.random(30)
times = np.arange(y.size, dtype=np.float64)
slope = _slope_lstsq(times, y)
assert np.isnan(_slope_lstsq(y_flat, x))
assert np.array_equal(_detrend(x, y_flat), np.zeros_like(y_flat))
np.testing.assert_array_almost_equal(_detrend(times, y), detrend(y, type="linear"))
X = times[..., np.newaxis]
X = np.column_stack((np.ones(X.shape[0]), X))
Expand Down

0 comments on commit b1890be

Please sign in to comment.