From a281d556749b3a968f9858b4db9f803186dfabcd Mon Sep 17 00:00:00 2001 From: Raphael Vallat Date: Mon, 20 Jun 2022 10:34:42 -0700 Subject: [PATCH] Black formatting (#83) * Use black code formatting + GH Actions * Reformat setup and conf --- .github/workflows/black.yml | 10 + docs/changelog.rst | 1 + docs/conf.py | 89 +-- notebooks/run_visbrain.py | 17 +- pyproject.toml | 2 + setup.cfg | 2 +- setup.py | 102 ++-- yasa/__init__.py | 5 +- yasa/detection.py | 1085 ++++++++++++++++++++------------- yasa/features.py | 182 +++--- yasa/heart.py | 41 +- yasa/hypno.py | 106 ++-- yasa/io.py | 17 +- yasa/numba.py | 24 +- yasa/others.py | 86 +-- yasa/plotting.py | 187 +++--- yasa/sleepstats.py | 60 +- yasa/spectral.py | 281 +++++---- yasa/staging.py | 153 ++--- yasa/tests/test_detection.py | 176 +++--- yasa/tests/test_heart.py | 13 +- yasa/tests/test_hypno.py | 91 +-- yasa/tests/test_io.py | 14 +- yasa/tests/test_numba.py | 7 +- yasa/tests/test_others.py | 77 ++- yasa/tests/test_plotting.py | 28 +- yasa/tests/test_sleepstats.py | 44 +- yasa/tests/test_spectral.py | 111 ++-- yasa/tests/test_staging.py | 15 +- 29 files changed, 1767 insertions(+), 1259 deletions(-) create mode 100644 .github/workflows/black.yml create mode 100644 pyproject.toml diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000..98b2a66 --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,10 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: psf/black@stable \ No newline at end of file diff --git a/docs/changelog.rst b/docs/changelog.rst index 4c99a0a..9e0bd5b 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -20,6 +20,7 @@ c. Added a new dataset containing 8 hours of ECG data. The dataset is in compres **Improvements** a. When using an MNE.Raw object, conversion of the data from Volts to micro-Volts is now performed within MNE. `PR 70 `_ +b. Use `black `_ code formatting. **Dependencies** diff --git a/docs/conf.py b/docs/conf.py index a49b895..f02cafe 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -9,24 +9,26 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -sys.path.insert(0, os.path.abspath('sphinxext')) -extensions = ['sphinx.ext.mathjax', - 'sphinx.ext.doctest', - 'sphinx.ext.viewcode', - 'sphinx.ext.githubpages', - 'sphinx.ext.autosummary', - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'matplotlib.sphinxext.plot_directive', - 'sphinx_copybutton', - 'numpydoc'] +sys.path.insert(0, os.path.abspath("sphinxext")) +extensions = [ + "sphinx.ext.mathjax", + "sphinx.ext.doctest", + "sphinx.ext.viewcode", + "sphinx.ext.githubpages", + "sphinx.ext.autosummary", + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "matplotlib.sphinxext.plot_directive", + "sphinx_copybutton", + "numpydoc", +] # Generate the API documentation when building autosummary_generate = True autodoc_default_options = { - 'members': True, - 'member-order': 'groupwise', - 'undoc-members': False, + "members": True, + "member-order": "groupwise", + "undoc-members": False, # 'special-members': '__init__', # 'exclude-members': '__weakref__' } @@ -49,15 +51,17 @@ # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'yasa' -author = 'Raphael Vallat' -copyright = u'2018-{}, Dr. Raphael Vallat, Center for Human Sleep Science, UC Berkeley'.format(time.strftime("%Y")) +project = "yasa" +author = "Raphael Vallat" +copyright = "2018-{}, Dr. Raphael Vallat, Center for Human Sleep Science, UC Berkeley".format( + time.strftime("%Y") +) # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -79,10 +83,10 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -90,45 +94,46 @@ # -- Options for HTML output ---------------------------------------------- # Bootstrap theme -html_theme = 'bootstrap' +html_theme = "bootstrap" html_theme_path = sphinx_bootstrap_theme.get_html_theme_path() html_theme_options = { - 'source_link_position': "footer", + "source_link_position": "footer", # 'navbar_title': ' ', # we replace this with an image - 'bootswatch_theme': "flatly", - 'navbar_sidebarrel': False, + "bootswatch_theme": "flatly", + "navbar_sidebarrel": False, # 'nosidebar': True, # 'navbar_site_name': "", - 'navbar_pagenav': False, - 'bootstrap_version': "3", - 'navbar_class': "navbar", - 'navbar_links': [ + "navbar_pagenav": False, + "bootstrap_version": "3", + "navbar_class": "navbar", + "navbar_links": [ ("API", "api"), ("Quickstart", "quickstart"), ("FAQ", "faq"), ("What's new", "changelog"), - ("Contribute", "contributing")], + ("Contribute", "contributing"), + ], } -html_logo = 'pictures/yasa_128x128.png' -html_favicon = 'pictures/favicon.ico' +html_logo = "pictures/yasa_128x128.png" +html_favicon = "pictures/favicon.ico" # -- Options for HTML output ------------------------------------------ # Output file base name for HTML help builder. -htmlhelp_basename = 'yasadoc' +htmlhelp_basename = "yasadoc" html_show_sourcelink = False # -- Intersphinx ------------------------------------------------ intersphinx_mapping = { - 'numpy': ('http://docs.scipy.org/doc/numpy/', None), - 'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None), - 'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None), - 'sklearn': ('https://scikit-learn.org/stable/', None), - 'matplotlib': ('https://matplotlib.org/', None), - 'mne': ('https://martinos.org/mne/stable/', None), - 'seaborn': ('https://seaborn.pydata.org/', None), - 'pyriemann': ('https://pyriemann.readthedocs.io/en/latest/', None), - 'tensorpac': ('https://etiennecmb.github.io/tensorpac/', None), + "numpy": ("http://docs.scipy.org/doc/numpy/", None), + "scipy": ("http://docs.scipy.org/doc/scipy/reference/", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), + "sklearn": ("https://scikit-learn.org/stable/", None), + "matplotlib": ("https://matplotlib.org/", None), + "mne": ("https://martinos.org/mne/stable/", None), + "seaborn": ("https://seaborn.pydata.org/", None), + "pyriemann": ("https://pyriemann.readthedocs.io/en/latest/", None), + "tensorpac": ("https://etiennecmb.github.io/tensorpac/", None), } diff --git a/notebooks/run_visbrain.py b/notebooks/run_visbrain.py index 11bde18..7e30d8c 100644 --- a/notebooks/run_visbrain.py +++ b/notebooks/run_visbrain.py @@ -5,9 +5,9 @@ from yasa import spindles_detect, sw_detect # Load the data and hypnogram -data = np.load('data_full_6hrs_100Hz_Cz+Fz+Pz.npz').get('data') -ch_names = ['Cz', 'Fz', 'Pz'] -hypno = np.load('data_full_6hrs_100Hz_hypno.npz').get('hypno') +data = np.load("data_full_6hrs_100Hz_Cz+Fz+Pz.npz").get("data") +ch_names = ["Cz", "Fz", "Pz"] +hypno = np.load("data_full_6hrs_100Hz_hypno.npz").get("hypno") # Initialize a Visbrain.gui.Sleep instance sl = Sleep(data=data, channels=ch_names, sf=100, hypno=hypno) @@ -22,23 +22,22 @@ def fcn_spindle(data, sf, time, hypno): # sp = spindles_detect(data, sf).summary() # NREM sleep only sp = spindles_detect(data, sf, hypno=hypno).summary() - return (sp[['Start', 'End']].values * sf).astype(int) + return (sp[["Start", "End"]].values * sf).astype(int) # Define slow-waves function def fcn_sw(data, sf, time, hypno): - """Replace Visbrain built-in slow-wave detection by YASA algorithm. - """ + """Replace Visbrain built-in slow-wave detection by YASA algorithm.""" # On N2 / N3 sleep only # Note that if you want to apply the detection on N3 sleep only, you should # use sw_detect(..., include=(3)).summary() sw = sw_detect(data, sf, hypno=hypno).summary() - return (sw[['Start', 'End']].values * sf).astype(int) + return (sw[["Start", "End"]].values * sf).astype(int) # Replace the native Visbrain detections -sl.replace_detections('spindle', fcn_spindle) -sl.replace_detections('sw', fcn_sw) +sl.replace_detections("spindle", fcn_spindle) +sl.replace_detections("sw", fcn_sw) # Launch the Graphical User Interface sl.show() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..037585e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 100 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 309c3db..5b99974 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,7 +12,7 @@ testpaths = [flake8] max-line-length = 100 -ignore = N806, N803, D107, D200, D205, D400, D401, D412, W504 +ignore = N806, N803, D107, D200, D205, D400, D401, D412, W504, E203 exclude = .git, __pycache__, diff --git a/setup.py b/setup.py index ed89d69..3e6ef15 100644 --- a/setup.py +++ b/setup.py @@ -6,71 +6,73 @@ LONG_DESCRIPTION = """YASA (Yet Another Spindle Algorithm) : fast and robust detection of spindles, slow-waves, and rapid eye movements from sleep EEG recordings.. """ -DISTNAME = 'yasa' -MAINTAINER = 'Raphael Vallat' -MAINTAINER_EMAIL = 'raphaelvallat9@gmail.com' -URL = 'https://github.com/raphaelvallat/yasa/' -LICENSE = 'BSD (3-clause)' -DOWNLOAD_URL = 'https://github.com/raphaelvallat/yasa/' -VERSION = '0.6.1' -PACKAGE_DATA = {'yasa.data.icons': ['*.svg']} +DISTNAME = "yasa" +MAINTAINER = "Raphael Vallat" +MAINTAINER_EMAIL = "raphaelvallat9@gmail.com" +URL = "https://github.com/raphaelvallat/yasa/" +LICENSE = "BSD (3-clause)" +DOWNLOAD_URL = "https://github.com/raphaelvallat/yasa/" +VERSION = "0.6.1" +PACKAGE_DATA = {"yasa.data.icons": ["*.svg"]} INSTALL_REQUIRES = [ - 'numpy', - 'scipy', - 'pandas', - 'matplotlib', - 'seaborn', - 'mne>=0.20.0', - 'numba', - 'outdated', - 'antropy', - 'scikit-learn', - 'tensorpac>=0.6.5', - 'pyriemann>=0.2.7', - 'sleepecg>=0.5.0', - 'lspopt', - 'ipywidgets', - 'joblib' + "numpy", + "scipy", + "pandas", + "matplotlib", + "seaborn", + "mne>=0.20.0", + "numba", + "outdated", + "antropy", + "scikit-learn", + "tensorpac>=0.6.5", + "pyriemann>=0.2.7", + "sleepecg>=0.5.0", + "lspopt", + "ipywidgets", + "joblib", ] PACKAGES = [ - 'yasa', + "yasa", ] CLASSIFIERS = [ - 'Intended Audience :: Science/Research', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'License :: OSI Approved :: BSD License', - 'Operating System :: POSIX', - 'Operating System :: Unix', - 'Operating System :: MacOS' + "Intended Audience :: Science/Research", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "License :: OSI Approved :: BSD License", + "Operating System :: POSIX", + "Operating System :: Unix", + "Operating System :: MacOS", ] try: from setuptools import setup + _has_setuptools = True except ImportError: from distutils.core import setup if __name__ == "__main__": - setup(name=DISTNAME, - author=MAINTAINER, - author_email=MAINTAINER_EMAIL, - maintainer=MAINTAINER, - maintainer_email=MAINTAINER_EMAIL, - description=DESCRIPTION, - long_description=LONG_DESCRIPTION, - license=LICENSE, - url=URL, - version=VERSION, - download_url=DOWNLOAD_URL, - install_requires=INSTALL_REQUIRES, - include_package_data=True, - packages=PACKAGES, - package_data=PACKAGE_DATA, - classifiers=CLASSIFIERS, - ) + setup( + name=DISTNAME, + author=MAINTAINER, + author_email=MAINTAINER_EMAIL, + maintainer=MAINTAINER, + maintainer_email=MAINTAINER_EMAIL, + description=DESCRIPTION, + long_description=LONG_DESCRIPTION, + license=LICENSE, + url=URL, + version=VERSION, + download_url=DOWNLOAD_URL, + install_requires=INSTALL_REQUIRES, + include_package_data=True, + packages=PACKAGES, + package_data=PACKAGE_DATA, + classifiers=CLASSIFIERS, + ) diff --git a/yasa/__init__.py b/yasa/__init__.py index 36580fe..bdbb327 100644 --- a/yasa/__init__.py +++ b/yasa/__init__.py @@ -12,10 +12,9 @@ from outdated import warn_if_outdated # Define YASA logger -logging.basicConfig(format='%(asctime)s | %(levelname)s | %(message)s', - datefmt='%d-%b-%y %H:%M:%S') +logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", datefmt="%d-%b-%y %H:%M:%S") -__author__ = "Raphael Vallat " +__author__ = "Raphael Vallat " __version__ = "0.6.1" # Warn if a newer version of YASA is available diff --git a/yasa/detection.py b/yasa/detection.py index bbe89fd..8f11360 100644 --- a/yasa/detection.py +++ b/yasa/detection.py @@ -20,34 +20,48 @@ from .spectral import stft_power from .numba import _detrend, _rms from .io import set_log_level, is_tensorpac_installed, is_pyriemann_installed -from .others import (moving_transform, trimbothstd, get_centered_indices, - sliding_window, _merge_close, _zerocrossings) +from .others import ( + moving_transform, + trimbothstd, + get_centered_indices, + sliding_window, + _merge_close, + _zerocrossings, +) -logger = logging.getLogger('yasa') +logger = logging.getLogger("yasa") -__all__ = ['art_detect', 'spindles_detect', 'SpindlesResults', 'sw_detect', 'SWResults', - 'rem_detect', 'REMResults'] +__all__ = [ + "art_detect", + "spindles_detect", + "SpindlesResults", + "sw_detect", + "SWResults", + "rem_detect", + "REMResults", +] ############################################################################# # DATA PREPROCESSING ############################################################################# + def _check_data_hypno(data, sf=None, ch_names=None, hypno=None, include=None, check_amp=True): """Helper functions for preprocessing of data and hypnogram.""" # 1) Extract data as a 2D NumPy array if isinstance(data, mne.io.BaseRaw): - sf = data.info['sfreq'] # Extract sampling frequency + sf = data.info["sfreq"] # Extract sampling frequency ch_names = data.ch_names # Extract channel names data = data.get_data(units=dict(eeg="uV", emg="uV", eog="uV", ecg="uV")) else: - assert sf is not None, 'sf must be specified if not using MNE Raw.' + assert sf is not None, "sf must be specified if not using MNE Raw." if isinstance(sf, np.ndarray): # Deal with sf = array(100.) --> 100 sf = float(sf) assert isinstance(sf, (int, float)), "sf must be int or float." data = np.asarray(data, dtype=np.float64) - assert data.ndim in [1, 2], 'data must be 1D (times) or 2D (chan, times).' + assert data.ndim in [1, 2], "data must be 1D (times) or 2D (chan, times)." if data.ndim == 1: # Force to 2D array: (n_chan, n_samples) data = data[None, ...] @@ -55,40 +69,41 @@ def _check_data_hypno(data, sf=None, ch_names=None, hypno=None, include=None, ch # 2) Check channel names if ch_names is None: - ch_names = ['CHAN' + str(i).zfill(3) for i in range(n_chan)] + ch_names = ["CHAN" + str(i).zfill(3) for i in range(n_chan)] else: assert len(ch_names) == n_chan # 3) Check hypnogram if hypno is not None: hypno = np.asarray(hypno, dtype=int) - assert hypno.ndim == 1, 'Hypno must be one dimensional.' - assert hypno.size == n_samples, 'Hypno must have same size as data.' + assert hypno.ndim == 1, "Hypno must be one dimensional." + assert hypno.size == n_samples, "Hypno must have same size as data." unique_hypno = np.unique(hypno) - logger.info('Number of unique values in hypno = %i', unique_hypno.size) - assert include is not None, 'include cannot be None if hypno is given' + logger.info("Number of unique values in hypno = %i", unique_hypno.size) + assert include is not None, "include cannot be None if hypno is given" include = np.atleast_1d(np.asarray(include)) - assert include.size >= 1, '`include` must have at least one element.' - assert hypno.dtype.kind == include.dtype.kind, ('hypno and include must have same dtype') - assert np.in1d(hypno, include).any(), ('None of the stages specified ' - 'in `include` are present in ' - 'hypno.') + assert include.size >= 1, "`include` must have at least one element." + assert hypno.dtype.kind == include.dtype.kind, "hypno and include must have same dtype" + assert np.in1d(hypno, include).any(), ( + "None of the stages specified " "in `include` are present in " "hypno." + ) # 4) Check data amplitude - logger.info('Number of samples in data = %i', n_samples) - logger.info('Sampling frequency = %.2f Hz', sf) - logger.info('Data duration = %.2f seconds', n_samples / sf) + logger.info("Number of samples in data = %i", n_samples) + logger.info("Sampling frequency = %.2f Hz", sf) + logger.info("Data duration = %.2f seconds", n_samples / sf) all_ptp = np.ptp(data, axis=-1) all_trimstd = trimbothstd(data, cut=0.05) bad_chan = np.zeros(n_chan, dtype=bool) for i in range(n_chan): - logger.info('Trimmed standard deviation of %s = %.4f uV' % (ch_names[i], all_trimstd[i])) - logger.info('Peak-to-peak amplitude of %s = %.4f uV' % (ch_names[i], all_ptp[i])) - if check_amp and not(0.1 < all_trimstd[i] < 1e3): - logger.error('Wrong data amplitude for %s ' - '(trimmed STD = %.3f). Unit of data MUST be uV! ' - 'Channel will be skipped.' - % (ch_names[i], all_trimstd[i])) + logger.info("Trimmed standard deviation of %s = %.4f uV" % (ch_names[i], all_trimstd[i])) + logger.info("Peak-to-peak amplitude of %s = %.4f uV" % (ch_names[i], all_ptp[i])) + if check_amp and not (0.1 < all_trimstd[i] < 1e3): + logger.error( + "Wrong data amplitude for %s " + "(trimmed STD = %.3f). Unit of data MUST be uV! " + "Channel will be skipped." % (ch_names[i], all_trimstd[i]) + ) bad_chan[i] = True # 5) Create sleep stage vector mask @@ -128,50 +143,56 @@ def _check_mask(self, mask): assert mask.size == n_events, "Mask.size must be the number of detected events." return mask - def summary(self, event_type, grp_chan=False, grp_stage=False, aggfunc='mean', sort=True, - mask=None): + def summary( + self, event_type, grp_chan=False, grp_stage=False, aggfunc="mean", sort=True, mask=None + ): """Summary""" # Check masking mask = self._check_mask(mask) # Define grouping grouper = [] - if grp_stage is True and 'Stage' in self._events: - grouper.append('Stage') - if grp_chan is True and 'Channel' in self._events: - grouper.append('Channel') + if grp_stage is True and "Stage" in self._events: + grouper.append("Stage") + if grp_chan is True and "Channel" in self._events: + grouper.append("Channel") if not len(grouper): # Return a copy of self._events after masking, without grouping return self._events.loc[mask, :].copy() - if event_type == 'spindles': - aggdict = {'Start': 'count', - 'Duration': aggfunc, - 'Amplitude': aggfunc, - 'RMS': aggfunc, - 'AbsPower': aggfunc, - 'RelPower': aggfunc, - 'Frequency': aggfunc, - 'Oscillations': aggfunc, - 'Symmetry': aggfunc} + if event_type == "spindles": + aggdict = { + "Start": "count", + "Duration": aggfunc, + "Amplitude": aggfunc, + "RMS": aggfunc, + "AbsPower": aggfunc, + "RelPower": aggfunc, + "Frequency": aggfunc, + "Oscillations": aggfunc, + "Symmetry": aggfunc, + } # if 'SOPhase' in self._events: # from scipy.stats import circmean # aggdict['SOPhase'] = lambda x: circmean(x, low=-np.pi, high=np.pi) - elif event_type == 'sw': - aggdict = {'Start': 'count', - 'Duration': aggfunc, - 'ValNegPeak': aggfunc, - 'ValPosPeak': aggfunc, - 'PTP': aggfunc, - 'Slope': aggfunc, - 'Frequency': aggfunc} - - if 'PhaseAtSigmaPeak' in self._events: + elif event_type == "sw": + aggdict = { + "Start": "count", + "Duration": aggfunc, + "ValNegPeak": aggfunc, + "ValPosPeak": aggfunc, + "PTP": aggfunc, + "Slope": aggfunc, + "Frequency": aggfunc, + } + + if "PhaseAtSigmaPeak" in self._events: from scipy.stats import circmean - aggdict['PhaseAtSigmaPeak'] = lambda x: circmean(x, low=-np.pi, high=np.pi) - aggdict['ndPAC'] = aggfunc + + aggdict["PhaseAtSigmaPeak"] = lambda x: circmean(x, low=-np.pi, high=np.pi) + aggdict["ndPAC"] = aggfunc if "CooccurringSpindle" in self._events: # We do not average "CooccurringSpindlePeak" @@ -179,22 +200,24 @@ def summary(self, event_type, grp_chan=False, grp_stage=False, aggfunc='mean', s aggdict["DistanceSpindleToSW"] = aggfunc else: # REM - aggdict = {'Start': 'count', - 'Duration': aggfunc, - 'LOCAbsValPeak': aggfunc, - 'ROCAbsValPeak': aggfunc, - 'LOCAbsRiseSlope': aggfunc, - 'ROCAbsRiseSlope': aggfunc, - 'LOCAbsFallSlope': aggfunc, - 'ROCAbsFallSlope': aggfunc} + aggdict = { + "Start": "count", + "Duration": aggfunc, + "LOCAbsValPeak": aggfunc, + "ROCAbsValPeak": aggfunc, + "LOCAbsRiseSlope": aggfunc, + "ROCAbsRiseSlope": aggfunc, + "LOCAbsFallSlope": aggfunc, + "ROCAbsFallSlope": aggfunc, + } # Apply grouping, after masking df_grp = self._events.loc[mask, :].groupby(grouper, sort=sort, as_index=False).agg(aggdict) - df_grp = df_grp.rename(columns={'Start': 'Count'}) + df_grp = df_grp.rename(columns={"Start": "Count"}) # Calculate density (= number per min of each stage) if self._hypno is not None and grp_stage is True: - stages = np.unique(self._events['Stage']) + stages = np.unique(self._events["Stage"]) dur = {} for st in stages: # Get duration in minutes of each stage present in dataframe @@ -202,37 +225,42 @@ def summary(self, event_type, grp_chan=False, grp_stage=False, aggfunc='mean', s # Insert new density column in grouped dataframe after count df_grp.insert( - loc=df_grp.columns.get_loc('Count') + 1, column='Density', - value=df_grp.apply(lambda rw: rw['Count'] / dur[rw['Stage']], axis=1)) + loc=df_grp.columns.get_loc("Count") + 1, + column="Density", + value=df_grp.apply(lambda rw: rw["Count"] / dur[rw["Stage"]], axis=1), + ) return df_grp.set_index(grouper) def get_mask(self): """get_mask""" from yasa.others import _index_to_events + mask = np.zeros(self._data.shape, dtype=int) - for i in self._events['IdxChannel'].unique(): - ev_chan = self._events[self._events['IdxChannel'] == i] - idx_ev = _index_to_events( - ev_chan[['Start', 'End']].to_numpy() * self._sf) + for i in self._events["IdxChannel"].unique(): + ev_chan = self._events[self._events["IdxChannel"] == i] + idx_ev = _index_to_events(ev_chan[["Start", "End"]].to_numpy() * self._sf) mask[i, idx_ev] = 1 return np.squeeze(mask) - def get_sync_events(self, center, time_before, time_after, filt=(None, None), mask=None, - as_dataframe=True): + def get_sync_events( + self, center, time_before, time_after, filt=(None, None), mask=None, as_dataframe=True + ): """Get_sync_events (not for REM, spindles & SW only)""" from yasa.others import get_centered_indices + assert time_before >= 0 assert time_after >= 0 bef = int(self._sf * time_before) aft = int(self._sf * time_after) # TODO: Step size is determined by sf: 0.01 sec at 100 Hz, 0.002 sec at # 500 Hz, 0.00390625 sec at 256 Hz. Should we add resample=100 (Hz) or step_size=0.01? - time = np.arange(-bef, aft + 1, dtype='int') / self._sf + time = np.arange(-bef, aft + 1, dtype="int") / self._sf if any(filt): data = mne.filter.filter_data( - self._data, self._sf, l_freq=filt[0], h_freq=filt[1], method='fir', verbose=False) + self._data, self._sf, l_freq=filt[0], h_freq=filt[1], method="fir", verbose=False + ) else: data = self._data @@ -242,18 +270,19 @@ def get_sync_events(self, center, time_before, time_after, filt=(None, None), ma output = [] - for i in masked_events['IdxChannel'].unique(): + for i in masked_events["IdxChannel"].unique(): # Copy is required to merge with the stage later on - ev_chan = masked_events[masked_events['IdxChannel'] == i].copy() - ev_chan['Event'] = np.arange(ev_chan.shape[0]) + ev_chan = masked_events[masked_events["IdxChannel"] == i].copy() + ev_chan["Event"] = np.arange(ev_chan.shape[0]) peaks = (ev_chan[center] * self._sf).astype(int).to_numpy() # Get centered indices idx, idx_valid = get_centered_indices(data[i, :], peaks, bef, aft) # If no good epochs are returned raise a warning if len(idx_valid) == 0: logger.error( - 'Time before and/or time after exceed data bounds, please ' - 'lower the temporal window around center. Skipping channel.') + "Time before and/or time after exceed data bounds, please " + "lower the temporal window around center. Skipping channel." + ) continue # Get data at indices and time vector @@ -266,15 +295,15 @@ def get_sync_events(self, center, time_before, time_after, filt=(None, None), ma # Convert to long-format dataframe df_chan = pd.DataFrame(amps.T) - df_chan['Time'] = time + df_chan["Time"] = time # Convert to long-format - df_chan = df_chan.melt(id_vars='Time', var_name='Event', value_name='Amplitude') + df_chan = df_chan.melt(id_vars="Time", var_name="Event", value_name="Amplitude") # Append stage - if 'Stage' in masked_events: - df_chan = df_chan.merge(ev_chan[['Event', 'Stage']].iloc[idx_valid]) + if "Stage" in masked_events: + df_chan = df_chan.merge(ev_chan[["Event", "Stage"]].iloc[idx_valid]) # Append channel name - df_chan['Channel'] = ev_chan['Channel'].iloc[0] - df_chan['IdxChannel'] = i + df_chan["Channel"] = ev_chan["Channel"].iloc[0] + df_chan["IdxChannel"] = i # Append to master dataframe output.append(df_chan) @@ -296,7 +325,7 @@ def _coincidence(x, y): coincidence = (x * y).sum() if scaled: # Handle division by zero error - denom = (x.sum() * y.sum()) + denom = x.sum() * y.sum() if denom == 0: coincidence = np.nan else: @@ -312,31 +341,42 @@ def _coincidence(x, y): return coinc_mat - def plot_average(self, event_type, center='Peak', hue='Channel', time_before=1, - time_after=1, filt=(None, None), mask=None, figsize=(6, 4.5), **kwargs): + def plot_average( + self, + event_type, + center="Peak", + hue="Channel", + time_before=1, + time_after=1, + filt=(None, None), + mask=None, + figsize=(6, 4.5), + **kwargs, + ): """Plot the average event (not for REM, spindles & SW only)""" import seaborn as sns import matplotlib.pyplot as plt - df_sync = self.get_sync_events(center=center, time_before=time_before, - time_after=time_after, filt=filt, mask=mask) + df_sync = self.get_sync_events( + center=center, time_before=time_before, time_after=time_after, filt=filt, mask=mask + ) assert not df_sync.empty, "Could not calculate event-locked data." - assert hue in ['Stage', 'Channel'], "hue must be 'Channel' or 'Stage'" + assert hue in ["Stage", "Channel"], "hue must be 'Channel' or 'Stage'" assert hue in df_sync.columns, "%s is not present in data." % hue - if event_type == 'spindles': + if event_type == "spindles": title = "Average spindle" else: # "sw": title = "Average SW" # Start figure fig, ax = plt.subplots(1, 1, figsize=figsize) - sns.lineplot(data=df_sync, x='Time', y='Amplitude', hue=hue, ax=ax, **kwargs) + sns.lineplot(data=df_sync, x="Time", y="Amplitude", hue=hue, ax=ax, **kwargs) # ax.legend(frameon=False, loc='lower right') - ax.set_xlim(df_sync['Time'].min(), df_sync['Time'].max()) + ax.set_xlim(df_sync["Time"].min(), df_sync["Time"].max()) ax.set_title(title) - ax.set_xlabel('Time (sec)') - ax.set_ylabel('Amplitude (uV)') + ax.set_xlabel("Time (sec)") + ax.set_ylabel("Amplitude (uV)") return ax def plot_detection(self): @@ -362,19 +402,15 @@ def plot_detection(self): # Plot fig, ax = plt.subplots(figsize=(12, 4)) - plt.plot(times[xrng], self._data[0, xrng], 'k', lw=1) - plt.plot(times[xrng], highlight[0, xrng], 'indianred') - plt.xlabel('Time (seconds)') - plt.ylabel('Amplitude (uV)') + plt.plot(times[xrng], self._data[0, xrng], "k", lw=1) + plt.plot(times[xrng], highlight[0, xrng], "indianred") + plt.xlabel("Time (seconds)") + plt.ylabel("Amplitude (uV)") fig.canvas.header_visible = False fig.tight_layout() # WIDGETS - layout = ipy.Layout( - width="50%", - justify_content='center', - align_items='center' - ) + layout = ipy.Layout(width="50%", justify_content="center", align_items="center") sl_ep = ipy.IntSlider( min=0, @@ -391,24 +427,23 @@ def plot_detection(self): step=25, value=150, layout=layout, - orientation='horizontal', - description="Amplitude:" + orientation="horizontal", + description="Amplitude:", ) dd_ch = ipy.Dropdown( - options=self._ch_names, value=self._ch_names[0], - description='Channel:' + options=self._ch_names, value=self._ch_names[0], description="Channel:" ) dd_win = ipy.Dropdown( options=[1, 5, 10, 30, 60], value=win_size, - description='Window size:', + description="Window size:", ) dd_check = ipy.Checkbox( value=False, - description='Filtered', + description="Filtered", ) def update(epoch, amplitude, channel, win_size, filt): @@ -428,8 +463,9 @@ def update(epoch, amplitude, channel, win_size, filt): pass ax.set_ylim([-amplitude, amplitude]) - return ipy.interact(update, epoch=sl_ep, amplitude=sl_amp, - channel=dd_ch, win_size=dd_win, filt=dd_check) + return ipy.interact( + update, epoch=sl_ep, amplitude=sl_amp, channel=dd_ch, win_size=dd_win, filt=dd_check + ) ############################################################################# @@ -437,11 +473,21 @@ def update(epoch, amplitude, channel, win_size, filt): ############################################################################# -def spindles_detect(data, sf=None, ch_names=None, hypno=None, - include=(1, 2, 3), freq_sp=(12, 15), freq_broad=(1, 30), - duration=(0.5, 2), min_distance=500, - thresh={'rel_pow': 0.2, 'corr': 0.65, 'rms': 1.5}, - multi_only=False, remove_outliers=False, verbose=False): +def spindles_detect( + data, + sf=None, + ch_names=None, + hypno=None, + include=(1, 2, 3), + freq_sp=(12, 15), + freq_broad=(1, 30), + duration=(0.5, 2), + min_distance=500, + thresh={"rel_pow": 0.2, "corr": 0.65, "rms": 1.5}, + multi_only=False, + remove_outliers=False, + verbose=False, +): """Spindles detection. Parameters @@ -614,44 +660,52 @@ def spindles_detect(data, sf=None, ch_names=None, hypno=None, """ set_log_level(verbose) - (data, sf, ch_names, hypno, include, mask, n_chan, n_samples, bad_chan - ) = _check_data_hypno(data, sf, ch_names, hypno, include) + (data, sf, ch_names, hypno, include, mask, n_chan, n_samples, bad_chan) = _check_data_hypno( + data, sf, ch_names, hypno, include + ) # If all channels are bad if sum(bad_chan) == n_chan: - logger.warning('All channels have bad amplitude. Returning None.') + logger.warning("All channels have bad amplitude. Returning None.") return None # Check detection thresholds - if 'rel_pow' not in thresh.keys(): - thresh['rel_pow'] = 0.20 - if 'corr' not in thresh.keys(): - thresh['corr'] = 0.65 - if 'rms' not in thresh.keys(): - thresh['rms'] = 1.5 - do_rel_pow = thresh['rel_pow'] not in [None, "none", "None"] - do_corr = thresh['corr'] not in [None, "none", "None"] - do_rms = thresh['rms'] not in [None, "none", "None"] + if "rel_pow" not in thresh.keys(): + thresh["rel_pow"] = 0.20 + if "corr" not in thresh.keys(): + thresh["corr"] = 0.65 + if "rms" not in thresh.keys(): + thresh["rms"] = 1.5 + do_rel_pow = thresh["rel_pow"] not in [None, "none", "None"] + do_corr = thresh["corr"] not in [None, "none", "None"] + do_rms = thresh["rms"] not in [None, "none", "None"] n_thresh = sum([do_rel_pow, do_corr, do_rms]) - assert n_thresh >= 1, 'At least one threshold must be defined.' + assert n_thresh >= 1, "At least one threshold must be defined." # Filtering nfast = next_fast_len(n_samples) # 1) Broadband bandpass filter (optional -- careful of lower freq for PAC) - data_broad = filter_data(data, sf, freq_broad[0], freq_broad[1], method='fir', verbose=0) + data_broad = filter_data(data, sf, freq_broad[0], freq_broad[1], method="fir", verbose=0) # 2) Sigma bandpass filter # The width of the transition band is set to 1.5 Hz on each side, # meaning that for freq_sp = (12, 15 Hz), the -6 dB points are located at # 11.25 and 15.75 Hz. data_sigma = filter_data( - data, sf, freq_sp[0], freq_sp[1], l_trans_bandwidth=1.5, h_trans_bandwidth=1.5, - method='fir', verbose=0) + data, + sf, + freq_sp[0], + freq_sp[1], + l_trans_bandwidth=1.5, + h_trans_bandwidth=1.5, + method="fir", + verbose=0, + ) # Hilbert power (to define the instantaneous frequency / power) analytic = signal.hilbert(data_sigma, N=nfast)[:, :n_samples] inst_phase = np.angle(analytic) inst_pow = np.square(np.abs(analytic)) - inst_freq = (sf / (2 * np.pi) * np.diff(inst_phase, axis=-1)) + inst_freq = sf / (2 * np.pi) * np.diff(inst_phase, axis=-1) # Extract the SO signal for coupling # if coupling: @@ -680,7 +734,8 @@ def spindles_detect(data, sf=None, ch_names=None, hypno=None, # Note that even if the threshold is None we still need to calculate it # for the individual spindles parameter (RelPow). f, t, Sxx = stft_power( - data_broad[i, :], sf, window=2, step=.2, band=freq_broad, interp=False, norm=True) + data_broad[i, :], sf, window=2, step=0.2, band=freq_broad, interp=False, norm=True + ) idx_sigma = np.logical_and(f >= freq_sp[0], f <= freq_sp[1]) rel_pow = Sxx[idx_sigma].sum(0) @@ -688,39 +743,47 @@ def spindles_detect(data, sf=None, ch_names=None, hypno=None, # Note that we could also have use the `interp=True` in the # `stft_power` function, however 2D interpolation is much slower than # 1D interpolation. - func = interp1d(t, rel_pow, kind='cubic', bounds_error=False, fill_value=0) + func = interp1d(t, rel_pow, kind="cubic", bounds_error=False, fill_value=0) t = np.arange(n_samples) / sf rel_pow = func(t) if do_corr: - _, mcorr = moving_transform(x=data_sigma[i, :], y=data_broad[i, :], sf=sf, window=.3, - step=.1, method='corr', interp=True) + _, mcorr = moving_transform( + x=data_sigma[i, :], + y=data_broad[i, :], + sf=sf, + window=0.3, + step=0.1, + method="corr", + interp=True, + ) if do_rms: - _, mrms = moving_transform(x=data_sigma[i, :], sf=sf, window=.3, step=.1, method='rms', - interp=True) + _, mrms = moving_transform( + x=data_sigma[i, :], sf=sf, window=0.3, step=0.1, method="rms", interp=True + ) # Let's define the thresholds if hypno is None: - thresh_rms = mrms.mean() + thresh['rms'] * trimbothstd(mrms, cut=0.10) + thresh_rms = mrms.mean() + thresh["rms"] * trimbothstd(mrms, cut=0.10) else: - thresh_rms = mrms[mask].mean() + thresh['rms'] * trimbothstd(mrms[mask], cut=0.10) + thresh_rms = mrms[mask].mean() + thresh["rms"] * trimbothstd(mrms[mask], cut=0.10) # Avoid too high threshold caused by Artefacts / Motion during Wake thresh_rms = min(thresh_rms, 10) - logger.info('Moving RMS threshold = %.3f', thresh_rms) + logger.info("Moving RMS threshold = %.3f", thresh_rms) # Boolean vector of supra-threshold indices idx_sum = np.zeros(n_samples) if do_rel_pow: - idx_rel_pow = (rel_pow >= thresh['rel_pow']).astype(int) + idx_rel_pow = (rel_pow >= thresh["rel_pow"]).astype(int) idx_sum += idx_rel_pow - logger.info('N supra-theshold relative power = %i', idx_rel_pow.sum()) + logger.info("N supra-theshold relative power = %i", idx_rel_pow.sum()) if do_corr: - idx_mcorr = (mcorr >= thresh['corr']).astype(int) + idx_mcorr = (mcorr >= thresh["corr"]).astype(int) idx_sum += idx_mcorr - logger.info('N supra-theshold moving corr = %i', idx_mcorr.sum()) + logger.info("N supra-theshold moving corr = %i", idx_mcorr.sum()) if do_rms: idx_mrms = (mrms >= thresh_rms).astype(int) idx_sum += idx_mrms - logger.info('N supra-theshold moving RMS = %i', idx_mrms.sum()) + logger.info("N supra-theshold moving RMS = %i", idx_mrms.sum()) # Make sure that we do not detect spindles outside mask if hypno is not None: @@ -733,7 +796,7 @@ def spindles_detect(data, sf=None, ch_names=None, hypno=None, # Sampling frequecy = 256 Hz --> w = 25 samples = 97 ms w = int(0.1 * sf) # Critical bugfix March 2022, see https://github.com/raphaelvallat/yasa/pull/55 - idx_sum = np.convolve(idx_sum, np.ones(w), mode='same') / w + idx_sum = np.convolve(idx_sum, np.ones(w), mode="same") / w # And we then find indices that are strictly greater than 2, i.e. we # find the 'true' beginning and 'true' end of the events by finding # where at least two out of the three treshold were crossed. @@ -741,7 +804,7 @@ def spindles_detect(data, sf=None, ch_names=None, hypno=None, # If no events are found, skip to next channel if not len(where_sp): - logger.warning('No spindle were found in channel %s.', ch_names[i]) + logger.warning("No spindle were found in channel %s.", ch_names[i]) continue # Merge events that are too close @@ -759,7 +822,7 @@ def spindles_detect(data, sf=None, ch_names=None, hypno=None, # If no events of good duration are found, skip to next channel if all(~good_dur): - logger.warning('No spindle were found in channel %s.', ch_names[i]) + logger.warning("No spindle were found in channel %s.", ch_names[i]) continue # Initialize empty variables @@ -796,7 +859,8 @@ def spindles_detect(data, sf=None, ch_names=None, hypno=None, # Number of oscillations peaks, peaks_params = signal.find_peaks( - sp_det, distance=distance, prominence=(None, None)) + sp_det, distance=distance, prominence=(None, None) + ) sp_osc[j] = len(peaks) # For frequency and amplitude, we can also optionally use these @@ -807,7 +871,7 @@ def spindles_detect(data, sf=None, ch_names=None, hypno=None, # Peak location & symmetry index # pk is expressed in sample since the beginning of the spindle - pk = peaks[peaks_params['prominences'].argmax()] + pk = peaks[peaks_params["prominences"].argmax()] sp_pro[j] = sp_start[j] + pk / sf sp_sym[j] = pk / sp_det.size @@ -820,70 +884,82 @@ def spindles_detect(data, sf=None, ch_names=None, hypno=None, sp_sta[j] = hypno[sp[j]][0] # Create a dataframe - sp_params = {'Start': sp_start, - 'Peak': sp_pro, - 'End': sp_end, - 'Duration': sp_dur, - 'Amplitude': sp_amp, - 'RMS': sp_rms, - 'AbsPower': sp_abs, - 'RelPower': sp_rel, - 'Frequency': sp_freq, - 'Oscillations': sp_osc, - 'Symmetry': sp_sym, - # 'SOPhase': sp_cou, - 'Stage': sp_sta} + sp_params = { + "Start": sp_start, + "Peak": sp_pro, + "End": sp_end, + "Duration": sp_dur, + "Amplitude": sp_amp, + "RMS": sp_rms, + "AbsPower": sp_abs, + "RelPower": sp_rel, + "Frequency": sp_freq, + "Oscillations": sp_osc, + "Symmetry": sp_sym, + # 'SOPhase': sp_cou, + "Stage": sp_sta, + } df_chan = pd.DataFrame(sp_params)[good_dur] # We need at least 50 detected spindles to apply the Isolation Forest. if remove_outliers and df_chan.shape[0] >= 50: - col_keep = ['Duration', 'Amplitude', 'RMS', 'AbsPower', 'RelPower', - 'Frequency', 'Oscillations', 'Symmetry'] + col_keep = [ + "Duration", + "Amplitude", + "RMS", + "AbsPower", + "RelPower", + "Frequency", + "Oscillations", + "Symmetry", + ] ilf = IsolationForest( - contamination='auto', max_samples='auto', verbose=0, random_state=42) + contamination="auto", max_samples="auto", verbose=0, random_state=42 + ) good = ilf.fit_predict(df_chan[col_keep]) good[good == -1] = 0 - logger.info('%i outliers were removed in channel %s.' - % ((good == 0).sum(), ch_names[i])) + logger.info( + "%i outliers were removed in channel %s." % ((good == 0).sum(), ch_names[i]) + ) # Remove outliers from DataFrame df_chan = df_chan[good.astype(bool)] - logger.info('%i spindles were found in channel %s.' - % (df_chan.shape[0], ch_names[i])) + logger.info("%i spindles were found in channel %s." % (df_chan.shape[0], ch_names[i])) # #################################################################### # END SINGLE CHANNEL DETECTION # #################################################################### - df_chan['Channel'] = ch_names[i] - df_chan['IdxChannel'] = i + df_chan["Channel"] = ch_names[i] + df_chan["IdxChannel"] = i df = pd.concat([df, df_chan], axis=0, ignore_index=True) # If no spindles were detected, return None if df.empty: - logger.warning('No spindles were found in data. Returning None.') + logger.warning("No spindles were found in data. Returning None.") return None # Remove useless columns to_drop = [] if hypno is None: - to_drop.append('Stage') + to_drop.append("Stage") else: - df['Stage'] = df['Stage'].astype(int) + df["Stage"] = df["Stage"].astype(int) # if not coupling: # to_drop.append('SOPhase') if len(to_drop): df = df.drop(columns=to_drop) # Find spindles that are present on at least two channels - if multi_only and df['Channel'].nunique() > 1: + if multi_only and df["Channel"].nunique() > 1: # We round to the nearest second idx_good = np.logical_or( - df['Start'].round(0).duplicated(keep=False), - df['End'].round(0).duplicated(keep=False)).to_list() + df["Start"].round(0).duplicated(keep=False), df["End"].round(0).duplicated(keep=False) + ).to_list() df = df[idx_good].reset_index(drop=True) - return SpindlesResults(events=df, data=data, sf=sf, ch_names=ch_names, - hypno=hypno, data_filt=data_sigma) + return SpindlesResults( + events=df, data=data, sf=sf, ch_names=ch_names, hypno=hypno, data_filt=data_sigma + ) class SpindlesResults(_DetectionResults): @@ -908,7 +984,7 @@ class SpindlesResults(_DetectionResults): def __init__(self, events, data, sf, ch_names, hypno, data_filt): super().__init__(events, data, sf, ch_names, hypno, data_filt) - def summary(self, grp_chan=False, grp_stage=False, mask=None, aggfunc='mean', sort=True): + def summary(self, grp_chan=False, grp_stage=False, mask=None, aggfunc="mean", sort=True): """Return a summary of the spindles detection, optionally grouped across channels and/or stage. @@ -928,8 +1004,14 @@ def summary(self, grp_chan=False, grp_stage=False, mask=None, aggfunc='mean', so sort : bool If True, sort group keys when grouping. """ - return super().summary(event_type='spindles', grp_chan=grp_chan, grp_stage=grp_stage, - aggfunc=aggfunc, sort=sort, mask=mask) + return super().summary( + event_type="spindles", + grp_chan=grp_chan, + grp_stage=grp_stage, + aggfunc=aggfunc, + sort=sort, + mask=mask, + ) def get_coincidence_matrix(self, scaled=True): """Return the (scaled) coincidence matrix. @@ -986,8 +1068,15 @@ def get_mask(self): """ return super().get_mask() - def get_sync_events(self, center='Peak', time_before=1, time_after=1, filt=(None, None), - mask=None, as_dataframe=True): + def get_sync_events( + self, + center="Peak", + time_before=1, + time_after=1, + filt=(None, None), + mask=None, + as_dataframe=True, + ): """ Return the raw or filtered data of each detected event after centering to a specific timepoint. @@ -1026,12 +1115,26 @@ def get_sync_events(self, center='Peak', time_before=1, time_after=1, filt=(None 'IdxChannel' : Index of channel in data 'Stage': Sleep stage in which the events occured (if available) """ - return super().get_sync_events(center=center, time_before=time_before, - time_after=time_after, filt=filt, mask=mask, - as_dataframe=as_dataframe) + return super().get_sync_events( + center=center, + time_before=time_before, + time_after=time_after, + filt=filt, + mask=mask, + as_dataframe=as_dataframe, + ) - def plot_average(self, center='Peak', hue='Channel', time_before=1, - time_after=1, filt=(None, None), mask=None, figsize=(6, 4.5), **kwargs): + def plot_average( + self, + center="Peak", + hue="Channel", + time_before=1, + time_after=1, + filt=(None, None), + mask=None, + figsize=(6, 4.5), + **kwargs, + ): """ Plot the average spindle. @@ -1060,10 +1163,17 @@ def plot_average(self, center='Peak', hue='Channel', time_before=1, **kwargs : dict Optional argument that are passed to :py:func:`seaborn.lineplot`. """ - return super().plot_average(event_type='spindles', center=center, - hue=hue, time_before=time_before, - time_after=time_after, filt=filt, mask=mask, - figsize=figsize, **kwargs) + return super().plot_average( + event_type="spindles", + center=center, + hue=hue, + time_before=time_before, + time_after=time_after, + filt=filt, + mask=mask, + figsize=figsize, + **kwargs, + ) def plot_detection(self): """Plot an overlay of the detected spindles on the EEG signal. @@ -1085,11 +1195,23 @@ def plot_detection(self): ############################################################################# -def sw_detect(data, sf=None, ch_names=None, hypno=None, include=(2, 3), freq_sw=(0.3, 1.5), - dur_neg=(0.3, 1.5), dur_pos=(0.1, 1), amp_neg=(40, 200), amp_pos=(10, 150), - amp_ptp=(75, 350), coupling=False, - coupling_params={"freq_sp": (12, 16), "time": 1, "p": 0.05}, - remove_outliers=False, verbose=False): +def sw_detect( + data, + sf=None, + ch_names=None, + hypno=None, + include=(2, 3), + freq_sw=(0.3, 1.5), + dur_neg=(0.3, 1.5), + dur_pos=(0.1, 1), + amp_neg=(40, 200), + amp_pos=(10, 150), + amp_ptp=(75, 350), + coupling=False, + coupling_params={"freq_sp": (12, 16), "time": 1, "p": 0.05}, + remove_outliers=False, + verbose=False, +): """Slow-waves detection. Parameters @@ -1297,12 +1419,13 @@ def sw_detect(data, sf=None, ch_names=None, hypno=None, include=(2, 3), freq_sw= """ set_log_level(verbose) - (data, sf, ch_names, hypno, include, mask, n_chan, n_samples, bad_chan - ) = _check_data_hypno(data, sf, ch_names, hypno, include) + (data, sf, ch_names, hypno, include, mask, n_chan, n_samples, bad_chan) = _check_data_hypno( + data, sf, ch_names, hypno, include + ) # If all channels are bad if sum(bad_chan) == n_chan: - logger.warning('All channels have bad amplitude. Returning None.') + logger.warning("All channels have bad amplitude. Returning None.") return None # Define time vector @@ -1312,13 +1435,21 @@ def sw_detect(data, sf=None, ch_names=None, hypno=None, include=(2, 3), freq_sw= # Bandpass filter nfast = next_fast_len(n_samples) data_filt = filter_data( - data, sf, freq_sw[0], freq_sw[1], method='fir', verbose=0, l_trans_bandwidth=0.2, - h_trans_bandwidth=0.2) + data, + sf, + freq_sw[0], + freq_sw[1], + method="fir", + verbose=0, + l_trans_bandwidth=0.2, + h_trans_bandwidth=0.2, + ) # Extract the spindles-related sigma signal for coupling if coupling: is_tensorpac_installed() import tensorpac.methods as tpm + # The width of the transition band is set to 1.5 Hz on each side, # meaning that for freq_sp = (12, 15 Hz), the -6 dB points are located # at 11.25 and 15.75 Hz. The frequency band for the amplitude signal @@ -1329,10 +1460,17 @@ def sw_detect(data, sf=None, ch_names=None, hypno=None, include=(2, 3), freq_sw= assert "freq_sp" in coupling_params.keys() assert "time" in coupling_params.keys() assert "p" in coupling_params.keys() - freq_sp = coupling_params['freq_sp'] + freq_sp = coupling_params["freq_sp"] data_sp = filter_data( - data, sf, freq_sp[0], freq_sp[1], method='fir', l_trans_bandwidth=1.5, - h_trans_bandwidth=1.5, verbose=0) + data, + sf, + freq_sp[0], + freq_sp[1], + method="fir", + l_trans_bandwidth=1.5, + h_trans_bandwidth=1.5, + verbose=0, + ) # Now extract the instantaneous phase/amplitude using Hilbert transform sw_pha = np.angle(signal.hilbert(data_filt, N=nfast)[:, :n_samples]) sp_amp = np.abs(signal.hilbert(data_sp, N=nfast)[:, :n_samples]) @@ -1359,7 +1497,7 @@ def sw_detect(data, sf=None, ch_names=None, hypno=None, include=(2, 3), freq_sw= # If no peaks are detected, return None if len(idx_neg_peaks) == 0 or len(idx_pos_peaks) == 0: - logger.warning('No SW were found in channel %s.', ch_names[i]) + logger.warning("No SW were found in channel %s.", ch_names[i]) continue # Make sure that the last detected peak is a positive one @@ -1374,12 +1512,12 @@ def sw_detect(data, sf=None, ch_names=None, hypno=None, include=(2, 3), freq_sw= idx_pos_peaks = idx_neg_peaks + closest_pos_peaks # Now we compute the PTP amplitude and keep only the good peaks - sw_ptp = (np.abs(data_filt[i, idx_neg_peaks]) + data_filt[i, idx_pos_peaks]) + sw_ptp = np.abs(data_filt[i, idx_neg_peaks]) + data_filt[i, idx_pos_peaks] good_ptp = np.logical_and(sw_ptp > amp_ptp[0], sw_ptp < amp_ptp[1]) # If good_ptp is all False if all(~good_ptp): - logger.warning('No SW were found in channel %s.', ch_names[i]) + logger.warning("No SW were found in channel %s.", ch_names[i]) continue sw_ptp = sw_ptp[good_ptp] @@ -1428,28 +1566,30 @@ def sw_detect(data, sf=None, ch_names=None, hypno=None, include=(2, 3), freq_sw= sw_sta = np.zeros(sw_dur.shape) # And we apply a set of thresholds to remove bad slow waves - good_sw = np.logical_and.reduce(( - # Data edges - previous_neg_zc != 0, - following_neg_zc != 0, - previous_pos_zc != 0, - following_pos_zc != 0, - # Duration criteria - sw_dur == sw_dur_both_phase, # dur = negative + positive - sw_dur <= dur_neg[1] + dur_pos[1], # dur < max(neg) + max(pos) - sw_dur >= dur_neg[0] + dur_pos[0], # dur > min(neg) + min(pos) - neg_phase_dur > dur_neg[0], - neg_phase_dur < dur_neg[1], - pos_phase_dur > dur_pos[0], - pos_phase_dur < dur_pos[1], - # Sanity checks - sw_midcrossing > sw_start, - sw_midcrossing < sw_end, - sw_slope > 0, - )) + good_sw = np.logical_and.reduce( + ( + # Data edges + previous_neg_zc != 0, + following_neg_zc != 0, + previous_pos_zc != 0, + following_pos_zc != 0, + # Duration criteria + sw_dur == sw_dur_both_phase, # dur = negative + positive + sw_dur <= dur_neg[1] + dur_pos[1], # dur < max(neg) + max(pos) + sw_dur >= dur_neg[0] + dur_pos[0], # dur > min(neg) + min(pos) + neg_phase_dur > dur_neg[0], + neg_phase_dur < dur_neg[1], + pos_phase_dur > dur_pos[0], + pos_phase_dur < dur_pos[1], + # Sanity checks + sw_midcrossing > sw_start, + sw_midcrossing < sw_end, + sw_slope > 0, + ) + ) if all(~good_sw): - logger.warning('No SW were found in channel %s.', ch_names[i]) + logger.warning("No SW were found in channel %s.", ch_names[i]) continue # Filter good events @@ -1466,28 +1606,33 @@ def sw_detect(data, sf=None, ch_names=None, hypno=None, include=(2, 3), freq_sw= sw_sta = sw_sta[good_sw] # Create a dictionnary - sw_params = OrderedDict({ - 'Start': sw_start, - 'NegPeak': sw_idx_neg, - 'MidCrossing': sw_midcrossing, - 'PosPeak': sw_idx_pos, - 'End': sw_end, - 'Duration': sw_dur, - 'ValNegPeak': data_filt[i, idx_neg_peaks], - 'ValPosPeak': data_filt[i, idx_pos_peaks], - 'PTP': sw_ptp, - 'Slope': sw_slope, - 'Frequency': 1 / sw_dur, - 'Stage': sw_sta, - }) + sw_params = OrderedDict( + { + "Start": sw_start, + "NegPeak": sw_idx_neg, + "MidCrossing": sw_midcrossing, + "PosPeak": sw_idx_pos, + "End": sw_end, + "Duration": sw_dur, + "ValNegPeak": data_filt[i, idx_neg_peaks], + "ValPosPeak": data_filt[i, idx_pos_peaks], + "PTP": sw_ptp, + "Slope": sw_slope, + "Frequency": 1 / sw_dur, + "Stage": sw_sta, + } + ) # Add phase (in radians) of slow-oscillation signal at maximum # spindles-related sigma amplitude within a XX-seconds centered epochs. if coupling: # Get phase and amplitude for each centered epoch - time_before = time_after = coupling_params['time'] - assert float(sf * time_before).is_integer(), ( - "Invalid time parameter for coupling. Must be a whole number of samples.") + time_before = time_after = coupling_params["time"] + assert float( + sf * time_before + ).is_integer(), ( + "Invalid time parameter for coupling. Must be a whole number of samples." + ) bef = int(sf * time_before) aft = int(sf * time_after) # Center of each epoch is defined as the negative peak of the SW @@ -1501,69 +1646,74 @@ def sw_detect(data, sf=None, ch_names=None, hypno=None, include=(2, 3), freq_sw= # Now we need to append it back to the original unmasked shape # to avoid error when idx.shape[0] != idx_valid.shape, i.e. # some epochs were out of data bounds. - sw_params['SigmaPeak'] = np.ones(n_peaks) * np.nan + sw_params["SigmaPeak"] = np.ones(n_peaks) * np.nan # Timestamp at sigma peak, expressed in seconds from negative peak # e.g. -0.39, 0.5, 1, 2 -- limits are [time_before, time_after] time_sigpk = (idx_max_amp - bef) / sf # convert to absolute time from beginning of the recording # time_sigpk only includes valid epoch time_sigpk_abs = sw_idx_neg[idx_valid] + time_sigpk - sw_params['SigmaPeak'][idx_valid] = time_sigpk_abs + sw_params["SigmaPeak"][idx_valid] = time_sigpk_abs # 2) PhaseAtSigmaPeak # Find SW phase at max sigma amplitude in epoch pha_at_max = np.squeeze(np.take_along_axis(sw_pha_ev, idx_max_amp[..., None], axis=1)) - sw_params['PhaseAtSigmaPeak'] = np.ones(n_peaks) * np.nan - sw_params['PhaseAtSigmaPeak'][idx_valid] = pha_at_max + sw_params["PhaseAtSigmaPeak"] = np.ones(n_peaks) * np.nan + sw_params["PhaseAtSigmaPeak"][idx_valid] = pha_at_max # 3) Normalized Direct PAC, with thresholding # Unreliable values are set to 0 - ndp = np.squeeze(tpm.norm_direct_pac( - sw_pha_ev[None, ...], sp_amp_ev[None, ...], p=coupling_params['p'])) - sw_params['ndPAC'] = np.ones(n_peaks) * np.nan - sw_params['ndPAC'][idx_valid] = ndp + ndp = np.squeeze( + tpm.norm_direct_pac( + sw_pha_ev[None, ...], sp_amp_ev[None, ...], p=coupling_params["p"] + ) + ) + sw_params["ndPAC"] = np.ones(n_peaks) * np.nan + sw_params["ndPAC"][idx_valid] = ndp # Make sure that Stage is the last column of the dataframe - sw_params.move_to_end('Stage') + sw_params.move_to_end("Stage") # Convert to dataframe, keeping only good events df_chan = pd.DataFrame(sw_params) # Remove all duplicates - df_chan = df_chan.drop_duplicates(subset=['Start'], keep=False) - df_chan = df_chan.drop_duplicates(subset=['End'], keep=False) + df_chan = df_chan.drop_duplicates(subset=["Start"], keep=False) + df_chan = df_chan.drop_duplicates(subset=["End"], keep=False) # We need at least 50 detected slow waves to apply the Isolation Forest if remove_outliers and df_chan.shape[0] >= 50: - col_keep = ['Duration', 'ValNegPeak', 'ValPosPeak', 'PTP', 'Slope', 'Frequency'] - ilf = IsolationForest(contamination='auto', max_samples='auto', - verbose=0, random_state=42) + col_keep = ["Duration", "ValNegPeak", "ValPosPeak", "PTP", "Slope", "Frequency"] + ilf = IsolationForest( + contamination="auto", max_samples="auto", verbose=0, random_state=42 + ) good = ilf.fit_predict(df_chan[col_keep]) good[good == -1] = 0 - logger.info('%i outliers were removed in channel %s.' - % ((good == 0).sum(), ch_names[i])) + logger.info( + "%i outliers were removed in channel %s." % ((good == 0).sum(), ch_names[i]) + ) # Remove outliers from DataFrame df_chan = df_chan[good.astype(bool)] - logger.info('%i slow-waves were found in channel %s.' - % (df_chan.shape[0], ch_names[i])) + logger.info("%i slow-waves were found in channel %s." % (df_chan.shape[0], ch_names[i])) # #################################################################### # END SINGLE CHANNEL DETECTION # #################################################################### - df_chan['Channel'] = ch_names[i] - df_chan['IdxChannel'] = i + df_chan["Channel"] = ch_names[i] + df_chan["IdxChannel"] = i df = pd.concat([df, df_chan], axis=0, ignore_index=True) # If no SW were detected, return None if df.empty: - logger.warning('No SW were found in data. Returning None.') + logger.warning("No SW were found in data. Returning None.") return None if hypno is None: - df = df.drop(columns=['Stage']) + df = df.drop(columns=["Stage"]) else: - df['Stage'] = df['Stage'].astype(int) + df["Stage"] = df["Stage"].astype(int) - return SWResults(events=df, data=data, sf=sf, ch_names=ch_names, - hypno=hypno, data_filt=data_filt) + return SWResults( + events=df, data=data, sf=sf, ch_names=ch_names, hypno=hypno, data_filt=data_filt + ) class SWResults(_DetectionResults): @@ -1588,7 +1738,7 @@ class SWResults(_DetectionResults): def __init__(self, events, data, sf, ch_names, hypno, data_filt): super().__init__(events, data, sf, ch_names, hypno, data_filt) - def summary(self, grp_chan=False, grp_stage=False, mask=None, aggfunc='mean', sort=True): + def summary(self, grp_chan=False, grp_stage=False, mask=None, aggfunc="mean", sort=True): """Return a summary of the SW detection, optionally grouped across channels and/or stage. @@ -1606,8 +1756,14 @@ def summary(self, grp_chan=False, grp_stage=False, mask=None, aggfunc='mean', so sort : bool If True, sort group keys when grouping. """ - return super().summary(event_type='sw', grp_chan=grp_chan, grp_stage=grp_stage, - aggfunc=aggfunc, sort=sort, mask=mask) + return super().summary( + event_type="sw", + grp_chan=grp_chan, + grp_stage=grp_stage, + aggfunc=aggfunc, + sort=sort, + mask=mask, + ) def find_cooccurring_spindles(self, spindles, lookaround=1.2): """Given a spindles detection summary dataframe, find slow-waves that co-occur with @@ -1663,13 +1819,13 @@ def find_cooccurring_spindles(self, spindles, lookaround=1.2): cooccurring_spindle_peaks = [] # Find intersecting channels - common_ch = np.intersect1d(self._events['Channel'].unique(), spindles['Channel'].unique()) + common_ch = np.intersect1d(self._events["Channel"].unique(), spindles["Channel"].unique()) assert len(common_ch), "No common channel(s) were found." # Loop across channels - for chan in self._events['Channel'].unique(): + for chan in self._events["Channel"].unique(): sw_chan_peaks = self._events[self._events["Channel"] == chan]["NegPeak"].to_numpy() - sp_chan_peaks = spindles[spindles["Channel"] == chan]['Peak'].to_numpy() + sp_chan_peaks = spindles[spindles["Channel"] == chan]["Peak"].to_numpy() # Loop across individual slow-waves for sw_negpeak in sw_chan_peaks: start = sw_negpeak - lookaround @@ -1687,7 +1843,7 @@ def find_cooccurring_spindles(self, spindles, lookaround=1.2): # Add columns to self._events: IN-PLACE MODIFICATION! self._events["CooccurringSpindle"] = ~np.isnan(distance_sp_to_sw_peak) self._events["CooccurringSpindlePeak"] = cooccurring_spindle_peaks - self._events['DistanceSpindleToSW'] = distance_sp_to_sw_peak + self._events["DistanceSpindleToSW"] = distance_sp_to_sw_peak def get_coincidence_matrix(self, scaled=True): """Return the (scaled) coincidence matrix. @@ -1744,8 +1900,15 @@ def get_mask(self): """ return super().get_mask() - def get_sync_events(self, center='NegPeak', time_before=0.4, time_after=0.8, filt=(None, None), - mask=None, as_dataframe=True): + def get_sync_events( + self, + center="NegPeak", + time_before=0.4, + time_after=0.8, + filt=(None, None), + mask=None, + as_dataframe=True, + ): """ Return the raw data of each detected event after centering to a specific timepoint. @@ -1784,11 +1947,25 @@ def get_sync_events(self, center='NegPeak', time_before=0.4, time_after=0.8, fil 'Stage': Sleep stage in which the events occured (if available) """ return super().get_sync_events( - center=center, time_before=time_before, time_after=time_after, filt=filt, mask=mask, - as_dataframe=as_dataframe) + center=center, + time_before=time_before, + time_after=time_after, + filt=filt, + mask=mask, + as_dataframe=as_dataframe, + ) - def plot_average(self, center='NegPeak', hue='Channel', time_before=0.4, time_after=0.8, - filt=(None, None), mask=None, figsize=(6, 4.5), **kwargs): + def plot_average( + self, + center="NegPeak", + hue="Channel", + time_before=0.4, + time_after=0.8, + filt=(None, None), + mask=None, + figsize=(6, 4.5), + **kwargs, + ): """ Plot the average slow-wave. @@ -1818,8 +1995,16 @@ def plot_average(self, center='NegPeak', hue='Channel', time_before=0.4, time_af Optional argument that are passed to :py:func:`seaborn.lineplot`. """ return super().plot_average( - event_type='sw', center=center, hue=hue, time_before=time_before, - time_after=time_after, filt=filt, mask=mask, figsize=figsize, **kwargs) + event_type="sw", + center=center, + hue=hue, + time_before=time_before, + time_after=time_after, + filt=filt, + mask=mask, + figsize=figsize, + **kwargs, + ) def plot_detection(self): """Plot an overlay of the detected slow-waves on the EEG signal. @@ -1841,8 +2026,18 @@ def plot_detection(self): ############################################################################# -def rem_detect(loc, roc, sf, hypno=None, include=4, amplitude=(50, 325), duration=(0.3, 1.2), - freq_rem=(0.5, 5), remove_outliers=False, verbose=False): +def rem_detect( + loc, + roc, + sf, + hypno=None, + include=4, + amplitude=(50, 325), + duration=(0.3, 1.2), + freq_rem=(0.5, 5), + remove_outliers=False, + verbose=False, +): """Rapid eye movements (REMs) detection. This detection requires both the left EOG (LOC) and right EOG (LOC). @@ -1972,18 +2167,18 @@ def rem_detect(loc, roc, sf, hypno=None, include=4, amplitude=(50, 325), duratio # Safety checks loc = np.squeeze(np.asarray(loc, dtype=np.float64)) roc = np.squeeze(np.asarray(roc, dtype=np.float64)) - assert loc.ndim == 1, 'LOC must be 1D.' - assert roc.ndim == 1, 'ROC must be 1D.' - assert loc.size == roc.size, 'LOC and ROC must have the same size.' + assert loc.ndim == 1, "LOC must be 1D." + assert roc.ndim == 1, "ROC must be 1D." + assert loc.size == roc.size, "LOC and ROC must have the same size." data = np.vstack((loc, roc)) - (data, sf, ch_names, hypno, include, mask, n_chan, n_samples, bad_chan - ) = _check_data_hypno(data, sf, ['LOC', 'ROC'], hypno, include) + (data, sf, ch_names, hypno, include, mask, n_chan, n_samples, bad_chan) = _check_data_hypno( + data, sf, ["LOC", "ROC"], hypno, include + ) # If all channels are bad if any(bad_chan): - logger.warning('At least one channel has bad amplitude. ' - 'Returning None.') + logger.warning("At least one channel has bad amplitude. " "Returning None.") return None # Bandpass filter @@ -1997,9 +2192,14 @@ def rem_detect(loc, roc, sf, hypno=None, include=4, amplitude=(50, 325), duratio # - distance: required distance in samples between neighboring peaks. # - prominence: required prominence of peaks. # - wlen: limit search for bases to a specific window. - hmin, hmax = amplitude[0]**2, amplitude[1]**2 - pks, pks_params = signal.find_peaks(negp, height=(hmin, hmax), distance=(duration[0] * sf), - prominence=(0.8 * hmin), wlen=(duration[1] * sf)) + hmin, hmax = amplitude[0] ** 2, amplitude[1] ** 2 + pks, pks_params = signal.find_peaks( + negp, + height=(hmin, hmax), + distance=(duration[0] * sf), + prominence=(0.8 * hmin), + wlen=(duration[1] * sf), + ) # Intersect with sleep stage vector # We do that before calculating the features in order to gain some time @@ -2010,78 +2210,98 @@ def rem_detect(loc, roc, sf, hypno=None, include=4, amplitude=(50, 325), duratio # If no peaks are detected, return None if len(pks) == 0: - logger.warning('No REMs were found in data. Returning None.') + logger.warning("No REMs were found in data. Returning None.") return None # Hypnogram if hypno is not None: # The sleep stage at the beginning of the REM is considered. - rem_sta = hypno[pks_params['left_bases']] + rem_sta = hypno[pks_params["left_bases"]] else: rem_sta = np.zeros(pks.shape) # Calculate time features - pks_params['Start'] = pks_params['left_bases'] / sf - pks_params['Peak'] = pks / sf - pks_params['End'] = pks_params['right_bases'] / sf - pks_params['Duration'] = pks_params['End'] - pks_params['Start'] + pks_params["Start"] = pks_params["left_bases"] / sf + pks_params["Peak"] = pks / sf + pks_params["End"] = pks_params["right_bases"] / sf + pks_params["Duration"] = pks_params["End"] - pks_params["Start"] # Time points in minutes (HH:MM:SS) # pks_params['StartMin'] = pd.to_timedelta(pks_params['Start'], unit='s').dt.round('s') # noqa # pks_params['PeakMin'] = pd.to_timedelta(pks_params['Peak'], unit='s').dt.round('s') # noqa # pks_params['EndMin'] = pd.to_timedelta(pks_params['End'], unit='s').dt.round('s') # noqa # Absolute LOC / ROC value at peak (filtered) - pks_params['LOCAbsValPeak'] = abs(data_filt[0, pks]) - pks_params['ROCAbsValPeak'] = abs(data_filt[1, pks]) + pks_params["LOCAbsValPeak"] = abs(data_filt[0, pks]) + pks_params["ROCAbsValPeak"] = abs(data_filt[1, pks]) # Absolute rising and falling slope - dist_pk_left = (pks - pks_params['left_bases']) / sf - dist_pk_right = (pks_params['right_bases'] - pks) / sf - locrs = (data_filt[0, pks] - data_filt[0, pks_params['left_bases']]) / dist_pk_left - rocrs = (data_filt[1, pks] - data_filt[1, pks_params['left_bases']]) / dist_pk_left - locfs = (data_filt[0, pks_params['right_bases']] - data_filt[0, pks]) / dist_pk_right - rocfs = (data_filt[1, pks_params['right_bases']] - data_filt[1, pks]) / dist_pk_right - pks_params['LOCAbsRiseSlope'] = abs(locrs) - pks_params['ROCAbsRiseSlope'] = abs(rocrs) - pks_params['LOCAbsFallSlope'] = abs(locfs) - pks_params['ROCAbsFallSlope'] = abs(rocfs) - pks_params['Stage'] = rem_sta # Sleep stage + dist_pk_left = (pks - pks_params["left_bases"]) / sf + dist_pk_right = (pks_params["right_bases"] - pks) / sf + locrs = (data_filt[0, pks] - data_filt[0, pks_params["left_bases"]]) / dist_pk_left + rocrs = (data_filt[1, pks] - data_filt[1, pks_params["left_bases"]]) / dist_pk_left + locfs = (data_filt[0, pks_params["right_bases"]] - data_filt[0, pks]) / dist_pk_right + rocfs = (data_filt[1, pks_params["right_bases"]] - data_filt[1, pks]) / dist_pk_right + pks_params["LOCAbsRiseSlope"] = abs(locrs) + pks_params["ROCAbsRiseSlope"] = abs(rocrs) + pks_params["LOCAbsFallSlope"] = abs(locfs) + pks_params["ROCAbsFallSlope"] = abs(rocfs) + pks_params["Stage"] = rem_sta # Sleep stage # Convert to Pandas DataFrame df = pd.DataFrame(pks_params) # Make sure that the sign of ROC and LOC is opposite - df['IsOppositeSign'] = (np.sign(data_filt[1, pks]) != np.sign(data_filt[0, pks])) + df["IsOppositeSign"] = np.sign(data_filt[1, pks]) != np.sign(data_filt[0, pks]) df = df[np.sign(data_filt[1, pks]) != np.sign(data_filt[0, pks])] # Remove bad duration tmin, tmax = duration - good_dur = np.logical_and(pks_params['Duration'] >= tmin, pks_params['Duration'] < tmax) + good_dur = np.logical_and(pks_params["Duration"] >= tmin, pks_params["Duration"] < tmax) df = df[good_dur] # Keep only useful channels - df = df[['Start', 'Peak', 'End', 'Duration', 'LOCAbsValPeak', 'ROCAbsValPeak', - 'LOCAbsRiseSlope', 'ROCAbsRiseSlope', 'LOCAbsFallSlope', 'ROCAbsFallSlope', 'Stage']] + df = df[ + [ + "Start", + "Peak", + "End", + "Duration", + "LOCAbsValPeak", + "ROCAbsValPeak", + "LOCAbsRiseSlope", + "ROCAbsRiseSlope", + "LOCAbsFallSlope", + "ROCAbsFallSlope", + "Stage", + ] + ] if hypno is None: - df = df.drop(columns=['Stage']) + df = df.drop(columns=["Stage"]) else: - df['Stage'] = df['Stage'].astype(int) + df["Stage"] = df["Stage"].astype(int) # We need at least 50 detected REMs to apply the Isolation Forest. if remove_outliers and df.shape[0] >= 50: - col_keep = ['Duration', 'LOCAbsValPeak', 'ROCAbsValPeak', 'LOCAbsRiseSlope', - 'ROCAbsRiseSlope', 'LOCAbsFallSlope', 'ROCAbsFallSlope'] - ilf = IsolationForest(contamination='auto', max_samples='auto', - verbose=0, random_state=42) + col_keep = [ + "Duration", + "LOCAbsValPeak", + "ROCAbsValPeak", + "LOCAbsRiseSlope", + "ROCAbsRiseSlope", + "LOCAbsFallSlope", + "ROCAbsFallSlope", + ] + ilf = IsolationForest(contamination="auto", max_samples="auto", verbose=0, random_state=42) good = ilf.fit_predict(df[col_keep]) good[good == -1] = 0 - logger.info('%i outliers were removed.', (good == 0).sum()) + logger.info("%i outliers were removed.", (good == 0).sum()) # Remove outliers from DataFrame df = df[good.astype(bool)] - logger.info('%i REMs were found in data.', df.shape[0]) + logger.info("%i REMs were found in data.", df.shape[0]) df = df.reset_index(drop=True) - return REMResults(events=df, data=data, sf=sf, ch_names=ch_names, - hypno=hypno, data_filt=data_filt) + return REMResults( + events=df, data=data, sf=sf, ch_names=ch_names, hypno=hypno, data_filt=data_filt + ) class REMResults(_DetectionResults): @@ -2108,7 +2328,7 @@ class REMResults(_DetectionResults): def __init__(self, events, data, sf, ch_names, hypno, data_filt): super().__init__(events, data, sf, ch_names, hypno, data_filt) - def summary(self, grp_stage=False, mask=None, aggfunc='mean', sort=True): + def summary(self, grp_stage=False, mask=None, aggfunc="mean", sort=True): """Return a summary of the REM detection, optionally grouped across stage. Parameters @@ -2126,8 +2346,14 @@ def summary(self, grp_stage=False, mask=None, aggfunc='mean', sort=True): """ # ``grp_chan`` is always False for REM detection because the # REMs are always detected on a combination of LOC and ROC. - return super().summary(event_type='rem', grp_chan=False, grp_stage=grp_stage, - aggfunc=aggfunc, sort=sort, mask=mask) + return super().summary( + event_type="rem", + grp_chan=False, + grp_stage=grp_stage, + aggfunc=aggfunc, + sort=sort, + mask=mask, + ) def get_mask(self): """Return a boolean array indicating for each sample in data if this @@ -2135,14 +2361,15 @@ def get_mask(self): """ # We cannot use super() because "Channel" is not present in _events. from yasa.others import _index_to_events + mask = np.zeros(self._data.shape, dtype=int) - idx_ev = _index_to_events( - self._events[['Start', 'End']].to_numpy() * self._sf) + idx_ev = _index_to_events(self._events[["Start", "End"]].to_numpy() * self._sf) mask[:, idx_ev] = 1 return mask - def get_sync_events(self, center='Peak', time_before=0.4, time_after=0.4, - filt=(None, None), mask=None): + def get_sync_events( + self, center="Peak", time_before=0.4, time_after=0.4, filt=(None, None), mask=None + ): """ Return the raw or filtered data of each detected event after centering to a specific timepoint. @@ -2177,6 +2404,7 @@ def get_sync_events(self, center='Peak', time_before=0.4, time_after=0.4, 'IdxChannel' : Index of channel in data """ from yasa.others import get_centered_indices + assert time_before >= 0 assert time_after >= 0 bef = int(self._sf * time_before) @@ -2184,7 +2412,8 @@ def get_sync_events(self, center='Peak', time_before=0.4, time_after=0.4, if any(filt): data = mne.filter.filter_data( - self._data, self._sf, l_freq=filt[0], h_freq=filt[1], method='fir', verbose=False) + self._data, self._sf, l_freq=filt[0], h_freq=filt[1], method="fir", verbose=False + ) else: data = self._data @@ -2192,15 +2421,16 @@ def get_sync_events(self, center='Peak', time_before=0.4, time_after=0.4, mask = self._check_mask(mask) masked_events = self._events.loc[mask, :] - time = np.arange(-bef, aft + 1, dtype='int') / self._sf + time = np.arange(-bef, aft + 1, dtype="int") / self._sf # Get location of peaks in data peaks = (masked_events[center] * self._sf).astype(int).to_numpy() # Get centered indices (here we could use second channel as well). idx, idx_valid = get_centered_indices(data[0, :], peaks, bef, aft) # If no good epochs are returned raise a warning assert len(idx_valid), ( - 'Time before and/or time after exceed data bounds, please ' - 'lower the temporal window around center.') + "Time before and/or time after exceed data bounds, please " + "lower the temporal window around center." + ) # Initialize empty dataframe df_sync = pd.DataFrame() @@ -2209,16 +2439,24 @@ def get_sync_events(self, center='Peak', time_before=0.4, time_after=0.4, for i, ch in enumerate(self._ch_names): amps = data[i, idx] df_chan = pd.DataFrame(amps.T) - df_chan['Time'] = time - df_chan = df_chan.melt(id_vars='Time', var_name='Event', value_name='Amplitude') - df_chan['Channel'] = ch - df_chan['IdxChannel'] = i + df_chan["Time"] = time + df_chan = df_chan.melt(id_vars="Time", var_name="Event", value_name="Amplitude") + df_chan["Channel"] = ch + df_chan["IdxChannel"] = i df_sync = pd.concat([df_sync, df_chan], axis=0, ignore_index=True) return df_sync - def plot_average(self, center='Peak', time_before=0.4, time_after=0.4, filt=(None, None), - mask=None, figsize=(6, 4.5), **kwargs): + def plot_average( + self, + center="Peak", + time_before=0.4, + time_after=0.4, + filt=(None, None), + mask=None, + figsize=(6, 4.5), + **kwargs, + ): """ Plot the average REM. @@ -2247,17 +2485,18 @@ def plot_average(self, center='Peak', time_before=0.4, time_after=0.4, filt=(Non import seaborn as sns import matplotlib.pyplot as plt - df_sync = self.get_sync_events(center=center, time_before=time_before, - time_after=time_after, filt=filt, mask=mask) + df_sync = self.get_sync_events( + center=center, time_before=time_before, time_after=time_after, filt=filt, mask=mask + ) # Start figure fig, ax = plt.subplots(1, 1, figsize=figsize) - sns.lineplot(data=df_sync, x='Time', y='Amplitude', hue='Channel', ax=ax, **kwargs) + sns.lineplot(data=df_sync, x="Time", y="Amplitude", hue="Channel", ax=ax, **kwargs) # ax.legend(frameon=False, loc='lower right') - ax.set_xlim(df_sync['Time'].min(), df_sync['Time'].max()) + ax.set_xlim(df_sync["Time"].min(), df_sync["Time"].max()) ax.set_title("Average REM") - ax.set_xlabel('Time (sec)') - ax.set_ylabel('Amplitude (uV)') + ax.set_xlabel("Time (sec)") + ax.set_ylabel("Amplitude (uV)") return ax @@ -2266,8 +2505,17 @@ def plot_average(self, center='Peak', time_before=0.4, time_after=0.4, filt=(Non ############################################################################# -def art_detect(data, sf=None, window=5, hypno=None, include=(1, 2, 3, 4), - method='covar', threshold=3, n_chan_reject=1, verbose=False): +def art_detect( + data, + sf=None, + window=5, + hypno=None, + include=(1, 2, 3, 4), + method="covar", + threshold=3, + n_chan_reject=1, + verbose=False, +): r""" Automatic artifact rejection. @@ -2446,23 +2694,24 @@ def art_detect(data, sf=None, window=5, hypno=None, include=(1, 2, 3, 4), ########################################################################### set_log_level(verbose) - (data, sf, _, hypno, include, _, n_chan, n_samples, _ - ) = _check_data_hypno(data, sf, ch_names=None, hypno=hypno, include=include, check_amp=False) + (data, sf, _, hypno, include, _, n_chan, n_samples, _) = _check_data_hypno( + data, sf, ch_names=None, hypno=hypno, include=include, check_amp=False + ) - assert isinstance(n_chan_reject, int), 'n_chan_reject must be int.' - assert n_chan_reject >= 1, 'n_chan_reject must be >= 1.' - assert n_chan_reject <= n_chan, 'n_chan_reject must be <= n_chan.' + assert isinstance(n_chan_reject, int), "n_chan_reject must be int." + assert n_chan_reject >= 1, "n_chan_reject must be >= 1." + assert n_chan_reject <= n_chan, "n_chan_reject must be <= n_chan." # Safety check: sampling frequency and window - assert isinstance(sf, (int, float)), 'sf must be int or float' - assert isinstance(window, (int, float)), 'window must be int or float' + assert isinstance(sf, (int, float)), "sf must be int or float" + assert isinstance(window, (int, float)), "window must be int or float" if isinstance(sf, float): - assert sf.is_integer(), 'sf must be a whole number.' + assert sf.is_integer(), "sf must be a whole number." sf = int(sf) win_sec = window window = win_sec * sf # Convert window to samples if isinstance(window, float): - assert window.is_integer(), 'window * sf must be a whole number.' + assert window.is_integer(), "window * sf must be a whole number." window = int(window) # Safety check: hypnogram @@ -2474,24 +2723,27 @@ def art_detect(data, sf=None, window=5, hypno=None, include=(1, 2, 3, 4), # Safety checks: methods assert isinstance(method, str), "method must be a string." method = method.lower() - if method in ['cov', 'covar', 'covariance', 'riemann', 'potato']: - method = 'covar' + if method in ["cov", "covar", "covariance", "riemann", "potato"]: + method = "covar" is_pyriemann_installed() from pyriemann.estimation import Covariances, Shrinkage from pyriemann.clustering import Potato + # Must have at least 4 channels to use method='covar' if n_chan <= 4: - logger.warning("Must have at least 4 channels for method='covar'. " - "Automatically switching to method='std'.") - method = 'std' + logger.warning( + "Must have at least 4 channels for method='covar'. " + "Automatically switching to method='std'." + ) + method = "std" ########################################################################### # START THE REJECTION ########################################################################### # Remove flat channels - isflat = (np.nanstd(data, axis=-1) == 0) + isflat = np.nanstd(data, axis=-1) == 0 if isflat.any(): - logger.warning('Flat channel(s) were found and removed in data.') + logger.warning("Flat channel(s) were found and removed in data.") data = data[~isflat] n_chan = data.shape[0] @@ -2509,13 +2761,13 @@ def art_detect(data, sf=None, window=5, hypno=None, include=(1, 2, 3, 4), n_flat_epochs = where_flat_epochs.size # Now let's make sure that we have an hypnogram and an include variable - if 'hypno_win' not in locals(): + if "hypno_win" not in locals(): # [-2, -2, -2, -2, ...], where -2 stands for unscored - hypno_win = -2 * np.ones(n_epochs, dtype='float') - include = np.array([-2], dtype='float') + hypno_win = -2 * np.ones(n_epochs, dtype="float") + include = np.array([-2], dtype="float") # We want to make sure that hypno-win and n_epochs have EXACTLY same shape - assert n_epochs == hypno_win.shape[-1], 'Hypno and epochs do not match.' + assert n_epochs == hypno_win.shape[-1], "Hypno and epochs do not match." # Finally, we make sure not to include any flat epochs in calculation # just using a random number that is unlikely to be picked by users @@ -2523,19 +2775,19 @@ def art_detect(data, sf=None, window=5, hypno=None, include=(1, 2, 3, 4), hypno_win[where_flat_epochs] = -111991 # Add logger info - logger.info('Number of channels in data = %i', n_chan) - logger.info('Number of samples in data = %i', n_samples) - logger.info('Sampling frequency = %.2f Hz', sf) - logger.info('Data duration = %.2f seconds', n_samples / sf) - logger.info('Number of epochs = %i' % n_epochs) - logger.info('Artifact window = %.2f seconds' % win_sec) - logger.info('Method = %s' % method) - logger.info('Threshold = %.2f standard deviations' % threshold) + logger.info("Number of channels in data = %i", n_chan) + logger.info("Number of samples in data = %i", n_samples) + logger.info("Sampling frequency = %.2f Hz", sf) + logger.info("Data duration = %.2f seconds", n_samples / sf) + logger.info("Number of epochs = %i" % n_epochs) + logger.info("Artifact window = %.2f seconds" % win_sec) + logger.info("Method = %s" % method) + logger.info("Threshold = %.2f standard deviations" % threshold) # Create empty `hypno_art` vector (1 sample = 1 epoch) - epoch_is_art = np.zeros(n_epochs, dtype='int') + epoch_is_art = np.zeros(n_epochs, dtype="int") - if method == 'covar': + if method == "covar": # Calculate the covariance matrices, # shape (n_epochs, n_chan, n_chan) covmats = Covariances().fit_transform(epochs) @@ -2543,10 +2795,11 @@ def art_detect(data, sf=None, window=5, hypno=None, include=(1, 2, 3, 4), covmats = Shrinkage().fit_transform(covmats) # Define Potato instance: 0 = clean, 1 = art # To increase speed we set the max number of iterations from 10 to 100 - potato = Potato(metric='riemann', threshold=threshold, pos_label=0, - neg_label=1, n_iter_max=10) + potato = Potato( + metric="riemann", threshold=threshold, pos_label=0, neg_label=1, n_iter_max=10 + ) # Create empty z-scores output (n_epochs) - zscores = np.zeros(n_epochs, dtype='float') * np.nan + zscores = np.zeros(n_epochs, dtype="float") * np.nan for stage in include: where_stage = np.where(hypno_win == stage)[0] @@ -2555,9 +2808,11 @@ def art_detect(data, sf=None, window=5, hypno=None, include=(1, 2, 3, 4), if where_stage.size < 30: if hypno is not None: # Only show warnig if user actually pass an hypnogram - logger.warning(f"At least 30 epochs are required to " - f"calculate z-score. Skipping " - f"stage {stage}") + logger.warning( + f"At least 30 epochs are required to " + f"calculate z-score. Skipping " + f"stage {stage}" + ) continue # Apply Potato algorithm, extract z-scores and labels zs = potato.fit_transform(covmats[where_stage]) @@ -2565,20 +2820,22 @@ def art_detect(data, sf=None, window=5, hypno=None, include=(1, 2, 3, 4), if hypno is not None: # Only shows if user actually pass an hypnogram perc_reject = 100 * (art.sum() / art.size) - text = (f"Stage {stage}: {art.sum()} / {art.size} " - f"epochs rejected ({perc_reject:.2f}%)") + text = ( + f"Stage {stage}: {art.sum()} / {art.size} " + f"epochs rejected ({perc_reject:.2f}%)" + ) logger.info(text) # Append to global vector epoch_is_art[where_stage] = art zscores[where_stage] = zs - elif method in ['std', 'sd']: + elif method in ["std", "sd"]: # Calculate log-transformed standard dev in each epoch # We add 1 to avoid log warning id std is zero (e.g. flat line) # (n_epochs, n_chan) std_epochs = np.log(np.nanstd(epochs, axis=-1) + 1) # Create empty zscores output (n_epochs, n_chan) - zscores = np.zeros((n_epochs, n_chan), dtype='float') * np.nan + zscores = np.zeros((n_epochs, n_chan), dtype="float") * np.nan for stage in include: where_stage = np.where(hypno_win == stage)[0] # At least 30 epochs are required to calculate z-scores @@ -2586,9 +2843,11 @@ def art_detect(data, sf=None, window=5, hypno=None, include=(1, 2, 3, 4), if where_stage.size < 30: if hypno is not None: # Only show warnig if user actually pass an hypnogram - logger.warning(f"At least 30 epochs are required to " - f"calculate z-score. Skipping " - f"stage {stage}") + logger.warning( + f"At least 30 epochs are required to " + f"calculate z-score. Skipping " + f"stage {stage}" + ) continue # Calculate z-scores of STD for each channel x stage c_mean = np.nanmean(std_epochs[where_stage], axis=0, keepdims=True) @@ -2600,8 +2859,10 @@ def art_detect(data, sf=None, window=5, hypno=None, include=(1, 2, 3, 4), if hypno is not None: # Only shows if user actually pass an hypnogram perc_reject = 100 * (art.sum() / art.size) - text = (f"Stage {stage}: {art.sum()} / {art.size} " - f"epochs rejected ({perc_reject:.2f}%)") + text = ( + f"Stage {stage}: {art.sum()} / {art.size} " + f"epochs rejected ({perc_reject:.2f}%)" + ) logger.info(text) # Append to global vector epoch_is_art[where_stage] = art @@ -2609,13 +2870,15 @@ def art_detect(data, sf=None, window=5, hypno=None, include=(1, 2, 3, 4), # Mark flat epochs as artefacts if n_flat_epochs > 0: - logger.info(f"Rejecting {n_flat_epochs} epochs with >=50% of channels " - f"that are flat. Z-scores set to np.nan for these epochs.") + logger.info( + f"Rejecting {n_flat_epochs} epochs with >=50% of channels " + f"that are flat. Z-scores set to np.nan for these epochs." + ) epoch_is_art[where_flat_epochs] = 1 # Log total percentage of epochs rejected perc_reject = 100 * (epoch_is_art.sum() / n_epochs) - text = (f"TOTAL: {epoch_is_art.sum()} / {n_epochs} epochs rejected ({perc_reject:.2f}%)") + text = f"TOTAL: {epoch_is_art.sum()} / {n_epochs} epochs rejected ({perc_reject:.2f}%)" logger.info(text) # Convert epoch_is_art to boolean [0, 0, 1] -- > [False, False, True] diff --git a/yasa/features.py b/yasa/features.py index 7a6b5b2..3400504 100644 --- a/yasa/features.py +++ b/yasa/features.py @@ -26,13 +26,14 @@ import scipy.stats as sp_stats -logger = logging.getLogger('yasa') +logger = logging.getLogger("yasa") -__all__ = ['compute_features_stage'] +__all__ = ["compute_features_stage"] -def compute_features_stage(raw, hypno, max_freq=35, spindles_params=dict(), - sw_params=dict(), do_1f=True): +def compute_features_stage( + raw, hypno, max_freq=35, spindles_params=dict(), sw_params=dict(), do_1f=True +): """Calculate a set of features for each sleep stage from PSG data. Features are calculated for N2, N3, NREM (= N2 + N3) and REM sleep. @@ -96,40 +97,45 @@ def compute_features_stage(raw, hypno, max_freq=35, spindles_params=dict(), bands.append(tuple((b, freqs[i + 1], "%s-%s" % (b, freqs[i + 1])))) # Append traditional bands bands_classic = [ - (0.5, 1, 'slowdelta'), (1, 4, 'fastdelta'), (0.5, 4, 'delta'), - (4, 8, 'theta'), (8, 12, 'alpha'), (12, 16, 'sigma'), (16, 30, 'beta'), - (30, max_freq, 'gamma')] + (0.5, 1, "slowdelta"), + (1, 4, "fastdelta"), + (0.5, 4, "delta"), + (4, 8, "theta"), + (8, 12, "alpha"), + (12, 16, "sigma"), + (16, 30, "beta"), + (30, max_freq, "gamma"), + ] bands = bands_classic + bands # Find min and maximum frequencies. These will be used for bandpass-filter # and 1/f adjustement of bandpower. l_freq = 0.5 / h_freq = 35 Hz. - all_freqs_sorted = np.sort(np.unique( - [b[0] for b in bands] + [b[1] for b in bands])) + all_freqs_sorted = np.sort(np.unique([b[0] for b in bands] + [b[1] for b in bands])) l_freq = all_freqs_sorted[0] h_freq = all_freqs_sorted[-1] # Mapping dictionnary integer to string for sleep stages (2 --> N2) stage_mapping = { - -2: 'Unscored', - -1: 'Artefact', - 0: 'Wake', - 1: 'N1', - 2: 'N2', - 3: 'N3', - 4: 'REM', - 6: 'NREM', - 7: 'WN' # Whole night = N2 + N3 + REM + -2: "Unscored", + -1: "Artefact", + 0: "Wake", + 1: "N1", + 2: "N2", + 3: "N3", + 4: "REM", + 6: "NREM", + 7: "WN", # Whole night = N2 + N3 + REM } # Hypnogram check + calculate NREM hypnogram hypno = np.asarray(hypno, dtype=int) - assert hypno.ndim == 1, 'Hypno must be one dimensional.' + assert hypno.ndim == 1, "Hypno must be one dimensional." unique_hypno = np.unique(hypno) - logger.info('Number of unique values in hypno = %i', unique_hypno.size) + logger.info("Number of unique values in hypno = %i", unique_hypno.size) # IMPORTANT: NREM is defined as N2 + N3, excluding N1 sleep. hypno_NREM = pd.Series(hypno).replace({2: 6, 3: 6}).to_numpy() - minutes_of_NREM = (hypno_NREM == 6).sum() / (60 * raw.info['sfreq']) + minutes_of_NREM = (hypno_NREM == 6).sum() / (60 * raw.info["sfreq"]) # WN = Whole night = N2 + N3 + REM (excluding N1) hypno_WN = pd.Series(hypno).replace({2: 7, 3: 7, 4: 7}).to_numpy() @@ -146,10 +152,10 @@ def compute_features_stage(raw, hypno, max_freq=35, spindles_params=dict(), raw_eeg.drop_channels(chan_flat) # Remove suffix from channels: C4-M1 --> C4 - chan_nosuffix = [c.split('-')[0] for c in raw_eeg.ch_names] + chan_nosuffix = [c.split("-")[0] for c in raw_eeg.ch_names] raw_eeg.rename_channels(dict(zip(raw_eeg.ch_names, chan_nosuffix))) # Rename P7/T5 --> P7 - chan_noslash = [c.split('/')[0] for c in raw_eeg.ch_names] + chan_noslash = [c.split("/")[0] for c in raw_eeg.ch_names] raw_eeg.rename_channels(dict(zip(raw_eeg.ch_names, chan_noslash))) chan = raw_eeg.ch_names @@ -159,9 +165,9 @@ def compute_features_stage(raw, hypno, max_freq=35, spindles_params=dict(), # Extract data and sf data = raw_eeg.get_data(units=dict(eeg="uV", emg="uV", eog="uV", ecg="uV")) - sf = raw_eeg.info['sfreq'] - assert data.ndim == 2, 'data must be 2D (chan, times).' - assert hypno.size == data.shape[1], 'Hypno must have same size as data.' + sf = raw_eeg.info["sfreq"] + assert data.ndim == 2, "data must be 2D (chan, times)." + assert hypno.size == data.shape[1], "Hypno must have same size as data." # ######################################################################### # 2) SPECTRAL POWER @@ -176,10 +182,10 @@ def compute_features_stage(raw, hypno, max_freq=35, spindles_params=dict(), df_bp_NREM = yasa.bandpower(raw_eeg, hypno=hypno_NREM, bands=bands, include=6) df_bp_WN = yasa.bandpower(raw_eeg, hypno=hypno_WN, bands=bands, include=7) df_bp = pd.concat([df_bp, df_bp_NREM, df_bp_WN], axis=0) - df_bp.drop(columns=['TotalAbsPow', 'FreqRes', 'Relative'], inplace=True) - df_bp = df_bp.add_prefix('bp_').reset_index() + df_bp.drop(columns=["TotalAbsPow", "FreqRes", "Relative"], inplace=True) + df_bp = df_bp.add_prefix("bp_").reset_index() # Replace 2 --> N2 - df_bp['Stage'] = df_bp['Stage'].map(stage_mapping) + df_bp["Stage"] = df_bp["Stage"].map(stage_mapping) # Assert that there are no negative values (see below issue on 1/f) assert not (df_bp._get_numeric_data() < 0).any().any() df_bp.columns = df_bp.columns.str.lower() @@ -203,26 +209,25 @@ def compute_features_stage(raw, hypno, max_freq=35, spindles_params=dict(), continue # Calculate aperiodic / oscillatory PSD + slope freqs, _, psd_osc, fit_params = yasa.irasa( - data_stage, sf, ch_names=chan, band=(l_freq, h_freq), - win_sec=4) + data_stage, sf, ch_names=chan, band=(l_freq, h_freq), win_sec=4 + ) # Make sure that we don't have any negative values in PSD # See https://github.com/raphaelvallat/yasa/issues/29 psd_osc = psd_osc - psd_osc.min(axis=-1, keepdims=True) # Calculate bandpower - bp = yasa.bandpower_from_psd(psd_osc, freqs, ch_names=chan, - bands=bands) + bp = yasa.bandpower_from_psd(psd_osc, freqs, ch_names=chan, bands=bands) # Add 1/f slope to dataframe and sleep stage - bp['1f_slope'] = np.abs(fit_params['Slope'].to_numpy()) + bp["1f_slope"] = np.abs(fit_params["Slope"].to_numpy()) bp.insert(loc=0, column="Stage", value=stage_mapping[stage]) df_bp_1f.append(bp) # Convert to a dataframe df_bp_1f = pd.concat(df_bp_1f) # Remove the TotalAbsPower column, incorrect because of negative values - df_bp_1f.drop(columns=['TotalAbsPow', 'FreqRes', 'Relative'], - inplace=True) - df_bp_1f.columns = [c if c in ['Stage', 'Chan', '1f_slope'] - else 'bp_adj_' + c for c in df_bp_1f.columns] + df_bp_1f.drop(columns=["TotalAbsPow", "FreqRes", "Relative"], inplace=True) + df_bp_1f.columns = [ + c if c in ["Stage", "Chan", "1f_slope"] else "bp_adj_" + c for c in df_bp_1f.columns + ] assert not (df_bp_1f._get_numeric_data() < 0).any().any() df_bp_1f.columns = df_bp_1f.columns.str.lower() @@ -243,24 +248,25 @@ def compute_features_stage(raw, hypno, max_freq=35, spindles_params=dict(), sp = yasa.spindles_detect(raw_eeg, hypno=hypno, **spindles_params) df_sp = sp.summary(grp_chan=True, grp_stage=True).reset_index() - df_sp['Stage'] = df_sp['Stage'].map(stage_mapping) + df_sp["Stage"] = df_sp["Stage"].map(stage_mapping) # Aggregate using the mean (adding NREM = N2 + N3) df_sp = sp.summary(grp_chan=True, grp_stage=True) df_sp_NREM = sp.summary(grp_chan=True).reset_index() - df_sp_NREM['Stage'] = 6 - df_sp_NREM.set_index(['Stage', 'Channel'], inplace=True) - density_NREM = df_sp_NREM['Count'] / minutes_of_NREM - df_sp_NREM.insert(loc=1, column='Density', value=density_NREM.to_numpy()) + df_sp_NREM["Stage"] = 6 + df_sp_NREM.set_index(["Stage", "Channel"], inplace=True) + density_NREM = df_sp_NREM["Count"] / minutes_of_NREM + df_sp_NREM.insert(loc=1, column="Density", value=density_NREM.to_numpy()) df_sp = pd.concat([df_sp, df_sp_NREM], axis=0) - df_sp.columns = ['sp_' + c if c in ['Count', 'Density'] else - 'sp_mean_' + c for c in df_sp.columns] + df_sp.columns = [ + "sp_" + c if c in ["Count", "Density"] else "sp_mean_" + c for c in df_sp.columns + ] # Prepare to export df_sp.reset_index(inplace=True) - df_sp['Stage'] = df_sp['Stage'].map(stage_mapping) + df_sp["Stage"] = df_sp["Stage"].map(stage_mapping) df_sp.columns = df_sp.columns.str.lower() - df_sp.rename(columns={'channel': 'chan'}, inplace=True) + df_sp.rename(columns={"channel": "chan"}, inplace=True) # ######################################################################### # 4) SLOW-WAVES DETECTION & SW-Sigma COUPLING @@ -280,38 +286,39 @@ def compute_features_stage(raw, hypno, max_freq=35, spindles_params=dict(), df_sw = sw.summary(grp_chan=True, grp_stage=True) # Add NREM df_sw_NREM = sw.summary(grp_chan=True).reset_index() - df_sw_NREM['Stage'] = 6 - df_sw_NREM.set_index(['Stage', 'Channel'], inplace=True) - density_NREM = df_sw_NREM['Count'] / minutes_of_NREM - df_sw_NREM.insert(loc=1, column='Density', value=density_NREM.to_numpy()) + df_sw_NREM["Stage"] = 6 + df_sw_NREM.set_index(["Stage", "Channel"], inplace=True) + density_NREM = df_sw_NREM["Count"] / minutes_of_NREM + df_sw_NREM.insert(loc=1, column="Density", value=density_NREM.to_numpy()) df_sw = pd.concat([df_sw, df_sw_NREM]) - df_sw = df_sw[['Count', 'Density', 'Duration', 'PTP', 'Frequency', 'ndPAC']] - df_sw.columns = ['sw_' + c if c in ['Count', 'Density'] else - 'sw_mean_' + c for c in df_sw.columns] + df_sw = df_sw[["Count", "Density", "Duration", "PTP", "Frequency", "ndPAC"]] + df_sw.columns = [ + "sw_" + c if c in ["Count", "Density"] else "sw_mean_" + c for c in df_sw.columns + ] # Aggregate using the coefficient of variation # The CV is a normalized (unitless) standard deviation. Lower # values mean that slow-waves are more similar to each other. # We keep only spefific columns of interest. Not duration because it # is highly correlated with frequency (r=0.99). - df_sw_cv = sw.summary( - grp_chan=True, grp_stage=True, aggfunc=sp_stats.variation - )[['PTP', 'Frequency', 'ndPAC']] + df_sw_cv = sw.summary(grp_chan=True, grp_stage=True, aggfunc=sp_stats.variation)[ + ["PTP", "Frequency", "ndPAC"] + ] # Add NREM - df_sw_cv_NREM = sw.summary( - grp_chan=True, grp_stage=False, aggfunc=sp_stats.variation - )[['PTP', 'Frequency', 'ndPAC']].reset_index() - df_sw_cv_NREM['Stage'] = 6 - df_sw_cv_NREM.set_index(['Stage', 'Channel'], inplace=True) + df_sw_cv_NREM = sw.summary(grp_chan=True, grp_stage=False, aggfunc=sp_stats.variation)[ + ["PTP", "Frequency", "ndPAC"] + ].reset_index() + df_sw_cv_NREM["Stage"] = 6 + df_sw_cv_NREM.set_index(["Stage", "Channel"], inplace=True) df_sw_cv = pd.concat([df_sw_cv, df_sw_cv_NREM], axis=0) - df_sw_cv.columns = ['sw_cv_' + c for c in df_sw_cv.columns] + df_sw_cv.columns = ["sw_cv_" + c for c in df_sw_cv.columns] # Combine the mean and CV into a single dataframe df_sw = df_sw.join(df_sw_cv).reset_index() - df_sw['Stage'] = df_sw['Stage'].map(stage_mapping) + df_sw["Stage"] = df_sw["Stage"].map(stage_mapping) df_sw.columns = df_sw.columns.str.lower() - df_sw.rename(columns={'channel': 'chan'}, inplace=True) + df_sw.rename(columns={"channel": "chan"}, inplace=True) # ######################################################################### # 5) ENTROPY & FRACTAL DIMENSION @@ -321,13 +328,18 @@ def compute_features_stage(raw, hypno, max_freq=35, spindles_params=dict(), # Filter data in the delta band and calculate envelope for CVE data_delta = mne.filter.filter_data( - data, sfreq=sf, l_freq=0.5, h_freq=4, l_trans_bandwidth=0.2, - h_trans_bandwidth=0.2, verbose=False) + data, + sfreq=sf, + l_freq=0.5, + h_freq=4, + l_trans_bandwidth=0.2, + h_trans_bandwidth=0.2, + verbose=False, + ) env_delta = np.abs(sp_sig.hilbert(data_delta)) # Initialize dataframe - idx_ent = pd.MultiIndex.from_product( - [[2, 3, 4, 6, 7], chan], names=['stage', 'chan']) + idx_ent = pd.MultiIndex.from_product([[2, 3, 4, 6, 7], chan], names=["stage", "chan"]) df_ent = pd.DataFrame(index=idx_ent) for stage in [2, 3, 4, 6, 7]: @@ -356,15 +368,15 @@ def compute_features_stage(raw, hypno, max_freq=35, spindles_params=dict(), # - Sample / app entropy not implemented because it is too slow to # calculate. from numpy import apply_along_axis as aal - df_ent.loc[stage, 'ent_svd'] = aal( - ant.svd_entropy, axis=1, arr=data_stage, normalize=True) - df_ent.loc[stage, 'ent_perm'] = aal( - ant.perm_entropy, axis=1, arr=data_stage, normalize=True) - df_ent.loc[stage, 'ent_spec'] = ant.spectral_entropy( - data_stage, sf, method="welch", nperseg=(5 * int(sf)), - normalize=True, axis=1) - df_ent.loc[stage, 'ent_higuchi'] = aal( - ant.higuchi_fd, axis=1, arr=data_stage) + + df_ent.loc[stage, "ent_svd"] = aal(ant.svd_entropy, axis=1, arr=data_stage, normalize=True) + df_ent.loc[stage, "ent_perm"] = aal( + ant.perm_entropy, axis=1, arr=data_stage, normalize=True + ) + df_ent.loc[stage, "ent_spec"] = ant.spectral_entropy( + data_stage, sf, method="welch", nperseg=(5 * int(sf)), normalize=True, axis=1 + ) + df_ent.loc[stage, "ent_higuchi"] = aal(ant.higuchi_fd, axis=1, arr=data_stage) # We also add the coefficient of variation of the delta envelope # (CVE), a measure of "slow-wave stability". @@ -372,22 +384,18 @@ def compute_features_stage(raw, hypno, max_freq=35, spindles_params=dict(), # Lower values = more stable slow-waves (= more sinusoidal) denom = np.sqrt(4 / np.pi - 1) # approx 0.5227 cve = sp_stats.variation(env_stage_delta, axis=1) / denom - df_ent.loc[stage, 'ent_cve_delta'] = cve + df_ent.loc[stage, "ent_cve_delta"] = cve # Other metrics of slow-wave (= delta) stability - df_ent.loc[stage, 'ent_higuchi_delta'] = aal( - ant.higuchi_fd, axis=1, arr=data_stage_delta) + df_ent.loc[stage, "ent_higuchi_delta"] = aal(ant.higuchi_fd, axis=1, arr=data_stage_delta) df_ent = df_ent.dropna(how="all").reset_index() - df_ent['stage'] = df_ent['stage'].map(stage_mapping) + df_ent["stage"] = df_ent["stage"].map(stage_mapping) # ######################################################################### # 5) MERGE ALL DATAFRAMES # ######################################################################### - df = (df_bp - .merge(df_sp, how='outer') - .merge(df_sw, how='outer') - .merge(df_ent, how='outer')) + df = df_bp.merge(df_sp, how="outer").merge(df_sw, how="outer").merge(df_ent, how="outer") - return df.set_index(['stage', 'chan']) + return df.set_index(["stage", "chan"]) diff --git a/yasa/heart.py b/yasa/heart.py index b6c0a56..2d080ab 100644 --- a/yasa/heart.py +++ b/yasa/heart.py @@ -12,13 +12,22 @@ from .detection import _check_data_hypno from .io import set_log_level, is_sleepecg_installed -logger = logging.getLogger('yasa') - -__all__ = ['hrv_stage'] - - -def hrv_stage(data, sf, *, hypno=None, include=(2, 3, 4), threshold="2min", equal_length=False, - rr_limit=(400, 2000), verbose=False): +logger = logging.getLogger("yasa") + +__all__ = ["hrv_stage"] + + +def hrv_stage( + data, + sf, + *, + hypno=None, + include=(2, 3, 4), + threshold="2min", + equal_length=False, + rr_limit=(400, 2000), + verbose=False, +): """Calculate heart rate and heart rate variability (HRV) features from an ECG. By default, the cardiac features are calculated for each period of N2, N3 or REM sleep that @@ -112,14 +121,16 @@ def hrv_stage(data, sf, *, hypno=None, include=(2, 3, 4), threshold="2min", equa if isinstance(hypno, type(None)): logger.warning( "No hypnogram was passed. The entire recording will be used, i.e. " - "hypno will be set to np.zeros(data.size) and include will be set to 0.") + "hypno will be set to np.zeros(data.size) and include will be set to 0." + ) data = np.asarray(data, dtype=np.float64) hypno = np.zeros(max(data.shape), dtype=int) include = 0 # Safety check - (data, sf, _, hypno, include, _, n_chan, n_samples, _ - ) = _check_data_hypno(data, sf, None, hypno, include, check_amp=False) + (data, sf, _, hypno, include, _, n_chan, n_samples, _) = _check_data_hypno( + data, sf, None, hypno, include, check_amp=False + ) assert n_chan == 1, "data must be a 1D ECG array." data = np.squeeze(data) @@ -128,8 +139,8 @@ def hrv_stage(data, sf, *, hypno=None, include=(2, 3, 4), threshold="2min", equa assert epochs.shape[0] > 0, f"No epochs longer than {threshold} found in hypnogram." epochs = epochs[epochs["values"].isin(include)].reset_index(drop=True) # Sort by stage and add epoch number - epochs = epochs.sort_values(by=['values', 'start']) - epochs['epoch'] = epochs.groupby("values")["start"].transform(lambda x: range(len(x))) + epochs = epochs.sort_values(by=["values", "start"]) + epochs["epoch"] = epochs.groupby("values")["start"].transform(lambda x: range(len(x))) epochs = epochs.set_index(["values", "epoch"]) # Loop over epochs @@ -170,11 +181,11 @@ def hrv_stage(data, sf, *, hypno=None, include=(2, 3, 4), threshold="2min", equa hr = 60000 / rri epochs.loc[idx, "hr_mean"] = np.mean(hr) epochs.loc[idx, "hr_std"] = np.std(hr, ddof=1) - epochs.loc[idx, "hrv_rmssd"] = np.sqrt(np.mean(np.diff(rri)**2)) + epochs.loc[idx, "hrv_rmssd"] = np.sqrt(np.mean(np.diff(rri) ** 2)) # Convert start and duration to seconds - epochs['start'] /= sf - epochs['length'] /= sf + epochs["start"] /= sf + epochs["length"] /= sf epochs = epochs.rename(columns={"length": "duration"}) return epochs, rpeaks diff --git a/yasa/hypno.py b/yasa/hypno.py index f3e74bd..6fbdde2 100644 --- a/yasa/hypno.py +++ b/yasa/hypno.py @@ -33,21 +33,44 @@ import pandas as pd from .io import set_log_level -__all__ = ['hypno_str_to_int', 'hypno_int_to_str', 'hypno_upsample_to_sf', - 'hypno_upsample_to_data', 'hypno_find_periods', 'load_profusion_hypno'] +__all__ = [ + "hypno_str_to_int", + "hypno_int_to_str", + "hypno_upsample_to_sf", + "hypno_upsample_to_data", + "hypno_find_periods", + "load_profusion_hypno", +] -logger = logging.getLogger('yasa') +logger = logging.getLogger("yasa") ############################################################################# # STR <--> INT CONVERSION ############################################################################# -def hypno_str_to_int(hypno, mapping_dict={'w': 0, 'wake': 0, 'n1': 1, 's1': 1, - 'n2': 2, 's2': 2, 'n3': 3, 's3': 3, - 's4': 3, 'r': 4, 'rem': 4, 'art': -1, - 'mt': -1, 'uns': -2, 'nd': -2}): + +def hypno_str_to_int( + hypno, + mapping_dict={ + "w": 0, + "wake": 0, + "n1": 1, + "s1": 1, + "n2": 2, + "s2": 2, + "n3": 3, + "s3": 3, + "s4": 3, + "r": 4, + "rem": 4, + "art": -1, + "mt": -1, + "uns": -2, + "nd": -2, + }, +): """Convert a string hypnogram array to integer. ['W', 'N2', 'N2', 'N3', 'R'] ==> [0, 2, 2, 3, 4] @@ -67,14 +90,15 @@ def hypno_str_to_int(hypno, mapping_dict={'w': 0, 'wake': 0, 'n1': 1, 's1': 1, hypno : array_like The corresponding integer hypnogram. """ - assert isinstance(hypno, (list, np.ndarray, pd.Series)), 'Not an array.' + assert isinstance(hypno, (list, np.ndarray, pd.Series)), "Not an array." hypno = pd.Series(np.asarray(hypno, dtype=str)) - assert not hypno.str.isnumeric().any(), 'Hypno contains numeric values.' + assert not hypno.str.isnumeric().any(), "Hypno contains numeric values." return hypno.str.lower().map(mapping_dict).values -def hypno_int_to_str(hypno, mapping_dict={0: 'W', 1: 'N1', 2: 'N2', 3: 'N3', - 4: 'R', -1: 'Art', -2: 'Uns'}): +def hypno_int_to_str( + hypno, mapping_dict={0: "W", 1: "N1", 2: "N2", 3: "N3", 4: "R", -1: "Art", -2: "Uns"} +): """Convert an integer hypnogram array to a string array. [0, 2, 2, 3, 4] ==> ['W', 'N2', 'N2', 'N3', 'R'] @@ -94,10 +118,11 @@ def hypno_int_to_str(hypno, mapping_dict={0: 'W', 1: 'N1', 2: 'N2', 3: 'N3', hypno : array_like The corresponding integer hypnogram. """ - assert isinstance(hypno, (list, np.ndarray, pd.Series)), 'Not an array.' + assert isinstance(hypno, (list, np.ndarray, pd.Series)), "Not an array." hypno = pd.Series(np.asarray(hypno, dtype=int)) return hypno.map(mapping_dict).values + ############################################################################# # UPSAMPLING ############################################################################# @@ -126,8 +151,8 @@ def hypno_upsample_to_sf(hypno, sf_hypno, sf_data): The hypnogram, upsampled to ``sf_data``. """ repeats = sf_data / sf_hypno - assert sf_hypno <= sf_data, 'sf_hypno must be less than sf_data.' - assert repeats.is_integer(), 'sf_hypno / sf_data must be a whole number.' + assert sf_hypno <= sf_data, "sf_hypno must be less than sf_data." + assert repeats.is_integer(), "sf_hypno / sf_data must be a whole number." assert isinstance(hypno, (list, np.ndarray, pd.Series)) return np.repeat(np.asarray(hypno), repeats) @@ -156,11 +181,11 @@ def hypno_fit_to_data(hypno, data, sf=None): """ # Check if data is an MNE raw object if isinstance(data, mne.io.BaseRaw): - sf = data.info['sfreq'] + sf = data.info["sfreq"] data = data.times # 1D array and does not require to preload data data = np.asarray(data) hypno = np.asarray(hypno) - assert hypno.ndim == 1, 'Hypno must be 1D.' + assert hypno.ndim == 1, "Hypno must be 1D." npts_hyp = hypno.size npts_data = max(data.shape) # Support for 2D data if npts_hyp < npts_data: @@ -168,22 +193,30 @@ def hypno_fit_to_data(hypno, data, sf=None): npts_diff = npts_data - npts_hyp if sf is not None: dur_diff = npts_diff / sf - logger.warning('Hypnogram is SHORTER than data by %.2f seconds. ' - 'Padding hypnogram with last value to match data.size.' % dur_diff) + logger.warning( + "Hypnogram is SHORTER than data by %.2f seconds. " + "Padding hypnogram with last value to match data.size." % dur_diff + ) else: - logger.warning('Hypnogram is SHORTER than data by %i samples. ' - 'Padding hypnogram with last value to match data.size.' % npts_diff) - hypno = np.pad(hypno, (0, npts_diff), mode='edge') + logger.warning( + "Hypnogram is SHORTER than data by %i samples. " + "Padding hypnogram with last value to match data.size." % npts_diff + ) + hypno = np.pad(hypno, (0, npts_diff), mode="edge") elif npts_hyp > npts_data: # Hypnogram is longer than data npts_diff = npts_hyp - npts_data if sf is not None: dur_diff = npts_diff / sf - logger.warning('Hypnogram is LONGER than data by %.2f seconds. ' - 'Cropping hypnogram to match data.size.' % dur_diff) + logger.warning( + "Hypnogram is LONGER than data by %.2f seconds. " + "Cropping hypnogram to match data.size." % dur_diff + ) else: - logger.warning('Hypnogram is LONGER than data by %i samples. ' - 'Cropping hypnogram to match data.size.' % npts_diff) + logger.warning( + "Hypnogram is LONGER than data by %i samples. " + "Cropping hypnogram to match data.size." % npts_diff + ) hypno = hypno[0:npts_data] return hypno @@ -230,7 +263,7 @@ def hypno_upsample_to_data(hypno, sf_hypno, data, sf_data=None, verbose=True): """ set_log_level(verbose) if isinstance(data, mne.io.BaseRaw): - sf_data = data.info['sfreq'] + sf_data = data.info["sfreq"] data = data.times hypno_up = hypno_upsample_to_sf(hypno=hypno, sf_hypno=sf_hypno, sf_data=sf_data) return hypno_fit_to_data(hypno=hypno_up, data=data, sf=sf_data) @@ -240,6 +273,7 @@ def hypno_upsample_to_data(hypno, sf_hypno, data, sf_data=None, verbose=True): # HYPNO LOADING ############################################################################# + def load_profusion_hypno(fname, replace=True): # pragma: no cover """ Load a Compumedics Profusion hypnogram (.xml). @@ -277,6 +311,7 @@ def load_profusion_hypno(fname, replace=True): # pragma: no cover # >>> annotations["Start"] = annotations["Start"].astype(float) # >>> annotations["Duration"] = annotations["Duration"].astype(float) import xml.etree.ElementTree as ET + tree = ET.parse(fname) root = tree.getroot() epoch_length = float(root[0].text) @@ -290,6 +325,7 @@ def load_profusion_hypno(fname, replace=True): # pragma: no cover hypno = pd.Series(hypno).replace({4: 3, 5: 4}).to_numpy() return hypno, sf_hyp + ############################################################################# # PERIODS & CYCLES ############################################################################# @@ -400,7 +436,7 @@ def hypno_find_periods(hypno, sf_hypno, threshold="5min", equal_length=False): # Find run starts # https://gist.github.com/alimanfoo/c5977e87111abe8127453b21204c1065 - assert isinstance(hypno, (list, np.ndarray, pd.Series)), 'hypno must be an array.' + assert isinstance(hypno, (list, np.ndarray, pd.Series)), "hypno must be an array." x = np.asarray(hypno) n = x.shape[0] loc_run_start = np.empty(n, dtype=bool) @@ -411,10 +447,10 @@ def hypno_find_periods(hypno, sf_hypno, threshold="5min", equal_length=False): run_values = x[loc_run_start] # Find run lengths run_lengths = np.diff(np.append(run_starts, n)) - seq = pd.DataFrame({'values': run_values, 'start': run_starts, 'length': run_lengths}) + seq = pd.DataFrame({"values": run_values, "start": run_starts, "length": run_lengths}) # Remove runs that are shorter than threshold - seq = seq[seq['length'] >= thr_samp].reset_index(drop=True) + seq = seq[seq["length"] >= thr_samp].reset_index(drop=True) if not equal_length: return seq @@ -424,19 +460,19 @@ def hypno_find_periods(hypno, sf_hypno, threshold="5min", equal_length=False): new_seq = {"values": [], "start": [], "length": []} for i, row in seq.iterrows(): - quotient, remainder = np.divmod(row['length'], thr_samp) - new_start = row['start'] + quotient, remainder = np.divmod(row["length"], thr_samp) + new_start = row["start"] if quotient > 0: while quotient != 0: - new_seq["values"].append(row['values']) + new_seq["values"].append(row["values"]) new_seq["start"].append(new_start) new_seq["length"].append(thr_samp) new_start += thr_samp quotient -= 1 else: - new_seq["values"].append(row['values']) - new_seq["start"].append(row['start']) - new_seq["length"].append(row['length']) + new_seq["values"].append(row["values"]) + new_seq["start"].append(row["start"]) + new_seq["length"].append(row["length"]) new_seq = pd.DataFrame(new_seq) return new_seq diff --git a/yasa/io.py b/yasa/io.py index 4b8848e..b9a6be5 100644 --- a/yasa/io.py +++ b/yasa/io.py @@ -3,8 +3,13 @@ import logging -LOGGING_TYPES = dict(DEBUG=logging.DEBUG, INFO=logging.INFO, WARNING=logging.WARNING, - ERROR=logging.ERROR, CRITICAL=logging.CRITICAL) +LOGGING_TYPES = dict( + DEBUG=logging.DEBUG, + INFO=logging.INFO, + WARNING=logging.WARNING, + ERROR=logging.ERROR, + CRITICAL=logging.CRITICAL, +) def set_log_level(verbose=None): @@ -19,16 +24,16 @@ def set_log_level(verbose=None): The verbosity of messages to print. If a str, it can be either PROFILER, DEBUG, INFO, WARNING, ERROR, or CRITICAL. """ - logger = logging.getLogger('yasa') + logger = logging.getLogger("yasa") if isinstance(verbose, bool): - verbose = 'INFO' if verbose else 'WARNING' + verbose = "INFO" if verbose else "WARNING" if isinstance(verbose, str): - if (verbose.upper() in LOGGING_TYPES): + if verbose.upper() in LOGGING_TYPES: verbose = verbose.upper() verbose = LOGGING_TYPES[verbose] logger.setLevel(verbose) else: - raise ValueError("verbose must be in %s" % ', '.join(LOGGING_TYPES)) + raise ValueError("verbose must be in %s" % ", ".join(LOGGING_TYPES)) def is_tensorpac_installed(): diff --git a/yasa/numba.py b/yasa/numba.py index c62e07f..8495cce 100644 --- a/yasa/numba.py +++ b/yasa/numba.py @@ -12,7 +12,7 @@ ############################################################################# -@jit('float64(float64[:], float64[:])', nopython=True) +@jit("float64(float64[:], float64[:])", nopython=True) def _corr(x, y): """Fast Pearson correlation.""" n = x.size @@ -21,7 +21,7 @@ def _corr(x, y): for i in range(n): xm = x[i] - mx ym = y[i] - my - r_num += (xm * ym) + r_num += xm * ym xm2s += xm**2 ym2s += ym**2 r_d1 = np.sqrt(xm2s) @@ -30,7 +30,7 @@ def _corr(x, y): return r_num / r_den -@jit('float64(float64[:], float64[:])', nopython=True) +@jit("float64(float64[:], float64[:])", nopython=True) def _covar(x, y): """Fast Covariance.""" n = x.size @@ -39,25 +39,24 @@ def _covar(x, y): for i in range(n): xm = x[i] - mx ym = y[i] - my - cov += (xm * ym) + cov += xm * ym return cov / (n - 1) -@jit('float64(float64[:])', nopython=True) +@jit("float64(float64[:])", nopython=True) def _rms(x): """Fast root mean square.""" n = x.size ms = 0 for i in range(n): - ms += x[i]**2 + ms += x[i] ** 2 ms /= n return np.sqrt(ms) -@jit('float64(float64[:], float64[:])', nopython=True) +@jit("float64(float64[:], float64[:])", nopython=True) def _slope_lstsq(x, y): - """Slope of a 1D least-squares regression. - """ + """Slope of a 1D least-squares regression.""" n_times = x.shape[0] sx2 = 0 sx = 0 @@ -68,15 +67,14 @@ def _slope_lstsq(x, y): sx += x[j] sxy += x[j] * y[j] sy += y[j] - den = n_times * sx2 - (sx ** 2) + den = n_times * sx2 - (sx**2) num = n_times * sxy - sx * sy return num / den -@jit('float64[:](float64[:], float64[:])', nopython=True) +@jit("float64[:](float64[:], float64[:])", nopython=True) def _detrend(x, y): - """Fast linear detrending. - """ + """Fast linear detrending.""" slope = _slope_lstsq(x, y) intercept = y.mean() - x.mean() * slope return y - (x * slope + intercept) diff --git a/yasa/others.py b/yasa/others.py index 9f2483d..2a1c6a0 100644 --- a/yasa/others.py +++ b/yasa/others.py @@ -6,10 +6,9 @@ from scipy.interpolate import interp1d from .numba import _slope_lstsq, _covar, _corr, _rms -logger = logging.getLogger('yasa') +logger = logging.getLogger("yasa") -__all__ = ['moving_transform', 'trimbothstd', 'sliding_window', - 'get_centered_indices'] +__all__ = ["moving_transform", "trimbothstd", "sliding_window", "get_centered_indices"] def _merge_close(index, min_distance_ms, sf): @@ -35,7 +34,7 @@ def _merge_close(index, min_distance_ms, sf): Original code imported from the Visbrain package. """ # Convert min_distance_ms - min_distance = min_distance_ms / 1000. * sf + min_distance = min_distance_ms / 1000.0 * sf idx_diff = np.diff(index) condition = idx_diff > 1 idx_distance = np.where(condition)[0] @@ -43,8 +42,7 @@ def _merge_close(index, min_distance_ms, sf): bad = idx_distance[np.where(distance < min_distance)[0]] # Fill gap between events separated with less than min_distance_ms if len(bad) > 0: - fill = np.hstack([np.arange(index[j] + 1, index[j + 1]) - for i, j in enumerate(bad)]) + fill = np.hstack([np.arange(index[j] + 1, index[j + 1]) for i, j in enumerate(bad)]) f_index = np.sort(np.append(index, fill)) return f_index else: @@ -77,8 +75,7 @@ def _index_to_events(x): return index -def moving_transform(x, y=None, sf=100, window=.3, step=.1, method='corr', - interp=False): +def moving_transform(x, y=None, sf=100, window=0.3, step=0.1, method="corr", interp=False): """Moving transformation of one or two time-series. Parameters @@ -127,8 +124,17 @@ def moving_transform(x, y=None, sf=100, window=.3, step=.1, method='corr', Wonambi package (https://github.com/wonambi-python/wonambi). """ # Safety checks - assert method in ['mean', 'min', 'max', 'ptp', 'rms', - 'prop_above_zero', 'slope', 'covar', 'corr'] + assert method in [ + "mean", + "min", + "max", + "ptp", + "rms", + "prop_above_zero", + "slope", + "covar", + "corr", + ] x = np.asarray(x, dtype=np.float64) if y is not None: y = np.asarray(y, dtype=np.float64) @@ -154,55 +160,63 @@ def moving_transform(x, y=None, sf=100, window=.3, step=.1, method='corr', # beg, end = beg[mask], end[mask] t = np.column_stack((beg, end)).mean(1) / sf - if method == 'mean': + if method == "mean": + def func(x): return np.mean(x) - elif method == 'min': + elif method == "min": + def func(x): return np.min(x) - elif method == 'max': + elif method == "max": + def func(x): return np.max(x) - elif method == 'ptp': + elif method == "ptp": + def func(x): return np.ptp(x) - elif method == 'prop_above_zero': + elif method == "prop_above_zero": + def func(x): return np.count_nonzero(x >= 0) / x.size - elif method == 'slope': + elif method == "slope": + def func(x): times = np.arange(x.size, dtype=np.float64) / sf return _slope_lstsq(times, x) - elif method == 'covar': + elif method == "covar": + def func(x, y): return _covar(x, y) - elif method == 'corr': + elif method == "corr": + def func(x, y): return _corr(x, y) else: + def func(x): return _rms(x) # Now loop over successive epochs - if method in ['covar', 'corr']: + if method in ["covar", "corr"]: for i in range(idx.size): - out[i] = func(x[beg[i]:end[i]], y[beg[i]:end[i]]) + out[i] = func(x[beg[i] : end[i]], y[beg[i] : end[i]]) else: for i in range(idx.size): - out[i] = func(x[beg[i]:end[i]]) + out[i] = func(x[beg[i] : end[i]]) # Finally interpolate if interp and step != 1 / sf: - f = interp1d(t, out, kind='cubic', bounds_error=False, - fill_value=0, assume_sorted=True) + f = interp1d(t, out, kind="cubic", bounds_error=False, fill_value=0, assume_sorted=True) t = np.arange(n) / sf out = f(t) @@ -353,36 +367,34 @@ def sliding_window(data, sf, window, step=None, axis=-1): [-51, 3, 31, -99, 33, -47, 5, -97, -47, 90]]]) """ from numpy.lib.stride_tricks import as_strided + assert axis <= data.ndim, "Axis value out of range." - assert isinstance(sf, (int, float)), 'sf must be int or float' - assert isinstance(window, (int, float)), 'window must be int or float' - assert isinstance(step, (int, float, type(None))), ('step must be int, ' - 'float or None.') + assert isinstance(sf, (int, float)), "sf must be int or float" + assert isinstance(window, (int, float)), "window must be int or float" + assert isinstance(step, (int, float, type(None))), "step must be int, " "float or None." if isinstance(sf, float): - assert sf.is_integer(), 'sf must be a whole number.' + assert sf.is_integer(), "sf must be a whole number." sf = int(sf) - assert isinstance(axis, int), 'axis must be int.' + assert isinstance(axis, int), "axis must be int." # window and step in samples instead of points window *= sf step = window if step is None else step * sf if isinstance(window, float): - assert window.is_integer(), 'window * sf must be a whole number.' + assert window.is_integer(), "window * sf must be a whole number." window = int(window) if isinstance(step, float): - assert step.is_integer(), 'step * sf must be a whole number.' + assert step.is_integer(), "step * sf must be a whole number." step = int(step) assert step >= 1, "Stepsize may not be zero or negative." - assert window < data.shape[axis], ("Sliding window size may not exceed " - "size of selected axis") + assert window < data.shape[axis], "Sliding window size may not exceed " "size of selected axis" # Define output shape shape = list(data.shape) - shape[axis] = np.floor(data.shape[axis] / step - window / step + 1 - ).astype(int) + shape[axis] = np.floor(data.shape[axis] / step - window / step + 1).astype(int) shape.append(window) # Calculate strides and time vector @@ -453,13 +465,13 @@ def get_centered_indices(data, idx, npts_before, npts_after): npts_before = int(npts_before) npts_after = int(npts_after) data = np.asarray(data) - idx = np.asarray(idx, dtype='int') + idx = np.asarray(idx, dtype="int") assert idx.ndim == 1, "idx must be 1D." assert data.ndim == 1, "data must be 1D." def rng(x): """Create a range before and after a given value.""" - return np.arange(x - npts_before, x + npts_after + 1, dtype='int') + return np.arange(x - npts_before, x + npts_after + 1, dtype="int") idx_ep = np.apply_along_axis(rng, 1, idx[..., np.newaxis]) # We drop the events for which the indices exceed data diff --git a/yasa/plotting.py b/yasa/plotting.py index 1b61aa3..18095a9 100644 --- a/yasa/plotting.py +++ b/yasa/plotting.py @@ -9,10 +9,10 @@ from lspopt import spectrogram_lspopt from matplotlib.colors import Normalize, ListedColormap -__all__ = ['plot_hypnogram', 'plot_spectrogram', 'topoplot'] +__all__ = ["plot_hypnogram", "plot_spectrogram", "topoplot"] -def plot_hypnogram(hypno, sf_hypno=1/30, lw=1.5, figsize=(9, 3)): +def plot_hypnogram(hypno, sf_hypno=1 / 30, lw=1.5, figsize=(9, 3)): """ Plot a hypnogram. @@ -58,15 +58,15 @@ def plot_hypnogram(hypno, sf_hypno=1/30, lw=1.5, figsize=(9, 3)): >>> ax = yasa.plot_hypnogram(hypno) """ # Increase font size while preserving original - old_fontsize = plt.rcParams['font.size'] - plt.rcParams.update({'font.size': 18}) + old_fontsize = plt.rcParams["font.size"] + plt.rcParams.update({"font.size": 18}) # Safety checks - assert isinstance(hypno, (np.ndarray, pd.Series, list)), 'hypno must be an array.' + assert isinstance(hypno, (np.ndarray, pd.Series, list)), "hypno must be an array." hypno = np.asarray(hypno).astype(int) assert (hypno >= -2).all() and (hypno <= 4).all(), "hypno values must be between -2 to 4." - assert hypno.ndim == 1, 'hypno must be a 1D array.' - assert isinstance(sf_hypno, (int, float)), 'sf must be int or float.' + assert hypno.ndim == 1, "hypno must be a 1D array." + assert isinstance(sf_hypno, (int, float)), "sf must be int or float." t_hyp = np.arange(hypno.size) / (sf_hypno * 3600) # Make sure that REM is displayed after Wake @@ -77,40 +77,42 @@ def plot_hypnogram(hypno, sf_hypno=1/30, lw=1.5, figsize=(9, 3)): fig, ax0 = plt.subplots(nrows=1, figsize=figsize) # Hypnogram (top axis) - ax0.step(t_hyp, -1 * hypno, color='k', lw=lw) - ax0.step(t_hyp, -1 * hypno_rem, color='red', lw=lw) - ax0.step(t_hyp, -1 * hypno_art_uns, color='grey', lw=lw) + ax0.step(t_hyp, -1 * hypno, color="k", lw=lw) + ax0.step(t_hyp, -1 * hypno_rem, color="red", lw=lw) + ax0.step(t_hyp, -1 * hypno_art_uns, color="grey", lw=lw) if -2 in hypno and -1 in hypno: # Both Unscored and Artefacts are present ax0.set_yticks([2, 1, 0, -1, -2, -3, -4]) - ax0.set_yticklabels(['Uns', 'Art', 'W', 'R', 'N1', 'N2', 'N3']) + ax0.set_yticklabels(["Uns", "Art", "W", "R", "N1", "N2", "N3"]) ax0.set_ylim(-4.5, 2.5) elif -2 in hypno and -1 not in hypno: # Only Unscored are present ax0.set_yticks([2, 0, -1, -2, -3, -4]) - ax0.set_yticklabels(['Uns', 'W', 'R', 'N1', 'N2', 'N3']) + ax0.set_yticklabels(["Uns", "W", "R", "N1", "N2", "N3"]) ax0.set_ylim(-4.5, 2.5) elif -2 not in hypno and -1 in hypno: # Only Artefacts are present ax0.set_yticks([1, 0, -1, -2, -3, -4]) - ax0.set_yticklabels(['Art', 'W', 'R', 'N1', 'N2', 'N3']) + ax0.set_yticklabels(["Art", "W", "R", "N1", "N2", "N3"]) ax0.set_ylim(-4.5, 1.5) else: # No artefacts or Unscored ax0.set_yticks([0, -1, -2, -3, -4]) - ax0.set_yticklabels(['W', 'R', 'N1', 'N2', 'N3']) + ax0.set_yticklabels(["W", "R", "N1", "N2", "N3"]) ax0.set_ylim(-4.5, 0.5) ax0.set_xlim(0, t_hyp.max()) - ax0.set_ylabel('Stage') - ax0.set_xlabel('Time [hrs]') - ax0.spines['right'].set_visible(False) - ax0.spines['top'].set_visible(False) + ax0.set_ylabel("Stage") + ax0.set_xlabel("Time [hrs]") + ax0.spines["right"].set_visible(False) + ax0.spines["top"].set_visible(False) # Revert font-size - plt.rcParams.update({'font.size': old_fontsize}) + plt.rcParams.update({"font.size": old_fontsize}) return ax0 -def plot_spectrogram(data, sf, hypno=None, win_sec=30, fmin=0.5, fmax=25, - trimperc=2.5, cmap='RdBu_r'): + +def plot_spectrogram( + data, sf, hypno=None, win_sec=30, fmin=0.5, fmax=25, trimperc=2.5, cmap="RdBu_r" +): """ Plot a full-night multi-taper spectrogram, optionally with the hypnogram on top. @@ -198,22 +200,22 @@ def plot_spectrogram(data, sf, hypno=None, win_sec=30, fmin=0.5, fmax=25, >>> fig = yasa.plot_spectrogram(data, sf, hypno, cmap='Spectral_r') """ # Increase font size while preserving original - old_fontsize = plt.rcParams['font.size'] - plt.rcParams.update({'font.size': 18}) + old_fontsize = plt.rcParams["font.size"] + plt.rcParams.update({"font.size": 18}) # Safety checks - assert isinstance(data, np.ndarray), 'Data must be a 1D NumPy array.' - assert isinstance(sf, (int, float)), 'sf must be int or float.' - assert data.ndim == 1, 'Data must be a 1D (single-channel) NumPy array.' - assert isinstance(win_sec, (int, float)), 'win_sec must be int or float.' - assert isinstance(fmin, (int, float)), 'fmin must be int or float.' - assert isinstance(fmax, (int, float)), 'fmax must be int or float.' - assert fmin < fmax, 'fmin must be strictly inferior to fmax.' - assert fmax < sf / 2, 'fmax must be less than Nyquist (sf / 2).' + assert isinstance(data, np.ndarray), "Data must be a 1D NumPy array." + assert isinstance(sf, (int, float)), "sf must be int or float." + assert data.ndim == 1, "Data must be a 1D (single-channel) NumPy array." + assert isinstance(win_sec, (int, float)), "win_sec must be int or float." + assert isinstance(fmin, (int, float)), "fmin must be int or float." + assert isinstance(fmax, (int, float)), "fmax must be int or float." + assert fmin < fmax, "fmin must be strictly inferior to fmax." + assert fmax < sf / 2, "fmax must be less than Nyquist (sf / 2)." # Calculate multi-taper spectrogram nperseg = int(win_sec * sf) - assert data.size > 2 * nperseg, 'Data length must be at least 2 * win_sec.' + assert data.size > 2 * nperseg, "Data length must be at least 2 * win_sec." f, t, Sxx = spectrogram_lspopt(data, sf, nperseg=nperseg, noverlap=0) Sxx = 10 * np.log10(Sxx) # Convert uV^2 / Hz --> dB / Hz @@ -231,70 +233,84 @@ def plot_spectrogram(data, sf, hypno=None, win_sec=30, fmin=0.5, fmax=25, fig, ax = plt.subplots(nrows=1, figsize=(12, 4)) im = ax.pcolormesh(t, f, Sxx, norm=norm, cmap=cmap, antialiased=True, shading="auto") ax.set_xlim(0, t.max()) - ax.set_ylabel('Frequency [Hz]') - ax.set_xlabel('Time [hrs]') + ax.set_ylabel("Frequency [Hz]") + ax.set_xlabel("Time [hrs]") # Add colorbar cbar = fig.colorbar(im, ax=ax, shrink=0.95, fraction=0.1, aspect=25) - cbar.ax.set_ylabel('Log Power (dB / Hz)', rotation=270, labelpad=20) + cbar.ax.set_ylabel("Log Power (dB / Hz)", rotation=270, labelpad=20) return fig else: hypno = np.asarray(hypno).astype(int) - assert hypno.ndim == 1, 'Hypno must be 1D.' - assert hypno.size == data.size, 'Hypno must have the same sf as data.' + assert hypno.ndim == 1, "Hypno must be 1D." + assert hypno.size == data.size, "Hypno must have the same sf as data." t_hyp = np.arange(hypno.size) / (sf * 3600) # Make sure that REM is displayed after Wake hypno = pd.Series(hypno).map({-2: -2, -1: -1, 0: 0, 1: 2, 2: 3, 3: 4, 4: 1}).values hypno_rem = np.ma.masked_not_equal(hypno, 1) fig, (ax0, ax1) = plt.subplots( - nrows=2, figsize=(12, 6), gridspec_kw={'height_ratios': [1, 2]}) + nrows=2, figsize=(12, 6), gridspec_kw={"height_ratios": [1, 2]} + ) plt.subplots_adjust(hspace=0.1) # Hypnogram (top axis) - ax0.step(t_hyp, -1 * hypno, color='k') - ax0.step(t_hyp, -1 * hypno_rem, color='r') + ax0.step(t_hyp, -1 * hypno, color="k") + ax0.step(t_hyp, -1 * hypno_rem, color="r") if -2 in hypno and -1 in hypno: # Both Unscored and Artefacts are present ax0.set_yticks([2, 1, 0, -1, -2, -3, -4]) - ax0.set_yticklabels(['Uns', 'Art', 'W', 'R', 'N1', 'N2', 'N3']) + ax0.set_yticklabels(["Uns", "Art", "W", "R", "N1", "N2", "N3"]) ax0.set_ylim(-4.5, 2.5) elif -2 in hypno and -1 not in hypno: # Only Unscored are present ax0.set_yticks([2, 0, -1, -2, -3, -4]) - ax0.set_yticklabels(['Uns', 'W', 'R', 'N1', 'N2', 'N3']) + ax0.set_yticklabels(["Uns", "W", "R", "N1", "N2", "N3"]) ax0.set_ylim(-4.5, 2.5) elif -2 not in hypno and -1 in hypno: # Only Artefacts are present ax0.set_yticks([1, 0, -1, -2, -3, -4]) - ax0.set_yticklabels(['Art', 'W', 'R', 'N1', 'N2', 'N3']) + ax0.set_yticklabels(["Art", "W", "R", "N1", "N2", "N3"]) ax0.set_ylim(-4.5, 1.5) else: # No artefacts or Unscored ax0.set_yticks([0, -1, -2, -3, -4]) - ax0.set_yticklabels(['W', 'R', 'N1', 'N2', 'N3']) + ax0.set_yticklabels(["W", "R", "N1", "N2", "N3"]) ax0.set_ylim(-4.5, 0.5) ax0.set_xlim(0, t_hyp.max()) - ax0.set_ylabel('Stage') + ax0.set_ylabel("Stage") ax0.xaxis.set_visible(False) - ax0.spines['right'].set_visible(False) - ax0.spines['top'].set_visible(False) + ax0.spines["right"].set_visible(False) + ax0.spines["top"].set_visible(False) # Spectrogram (bottom axis) im = ax1.pcolormesh(t, f, Sxx, norm=norm, cmap=cmap, antialiased=True, shading="auto") ax1.set_xlim(0, t.max()) - ax1.set_ylabel('Frequency [Hz]') - ax1.set_xlabel('Time [hrs]') + ax1.set_ylabel("Frequency [Hz]") + ax1.set_xlabel("Time [hrs]") # Revert font-size - plt.rcParams.update({'font.size': old_fontsize}) + plt.rcParams.update({"font.size": old_fontsize}) return fig -def topoplot(data, montage="standard_1020", vmin=None, vmax=None, mask=None, title=None, - cmap=None, n_colors=100, cbar_title=None, cbar_ticks=None, figsize=(4, 4), dpi=80, - fontsize=14, **kwargs): +def topoplot( + data, + montage="standard_1020", + vmin=None, + vmax=None, + mask=None, + title=None, + cmap=None, + n_colors=100, + cbar_title=None, + cbar_ticks=None, + figsize=(4, 4), + dpi=80, + fontsize=14, + **kwargs +): """ Topoplot. @@ -370,19 +386,19 @@ def topoplot(data, montage="standard_1020", vmin=None, vmax=None, mask=None, tit ... cbar_title="Pearson correlation") """ # Increase font size while preserving original - old_fontsize = plt.rcParams['font.size'] - plt.rcParams.update({'font.size': fontsize}) - plt.rcParams.update({'savefig.bbox': 'tight'}) - plt.rcParams.update({'savefig.transparent': 'True'}) + old_fontsize = plt.rcParams["font.size"] + plt.rcParams.update({"font.size": fontsize}) + plt.rcParams.update({"savefig.bbox": "tight"}) + plt.rcParams.update({"savefig.transparent": "True"}) # Make sure we don't do any in-place modification - assert isinstance(data, pd.Series), 'Data must be a Pandas Series' + assert isinstance(data, pd.Series), "Data must be a Pandas Series" data = data.copy() # Add mask, if present if mask is not None: - assert isinstance(mask, pd.Series), 'mask must be a Pandas Series' - assert mask.dtype.kind in 'bi', "mask must be True/False or 0/1." + assert isinstance(mask, pd.Series), "mask must be a Pandas Series" + assert mask.dtype.kind in "bi", "mask must be True/False or 0/1." else: mask = pd.Series(1, index=data.index, name="mask") @@ -390,11 +406,11 @@ def topoplot(data, montage="standard_1020", vmin=None, vmax=None, mask=None, tit data = data.to_frame().join(mask, how="left") # Preprocess channel names: C4-M1 --> C4 - data.index = data.index.str.split('-').str.get(0) + data.index = data.index.str.split("-").str.get(0) # Define electrodes coordinates - Info = mne.create_info(data.index.tolist(), sfreq=100, ch_types='eeg') - Info.set_montage(montage, match_case=False, on_missing='ignore') + Info = mne.create_info(data.index.tolist(), sfreq=100, ch_types="eeg") + Info.set_montage(montage, match_case=False, on_missing="ignore") chan = Info.ch_names # Define vmin and vmax @@ -406,36 +422,43 @@ def topoplot(data, montage="standard_1020", vmin=None, vmax=None, mask=None, tit # Choose and discretize colormap if cmap is None: if vmin < 0 and vmax <= 0: - cmap = 'mako' + cmap = "mako" elif vmin < 0 and vmax > 0: - cmap = 'Spectral_r' + cmap = "Spectral_r" elif vmin >= 0 and vmax > 0: - cmap = 'rocket_r' + cmap = "rocket_r" cmap = ListedColormap(sns.color_palette(cmap, n_colors).as_hex()) - if 'sensors' not in kwargs: - kwargs['sensors'] = False - if 'res' not in kwargs: - kwargs['res'] = 256 - if 'names' not in kwargs: - kwargs['names'] = chan - if 'show_names' not in kwargs: - kwargs['show_names'] = True - if 'mask_params' not in kwargs: - kwargs['mask_params'] = dict(marker=None) + if "sensors" not in kwargs: + kwargs["sensors"] = False + if "res" not in kwargs: + kwargs["res"] = 256 + if "names" not in kwargs: + kwargs["names"] = chan + if "show_names" not in kwargs: + kwargs["show_names"] = True + if "mask_params" not in kwargs: + kwargs["mask_params"] = dict(marker=None) # Hidden feature: if names='values', show the actual values. - if kwargs['names'] == 'values': - kwargs['names'] = data.iloc[:, 0][chan].round(2).to_numpy() + if kwargs["names"] == "values": + kwargs["names"] = data.iloc[:, 0][chan].round(2).to_numpy() # Start the plot with sns.axes_style("white"): fig, ax = plt.subplots(figsize=figsize, dpi=dpi) im, _ = mne.viz.plot_topomap( - data=data.iloc[:, 0][chan], pos=Info, vmin=vmin, vmax=vmax, - mask=data.iloc[:, 1][chan], cmap=cmap, show=False, axes=ax, - **kwargs) + data=data.iloc[:, 0][chan], + pos=Info, + vmin=vmin, + vmax=vmax, + mask=data.iloc[:, 1][chan], + cmap=cmap, + show=False, + axes=ax, + **kwargs + ) if title is not None: ax.set_title(title) @@ -449,5 +472,5 @@ def topoplot(data, montage="standard_1020", vmin=None, vmax=None, mask=None, tit cbar.set_label(cbar_title) # Revert font-size - plt.rcParams.update({'font.size': old_fontsize}) + plt.rcParams.update({"font.size": old_fontsize}) return fig diff --git a/yasa/sleepstats.py b/yasa/sleepstats.py index 1baa13e..3cb5ac2 100644 --- a/yasa/sleepstats.py +++ b/yasa/sleepstats.py @@ -5,13 +5,14 @@ import numpy as np import pandas as pd -__all__ = ['transition_matrix', 'sleep_statistics'] +__all__ = ["transition_matrix", "sleep_statistics"] ############################################################################# # TRANSITION MATRIX ############################################################################# + def transition_matrix(hypno): """Create a state-transition matrix from an hypnogram. @@ -108,12 +109,13 @@ def transition_matrix(hypno): # Convert to a Pandas DataFrame counts = pd.DataFrame(counts, index=unique, columns=unique) probs = pd.DataFrame(probs, index=unique, columns=unique) - counts.index.name = 'From Stage' - probs.index.name = 'From Stage' - counts.columns.name = 'To Stage' - probs.columns.name = 'To Stage' + counts.index.name = "From Stage" + probs.index.name = "From Stage" + counts.columns.name = "To Stage" + probs.columns.name = "To Stage" return counts, probs + ############################################################################# # SLEEP STATISTICS ############################################################################# @@ -217,46 +219,46 @@ def sleep_statistics(hypno, sf_hyp): """ stats = {} hypno = np.asarray(hypno) - assert hypno.ndim == 1, 'hypno must have only one dimension.' - assert hypno.size > 1, 'hypno must have at least two elements.' + assert hypno.ndim == 1, "hypno must have only one dimension." + assert hypno.size > 1, "hypno must have at least two elements." # TIB, first and last sleep - stats['TIB'] = len(hypno) + stats["TIB"] = len(hypno) first_sleep = np.where(hypno > 0)[0][0] last_sleep = np.where(hypno > 0)[0][-1] # Crop to SPT - hypno_s = hypno[first_sleep:(last_sleep + 1)] - stats['SPT'] = hypno_s.size - stats['WASO'] = hypno_s[hypno_s == 0].size + hypno_s = hypno[first_sleep : (last_sleep + 1)] + stats["SPT"] = hypno_s.size + stats["WASO"] = hypno_s[hypno_s == 0].size # Before YASA v0.5.0, TST was calculated as SPT - WASO, meaning that Art # and Unscored epochs were included. TST is now restrained to sleep stages. - stats['TST'] = hypno_s[hypno_s > 0].size + stats["TST"] = hypno_s[hypno_s > 0].size # Duration of each sleep stages - stats['N1'] = hypno[hypno == 1].size - stats['N2'] = hypno[hypno == 2].size - stats['N3'] = hypno[hypno == 3].size - stats['REM'] = hypno[hypno == 4].size - stats['NREM'] = stats['N1'] + stats['N2'] + stats['N3'] + stats["N1"] = hypno[hypno == 1].size + stats["N2"] = hypno[hypno == 2].size + stats["N3"] = hypno[hypno == 3].size + stats["REM"] = hypno[hypno == 4].size + stats["NREM"] = stats["N1"] + stats["N2"] + stats["N3"] # Sleep stage latencies -- only relevant if hypno is cropped to TIB - stats['SOL'] = first_sleep - stats['Lat_N1'] = np.where(hypno == 1)[0].min() if 1 in hypno else np.nan - stats['Lat_N2'] = np.where(hypno == 2)[0].min() if 2 in hypno else np.nan - stats['Lat_N3'] = np.where(hypno == 3)[0].min() if 3 in hypno else np.nan - stats['Lat_REM'] = np.where(hypno == 4)[0].min() if 4 in hypno else np.nan + stats["SOL"] = first_sleep + stats["Lat_N1"] = np.where(hypno == 1)[0].min() if 1 in hypno else np.nan + stats["Lat_N2"] = np.where(hypno == 2)[0].min() if 2 in hypno else np.nan + stats["Lat_N3"] = np.where(hypno == 3)[0].min() if 3 in hypno else np.nan + stats["Lat_REM"] = np.where(hypno == 4)[0].min() if 4 in hypno else np.nan # Convert to minutes for key, value in stats.items(): stats[key] = value / (60 * sf_hyp) # Percentage - stats['%N1'] = 100 * stats['N1'] / stats['TST'] - stats['%N2'] = 100 * stats['N2'] / stats['TST'] - stats['%N3'] = 100 * stats['N3'] / stats['TST'] - stats['%REM'] = 100 * stats['REM'] / stats['TST'] - stats['%NREM'] = 100 * stats['NREM'] / stats['TST'] - stats['SE'] = 100 * stats['TST'] / stats['TIB'] - stats['SME'] = 100 * stats['TST'] / stats['SPT'] + stats["%N1"] = 100 * stats["N1"] / stats["TST"] + stats["%N2"] = 100 * stats["N2"] / stats["TST"] + stats["%N3"] = 100 * stats["N3"] / stats["TST"] + stats["%REM"] = 100 * stats["REM"] / stats["TST"] + stats["%NREM"] = 100 * stats["NREM"] / stats["TST"] + stats["SE"] = 100 * stats["TST"] / stats["TIB"] + stats["SME"] = 100 * stats["TST"] / stats["SPT"] return stats diff --git a/yasa/spectral.py b/yasa/spectral.py index 9bb909e..374835f 100644 --- a/yasa/spectral.py +++ b/yasa/spectral.py @@ -11,17 +11,30 @@ from scipy.interpolate import RectBivariateSpline from .io import set_log_level -logger = logging.getLogger('yasa') - -__all__ = ['bandpower', 'bandpower_from_psd', 'bandpower_from_psd_ndarray', - 'irasa', 'stft_power'] - - -def bandpower(data, sf=None, ch_names=None, hypno=None, include=(2, 3), - win_sec=4, relative=True, bandpass=False, - bands=[(0.5, 4, 'Delta'), (4, 8, 'Theta'), (8, 12, 'Alpha'), - (12, 16, 'Sigma'), (16, 30, 'Beta'), (30, 40, 'Gamma')], - kwargs_welch=dict(average='median', window='hamming')): +logger = logging.getLogger("yasa") + +__all__ = ["bandpower", "bandpower_from_psd", "bandpower_from_psd_ndarray", "irasa", "stft_power"] + + +def bandpower( + data, + sf=None, + ch_names=None, + hypno=None, + include=(2, 3), + win_sec=4, + relative=True, + bandpass=False, + bands=[ + (0.5, 4, "Delta"), + (4, 8, "Theta"), + (8, 12, "Alpha"), + (12, 16, "Sigma"), + (16, 30, "Beta"), + (30, 40, "Gamma"), + ], + kwargs_welch=dict(average="median", window="hamming"), +): """ Calculate the Welch bandpower for each channel and, if specified, for each sleep stage. @@ -88,74 +101,85 @@ def bandpower(data, sf=None, ch_names=None, hypno=None, include=(2, 3), https://github.com/raphaelvallat/yasa/blob/master/notebooks/08_bandpower.ipynb """ # Type checks - assert isinstance(bands, list), 'bands must be a list of tuple(s)' - assert isinstance(relative, bool), 'relative must be a boolean' - assert isinstance(bandpass, bool), 'bandpass must be a boolean' + assert isinstance(bands, list), "bands must be a list of tuple(s)" + assert isinstance(relative, bool), "relative must be a boolean" + assert isinstance(bandpass, bool), "bandpass must be a boolean" # Check if input data is a MNE Raw object if isinstance(data, mne.io.BaseRaw): - sf = data.info['sfreq'] # Extract sampling frequency + sf = data.info["sfreq"] # Extract sampling frequency ch_names = data.ch_names # Extract channel names data = data.get_data(units=dict(eeg="uV", emg="uV", eog="uV", ecg="uV")) _, npts = data.shape else: # Safety checks - assert isinstance(data, np.ndarray), 'Data must be a numpy array.' + assert isinstance(data, np.ndarray), "Data must be a numpy array." data = np.atleast_2d(data) - assert data.ndim == 2, 'Data must be of shape (nchan, n_samples).' + assert data.ndim == 2, "Data must be of shape (nchan, n_samples)." nchan, npts = data.shape # assert nchan < npts, 'Data must be of shape (nchan, n_samples).' - assert sf is not None, 'sf must be specified if passing a numpy array.' + assert sf is not None, "sf must be specified if passing a numpy array." assert isinstance(sf, (int, float)) if ch_names is None: - ch_names = ['CHAN' + str(i).zfill(3) for i in range(nchan)] + ch_names = ["CHAN" + str(i).zfill(3) for i in range(nchan)] else: ch_names = np.atleast_1d(np.asarray(ch_names, dtype=str)) - assert ch_names.ndim == 1, 'ch_names must be 1D.' - assert len(ch_names) == nchan, 'ch_names must match data.shape[0].' + assert ch_names.ndim == 1, "ch_names must be 1D." + assert len(ch_names) == nchan, "ch_names must match data.shape[0]." if bandpass: # Apply FIR bandpass filter all_freqs = np.hstack([[b[0], b[1]] for b in bands]) fmin, fmax = min(all_freqs), max(all_freqs) - data = mne.filter.filter_data(data.astype('float64'), sf, fmin, fmax, verbose=0) + data = mne.filter.filter_data(data.astype("float64"), sf, fmin, fmax, verbose=0) win = int(win_sec * sf) # nperseg if hypno is None: # Calculate the PSD over the whole data freqs, psd = signal.welch(data, sf, nperseg=win, **kwargs_welch) - return bandpower_from_psd( - psd, freqs, ch_names, bands=bands, relative=relative).set_index('Chan') + return bandpower_from_psd(psd, freqs, ch_names, bands=bands, relative=relative).set_index( + "Chan" + ) else: # Per each sleep stage defined in ``include``. hypno = np.asarray(hypno) - assert include is not None, 'include cannot be None if hypno is given' + assert include is not None, "include cannot be None if hypno is given" include = np.atleast_1d(np.asarray(include)) - assert hypno.ndim == 1, 'Hypno must be a 1D array.' - assert hypno.size == npts, 'Hypno must have same size as data.shape[1]' - assert include.size >= 1, '`include` must have at least one element.' - assert hypno.dtype.kind == include.dtype.kind, 'hypno and include must have same dtype' - assert np.in1d(hypno, include).any(), ( - 'None of the stages specified in `include` are present in hypno.') + assert hypno.ndim == 1, "Hypno must be a 1D array." + assert hypno.size == npts, "Hypno must have same size as data.shape[1]" + assert include.size >= 1, "`include` must have at least one element." + assert hypno.dtype.kind == include.dtype.kind, "hypno and include must have same dtype" + assert np.in1d( + hypno, include + ).any(), "None of the stages specified in `include` are present in hypno." # Initialize empty dataframe and loop over stages df_bp = pd.DataFrame([]) for stage in include: if stage not in hypno: continue data_stage = data[:, hypno == stage] - freqs, psd = signal.welch(data_stage, sf, nperseg=win, - **kwargs_welch) - bp_stage = bandpower_from_psd(psd, freqs, ch_names, bands=bands, - relative=relative) - bp_stage['Stage'] = stage + freqs, psd = signal.welch(data_stage, sf, nperseg=win, **kwargs_welch) + bp_stage = bandpower_from_psd(psd, freqs, ch_names, bands=bands, relative=relative) + bp_stage["Stage"] = stage df_bp = pd.concat([df_bp, bp_stage], axis=0) - return df_bp.set_index(['Stage', 'Chan']) - - -def bandpower_from_psd(psd, freqs, ch_names=None, bands=[(0.5, 4, 'Delta'), - (4, 8, 'Theta'), (8, 12, 'Alpha'), (12, 16, 'Sigma'), - (16, 30, 'Beta'), (30, 40, 'Gamma')], relative=True): + return df_bp.set_index(["Stage", "Chan"]) + + +def bandpower_from_psd( + psd, + freqs, + ch_names=None, + bands=[ + (0.5, 4, "Delta"), + (4, 8, "Theta"), + (8, 12, "Alpha"), + (12, 16, "Sigma"), + (16, 30, "Beta"), + (30, 40, "Gamma"), + ], + relative=True, +): """Compute the average power of the EEG in specified frequency band(s) given a pre-computed PSD. @@ -184,27 +208,27 @@ def bandpower_from_psd(psd, freqs, ch_names=None, bands=[(0.5, 4, 'Delta'), Bandpower dataframe, in which each row is a channel and each column a spectral band. """ # Type checks - assert isinstance(bands, list), 'bands must be a list of tuple(s)' - assert isinstance(relative, bool), 'relative must be a boolean' + assert isinstance(bands, list), "bands must be a list of tuple(s)" + assert isinstance(relative, bool), "relative must be a boolean" # Safety checks freqs = np.asarray(freqs) assert freqs.ndim == 1 psd = np.atleast_2d(psd) - assert psd.ndim == 2, 'PSD must be of shape (n_channels, n_freqs).' + assert psd.ndim == 2, "PSD must be of shape (n_channels, n_freqs)." all_freqs = np.hstack([[b[0], b[1]] for b in bands]) fmin, fmax = min(all_freqs), max(all_freqs) idx_good_freq = np.logical_and(freqs >= fmin, freqs <= fmax) freqs = freqs[idx_good_freq] res = freqs[1] - freqs[0] nchan = psd.shape[0] - assert nchan < psd.shape[1], 'PSD must be of shape (n_channels, n_freqs).' + assert nchan < psd.shape[1], "PSD must be of shape (n_channels, n_freqs)." if ch_names is not None: ch_names = np.atleast_1d(np.asarray(ch_names, dtype=str)) - assert ch_names.ndim == 1, 'ch_names must be 1D.' - assert len(ch_names) == nchan, 'ch_names must match psd.shape[0].' + assert ch_names.ndim == 1, "ch_names must be 1D." + assert len(ch_names) == nchan, "ch_names must match psd.shape[0]." else: - ch_names = ['CHAN' + str(i).zfill(3) for i in range(nchan)] + ch_names = ["CHAN" + str(i).zfill(3) for i in range(nchan)] bp = np.zeros((nchan, len(bands)), dtype=np.float64) psd = psd[:, idx_good_freq] total_power = simps(psd, dx=res) @@ -216,7 +240,8 @@ def bandpower_from_psd(psd, freqs, ch_names=None, bands=[(0.5, 4, 'Delta'), "There are negative values in PSD. This will result in incorrect " "bandpower values. We highly recommend working with an " "all-positive PSD. For more details, please refer to: " - "https://github.com/raphaelvallat/yasa/issues/29") + "https://github.com/raphaelvallat/yasa/issues/29" + ) logger.warning(msg) # Enumerate over the frequency bands @@ -232,21 +257,30 @@ def bandpower_from_psd(psd, freqs, ch_names=None, bands=[(0.5, 4, 'Delta'), # Convert to DataFrame bp = pd.DataFrame(bp, columns=labels) - bp['TotalAbsPow'] = np.squeeze(total_power) - bp['FreqRes'] = res + bp["TotalAbsPow"] = np.squeeze(total_power) + bp["FreqRes"] = res # bp['WindowSec'] = 1 / res - bp['Relative'] = relative - bp['Chan'] = ch_names - bp = bp.set_index('Chan').reset_index() + bp["Relative"] = relative + bp["Chan"] = ch_names + bp = bp.set_index("Chan").reset_index() # Add hidden attributes bp.bands_ = str(bands) return bp -def bandpower_from_psd_ndarray(psd, freqs, bands=[(0.5, 4, 'Delta'), - (4, 8, 'Theta'), (8, 12, 'Alpha'), - (12, 16, 'Sigma'), (16, 30, 'Beta'), - (30, 40, 'Gamma')], relative=True): +def bandpower_from_psd_ndarray( + psd, + freqs, + bands=[ + (0.5, 4, "Delta"), + (4, 8, "Theta"), + (8, 12, "Alpha"), + (12, 16, "Sigma"), + (16, 30, "Beta"), + (30, 40, "Gamma"), + ], + relative=True, +): """Compute bandpowers in N-dimensional PSD. This is a NumPy-only implementation of the :py:func:`yasa.bandpower_from_psd` function, @@ -275,14 +309,14 @@ def bandpower_from_psd_ndarray(psd, freqs, bands=[(0.5, 4, 'Delta'), Bandpower array of shape *(n_bands, ...)*. """ # Type checks - assert isinstance(bands, list), 'bands must be a list of tuple(s)' - assert isinstance(relative, bool), 'relative must be a boolean' + assert isinstance(bands, list), "bands must be a list of tuple(s)" + assert isinstance(relative, bool), "relative must be a boolean" # Safety checks freqs = np.asarray(freqs) psd = np.asarray(psd) - assert freqs.ndim == 1, 'freqs must be a 1-D array of shape (n_freqs,)' - assert psd.shape[-1] == freqs.shape[-1], 'n_freqs must be last axis of psd' + assert freqs.ndim == 1, "freqs must be a 1-D array of shape (n_freqs,)" + assert psd.shape[-1] == freqs.shape[-1], "n_freqs must be last axis of psd" # Extract frequencies of interest all_freqs = np.hstack([[b[0], b[1]] for b in bands]) @@ -300,7 +334,8 @@ def bandpower_from_psd_ndarray(psd, freqs, bands=[(0.5, 4, 'Delta'), "There are negative values in PSD. This will result in incorrect " "bandpower values. We highly recommend working with an " "all-positive PSD. For more details, please refer to: " - "https://github.com/raphaelvallat/yasa/issues/29") + "https://github.com/raphaelvallat/yasa/issues/29" + ) logger.warning(msg) # Calculate total power @@ -323,11 +358,35 @@ def bandpower_from_psd_ndarray(psd, freqs, bands=[(0.5, 4, 'Delta'), return bp -def irasa(data, sf=None, ch_names=None, band=(1, 30), - hset=[1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45, 1.5, 1.55, 1.6, - 1.65, 1.7, 1.75, 1.8, 1.85, 1.9], return_fit=True, win_sec=4, - kwargs_welch=dict(average='median', window='hamming'), - verbose=True): +def irasa( + data, + sf=None, + ch_names=None, + band=(1, 30), + hset=[ + 1.1, + 1.15, + 1.2, + 1.25, + 1.3, + 1.35, + 1.4, + 1.45, + 1.5, + 1.55, + 1.6, + 1.65, + 1.7, + 1.75, + 1.8, + 1.85, + 1.9, + ], + return_fit=True, + win_sec=4, + kwargs_welch=dict(average="median", window="hamming"), + verbose=True, +): r""" Separate the aperiodic (= fractal, or 1/f) and oscillatory component of the power spectra of EEG data using the IRASA method. @@ -440,40 +499,41 @@ def irasa(data, sf=None, ch_names=None, band=(1, 30), [5] https://doi.org/10.1101/2021.10.15.464483 """ import fractions + set_log_level(verbose) # Check if input data is a MNE Raw object if isinstance(data, mne.io.BaseRaw): - sf = data.info['sfreq'] # Extract sampling frequency + sf = data.info["sfreq"] # Extract sampling frequency ch_names = data.ch_names # Extract channel names - hp = data.info['highpass'] # Extract highpass filter - lp = data.info['lowpass'] # Extract lowpass filter + hp = data.info["highpass"] # Extract highpass filter + lp = data.info["lowpass"] # Extract lowpass filter data = data.get_data(units=dict(eeg="uV", emg="uV", eog="uV", ecg="uV")) else: # Safety checks - assert isinstance(data, np.ndarray), 'Data must be a numpy array.' + assert isinstance(data, np.ndarray), "Data must be a numpy array." data = np.atleast_2d(data) - assert data.ndim == 2, 'Data must be of shape (nchan, n_samples).' + assert data.ndim == 2, "Data must be of shape (nchan, n_samples)." nchan, npts = data.shape - assert nchan < npts, 'Data must be of shape (nchan, n_samples).' - assert sf is not None, 'sf must be specified if passing a numpy array.' + assert nchan < npts, "Data must be of shape (nchan, n_samples)." + assert sf is not None, "sf must be specified if passing a numpy array." assert isinstance(sf, (int, float)) if ch_names is None: - ch_names = ['CHAN' + str(i).zfill(3) for i in range(nchan)] + ch_names = ["CHAN" + str(i).zfill(3) for i in range(nchan)] else: ch_names = np.atleast_1d(np.asarray(ch_names, dtype=str)) - assert ch_names.ndim == 1, 'ch_names must be 1D.' - assert len(ch_names) == nchan, 'ch_names must match data.shape[0].' + assert ch_names.ndim == 1, "ch_names must be 1D." + assert len(ch_names) == nchan, "ch_names must match data.shape[0]." hp = 0 # Highpass filter unknown -> set to 0 Hz lp = sf / 2 # Lowpass filter unknown -> set to Nyquist # Check the other arguments hset = np.asarray(hset) - assert hset.ndim == 1, 'hset must be 1D.' - assert hset.size > 1, '2 or more resampling fators are required.' + assert hset.ndim == 1, "hset must be 1D." + assert hset.size > 1, "2 or more resampling fators are required." hset = np.round(hset, 4) # avoid float precision error with np.arange. band = sorted(band) - assert band[0] > 0, 'first element of band must be > 0.' - assert band[1] < (sf / 2), 'second element of band must be < (sf / 2).' + assert band[0] > 0, "first element of band must be > 0." + assert band[1] < (sf / 2), "second element of band must be < (sf / 2)." win = int(win_sec * sf) # nperseg # Inform about maximum resampled fitting range @@ -484,21 +544,27 @@ def irasa(data, sf=None, ch_names=None, band=(1, 30), logging.info(f"Fitting range: {band[0]:.2f}Hz-{band[1]:.2f}Hz") logging.info(f"Evaluated frequency range: {band_evaluated[0]:.2f}Hz-{band_evaluated[1]:.2f}Hz") if band_evaluated[0] < hp: - logging.warning("The evaluated frequency range starts below the " - f"highpass filter ({hp:.2f}Hz). Increase the lower band" - f" ({band[0]:.2f}Hz) or decrease the maximum value of " - f"the hset ({h_max:.2f}).") + logging.warning( + "The evaluated frequency range starts below the " + f"highpass filter ({hp:.2f}Hz). Increase the lower band" + f" ({band[0]:.2f}Hz) or decrease the maximum value of " + f"the hset ({h_max:.2f})." + ) if band_evaluated[1] > lp and lp < freq_Nyq_res: - logging.warning("The evaluated frequency range ends after the " - f"lowpass filter ({lp:.2f}Hz). Decrease the upper band" - f" ({band[1]:.2f}Hz) or decrease the maximum value of " - f"the hset ({h_max:.2f}).") + logging.warning( + "The evaluated frequency range ends after the " + f"lowpass filter ({lp:.2f}Hz). Decrease the upper band" + f" ({band[1]:.2f}Hz) or decrease the maximum value of " + f"the hset ({h_max:.2f})." + ) if band_evaluated[1] > freq_Nyq_res: - logging.warning("The evaluated frequency range ends after the " - "resampled Nyquist frequency " - f"({freq_Nyq_res:.2f}Hz). Decrease the upper band " - f"({band[1]:.2f}Hz) or decrease the maximum value " - f"of the hset ({h_max:.2f}).") + logging.warning( + "The evaluated frequency range ends after the " + "resampled Nyquist frequency " + f"({freq_Nyq_res:.2f}Hz). Decrease the upper band " + f"({band[1]:.2f}Hz) or decrease the maximum value " + f"of the hset ({h_max:.2f})." + ) # Calculate the original PSD over the whole data freqs, psd = signal.welch(data, sf, nperseg=win, **kwargs_welch) @@ -535,6 +601,7 @@ def irasa(data, sf=None, ch_names=None, band=(1, 30), if return_fit: # Aperiodic fit in semilog space for each channel from scipy.optimize import curve_fit + intercepts, slopes, r_squared = [], [], [] def func(t, a, b): @@ -545,26 +612,31 @@ def func(t, a, b): y_log = np.log(y) # Note that here we define bounds for the slope but not for the # intercept. - popt, pcov = curve_fit(func, freqs, y_log, p0=(2, -1), - bounds=((-np.inf, -10), (np.inf, 2))) + popt, pcov = curve_fit( + func, freqs, y_log, p0=(2, -1), bounds=((-np.inf, -10), (np.inf, 2)) + ) intercepts.append(popt[0]) slopes.append(popt[1]) # Calculate R^2: https://stackoverflow.com/q/19189362/10581531 residuals = y_log - func(freqs, *popt) ss_res = np.sum(residuals**2) - ss_tot = np.sum((y_log - np.mean(y_log))**2) + ss_tot = np.sum((y_log - np.mean(y_log)) ** 2) r_squared.append(1 - (ss_res / ss_tot)) # Create fit parameters dataframe - fit_params = {'Chan': ch_names, 'Intercept': intercepts, - 'Slope': slopes, 'R^2': r_squared, - 'std(osc)': np.std(psd_osc, axis=-1, ddof=1)} + fit_params = { + "Chan": ch_names, + "Intercept": intercepts, + "Slope": slopes, + "R^2": r_squared, + "std(osc)": np.std(psd_osc, axis=-1, ddof=1), + } return freqs, psd_aperiodic, psd_osc, pd.DataFrame(fit_params) else: return freqs, psd_aperiodic, psd_osc -def stft_power(data, sf, window=2, step=.2, band=(1, 30), interp=True, norm=False): +def stft_power(data, sf, window=2, step=0.2, band=(1, 30), interp=True, norm=False): """Compute the pointwise power via STFT and interpolation. Parameters @@ -619,7 +691,8 @@ def stft_power(data, sf, window=2, step=.2, band=(1, 30), interp=True, norm=Fals # Compute STFT and remove the last epoch f, t, Sxx = signal.stft( - data, sf, nperseg=nperseg, noverlap=noverlap, detrend=False, padded=True) + data, sf, nperseg=nperseg, noverlap=noverlap, detrend=False, padded=True + ) # Let's keep only the frequency of interest if band is not None: diff --git a/yasa/staging.py b/yasa/staging.py index 85e06d9..a0e1efb 100644 --- a/yasa/staging.py +++ b/yasa/staging.py @@ -16,7 +16,7 @@ from .others import sliding_window from .spectral import bandpower_from_psd_ndarray -logger = logging.getLogger('yasa') +logger = logging.getLogger("yasa") class SleepStaging: @@ -166,21 +166,21 @@ def __init__(self, raw, eeg_name, *, eog_name=None, emg_name=None, metadata=None # Validate metadata if isinstance(metadata, dict): - if 'age' in metadata.keys(): - assert 0 < metadata['age'] < 120, 'age must be between 0 and 120.' - if 'male' in metadata.keys(): - metadata['male'] = int(metadata['male']) - assert metadata['male'] in [0, 1], 'male must be 0 or 1.' + if "age" in metadata.keys(): + assert 0 < metadata["age"] < 120, "age must be between 0 and 120." + if "male" in metadata.keys(): + metadata["male"] = int(metadata["male"]) + assert metadata["male"] in [0, 1], "male must be 0 or 1." # Validate Raw instance and load data - assert isinstance(raw, mne.io.BaseRaw), 'raw must be a MNE Raw object.' - sf = raw.info['sfreq'] + assert isinstance(raw, mne.io.BaseRaw), "raw must be a MNE Raw object." + sf = raw.info["sfreq"] ch_names = np.array([eeg_name, eog_name, emg_name]) - ch_types = np.array(['eeg', 'eog', 'emg']) + ch_types = np.array(["eeg", "eog", "emg"]) keep_chan = [] for c in ch_names: if c is not None: - assert c in raw.ch_names, '%s does not exist' % c + assert c in raw.ch_names, "%s does not exist" % c keep_chan.append(True) else: keep_chan.append(False) @@ -191,17 +191,17 @@ def __init__(self, raw, eeg_name, *, eog_name=None, emg_name=None, metadata=None raw_pick = raw.copy().pick_channels(ch_names, ordered=True) # Downsample if sf != 100 - assert sf > 80, 'Sampling frequency must be at least 80 Hz.' + assert sf > 80, "Sampling frequency must be at least 80 Hz." if sf != 100: raw_pick.resample(100, npad="auto") - sf = raw_pick.info['sfreq'] + sf = raw_pick.info["sfreq"] # Get data and convert to microVolts data = raw_pick.get_data(units=dict(eeg="uV", emg="uV", eog="uV", ecg="uV")) # Extract duration of recording in minutes duration_minutes = data.shape[1] / sf / 60 - assert duration_minutes >= 5, 'At least 5 minutes of data is required.' + assert duration_minutes >= 5, "At least 5 minutes of data is required." # Add to self self.sf = sf @@ -227,10 +227,14 @@ def fit(self): win_sec = 5 # = 2 / freq_broad[0] sf = self.sf win = int(win_sec * sf) - kwargs_welch = dict(window='hamming', nperseg=win, average='median') + kwargs_welch = dict(window="hamming", nperseg=win, average="median") bands = [ - (0.4, 1, 'sdelta'), (1, 4, 'fdelta'), (4, 8, 'theta'), - (8, 12, 'alpha'), (12, 16, 'sigma'), (16, 30, 'beta') + (0.4, 1, "sdelta"), + (1, 4, "fdelta"), + (4, 8, "theta"), + (8, 12, "alpha"), + (12, 16, "sigma"), + (16, 30, "beta"), ] ####################################################################### @@ -243,7 +247,8 @@ def fit(self): # Preprocessing # - Filter the data dt_filt = filter_data( - self.data[i, :], sf, l_freq=freq_broad[0], h_freq=freq_broad[1], verbose=False) + self.data[i, :], sf, l_freq=freq_broad[0], h_freq=freq_broad[1], verbose=False + ) # - Extract epochs. Data is now of shape (n_epochs, n_samples). times, epochs = sliding_window(dt_filt, sf=sf, window=30) @@ -251,44 +256,42 @@ def fit(self): hmob, hcomp = ant.hjorth_params(epochs, axis=1) feat = { - 'std': np.std(epochs, ddof=1, axis=1), - 'iqr': sp_stats.iqr(epochs, rng=(25, 75), axis=1), - 'skew': sp_stats.skew(epochs, axis=1), - 'kurt': sp_stats.kurtosis(epochs, axis=1), - 'nzc': ant.num_zerocross(epochs, axis=1), - 'hmob': hmob, - 'hcomp': hcomp + "std": np.std(epochs, ddof=1, axis=1), + "iqr": sp_stats.iqr(epochs, rng=(25, 75), axis=1), + "skew": sp_stats.skew(epochs, axis=1), + "kurt": sp_stats.kurtosis(epochs, axis=1), + "nzc": ant.num_zerocross(epochs, axis=1), + "hmob": hmob, + "hcomp": hcomp, } # Calculate spectral power features (for EEG + EOG) freqs, psd = sp_sig.welch(epochs, sf, **kwargs_welch) - if c != 'emg': + if c != "emg": bp = bandpower_from_psd_ndarray(psd, freqs, bands=bands) for j, (_, _, b) in enumerate(bands): feat[b] = bp[j] # Add power ratios for EEG - if c == 'eeg': - delta = feat['sdelta'] + feat['fdelta'] - feat['dt'] = delta / feat['theta'] - feat['ds'] = delta / feat['sigma'] - feat['db'] = delta / feat['beta'] - feat['at'] = feat['alpha'] / feat['theta'] + if c == "eeg": + delta = feat["sdelta"] + feat["fdelta"] + feat["dt"] = delta / feat["theta"] + feat["ds"] = delta / feat["sigma"] + feat["db"] = delta / feat["beta"] + feat["at"] = feat["alpha"] / feat["theta"] # Add total power idx_broad = np.logical_and(freqs >= freq_broad[0], freqs <= freq_broad[1]) dx = freqs[1] - freqs[0] - feat['abspow'] = np.trapz(psd[:, idx_broad], dx=dx) + feat["abspow"] = np.trapz(psd[:, idx_broad], dx=dx) # Calculate entropy and fractal dimension features - feat['perm'] = np.apply_along_axis( - ant.perm_entropy, axis=1, arr=epochs, normalize=True) - feat['higuchi'] = np.apply_along_axis( - ant.higuchi_fd, axis=1, arr=epochs) - feat['petrosian'] = ant.petrosian_fd(epochs, axis=1) + feat["perm"] = np.apply_along_axis(ant.perm_entropy, axis=1, arr=epochs, normalize=True) + feat["higuchi"] = np.apply_along_axis(ant.higuchi_fd, axis=1, arr=epochs) + feat["petrosian"] = ant.petrosian_fd(epochs, axis=1) # Convert to dataframe - feat = pd.DataFrame(feat).add_prefix(c + '_') + feat = pd.DataFrame(feat).add_prefix(c + "_") features.append(feat) ####################################################################### @@ -297,20 +300,19 @@ def fit(self): # Save features to dataframe features = pd.concat(features, axis=1) - features.index.name = 'epoch' + features.index.name = "epoch" # Apply centered rolling average (15 epochs = 7 min 30) # Triang: [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1., # 0.875, 0.75, 0.625, 0.5, 0.375, 0.25, 0.125] - rollc = features.rolling( - window=15, center=True, min_periods=1, win_type='triang').mean() + rollc = features.rolling(window=15, center=True, min_periods=1, win_type="triang").mean() rollc[rollc.columns] = robust_scale(rollc, quantile_range=(5, 95)) - rollc = rollc.add_suffix('_c7min_norm') + rollc = rollc.add_suffix("_c7min_norm") # Now look at the past 2 minutes rollp = features.rolling(window=4, min_periods=1).mean() rollp[rollp.columns] = robust_scale(rollp, quantile_range=(5, 95)) - rollp = rollp.add_suffix('_p2min_norm') + rollp = rollp.add_suffix("_p2min_norm") # Add to current set of features features = features.join(rollc).join(rollp) @@ -320,8 +322,8 @@ def fit(self): ####################################################################### # Add temporal features - features['time_hour'] = times / 3600 - features['time_norm'] = times / times[-1] + features["time_hour"] = times / 3600 + features["time_norm"] = times / times[-1] # Add metadata if present if self.metadata is not None: @@ -332,10 +334,10 @@ def fit(self): cols_float = features.select_dtypes(np.float64).columns.tolist() features[cols_float] = features[cols_float].astype(np.float32) # Make sure that age and sex are encoded as int - if 'age' in features.columns: - features['age'] = features['age'].astype(int) - if 'male' in features.columns: - features['male'] = features['male'].astype(int) + if "age" in features.columns: + features["age"] = features["age"].astype(int) + if "male" in features.columns: + features["male"] = features["male"].astype(int) # Sort the column names here (same behavior as lightGBM) features.sort_index(axis=1, inplace=True) @@ -352,7 +354,7 @@ def get_features(self): features : :py:class:`pandas.DataFrame` Feature dataframe. """ - if not hasattr(self, '_features'): + if not hasattr(self, "_features"): self.fit() return self._features.copy() @@ -362,22 +364,32 @@ def _validate_predict(self, clf): # Note that clf.feature_name_ is only available in lightgbm>=3.0 f_diff = np.setdiff1d(clf.feature_name_, self.feature_name_) if len(f_diff): - raise ValueError("The following features are present in the " - "classifier but not in the current features set:", f_diff) - f_diff = np.setdiff1d(self.feature_name_, clf.feature_name_, ) + raise ValueError( + "The following features are present in the " + "classifier but not in the current features set:", + f_diff, + ) + f_diff = np.setdiff1d( + self.feature_name_, + clf.feature_name_, + ) if len(f_diff): - raise ValueError("The following features are present in the " - "current feature set but not in the classifier:", f_diff) + raise ValueError( + "The following features are present in the " + "current feature set but not in the classifier:", + f_diff, + ) def _load_model(self, path_to_model): """Load the relevant trained classifier.""" if path_to_model == "auto": from pathlib import Path - clf_dir = os.path.join(str(Path(__file__).parent), 'classifiers/') - name = 'clf_eeg' - name = name + '+eog' if 'eog' in self.ch_types else name - name = name + '+emg' if 'emg' in self.ch_types else name - name = name + '+demo' if self.metadata is not None else name + + clf_dir = os.path.join(str(Path(__file__).parent), "classifiers/") + name = "clf_eeg" + name = name + "+eog" if "eog" in self.ch_types else name + name = name + "+emg" if "emg" in self.ch_types else name + name = name + "+demo" if self.metadata is not None else name # e.g. clf_eeg+eog+emg+demo_lgb_0.4.0.joblib all_matching_files = glob.glob(clf_dir + name + "*.joblib") # Find the latest file @@ -410,7 +422,7 @@ def predict(self, path_to_model="auto"): pred : :py:class:`numpy.ndarray` The predicted sleep stages. """ - if not hasattr(self, '_features'): + if not hasattr(self, "_features"): self.fit() # Load and validate pre-trained classifier clf = self._load_model(path_to_model) @@ -419,7 +431,7 @@ def predict(self, path_to_model="auto"): # Predict the sleep stages and probabilities self._predicted = clf.predict(X) proba = pd.DataFrame(clf.predict_proba(X), columns=clf.classes_) - proba.index.name = 'epoch' + proba.index.name = "epoch" self._proba = proba return self._predicted.copy() @@ -442,13 +454,16 @@ def predict_proba(self, path_to_model="auto"): proba : :py:class:`pandas.DataFrame` The predicted probability for each sleep stage for each 30-sec epoch of data. """ - if not hasattr(self, '_proba'): + if not hasattr(self, "_proba"): self.predict(path_to_model) return self._proba.copy() - def plot_predict_proba(self, proba=None, majority_only=False, - palette=['#99d7f1', '#009DDC', 'xkcd:twilight blue', - 'xkcd:rich purple', 'xkcd:sunflower']): + def plot_predict_proba( + self, + proba=None, + majority_only=False, + palette=["#99d7f1", "#009DDC", "xkcd:twilight blue", "xkcd:rich purple", "xkcd:sunflower"], + ): """ Plot the predicted probability for each sleep stage for each 30-sec epoch of data. @@ -459,16 +474,16 @@ def plot_predict_proba(self, proba=None, majority_only=False, majority_only : boolean If True, probabilities of the non-majority classes will be set to 0. """ - if proba is None and not hasattr(self, '_features'): + if proba is None and not hasattr(self, "_features"): raise ValueError("Must call .predict_proba before this function") if proba is None: proba = self._proba.copy() else: - assert isinstance(proba, pd.DataFrame), 'proba must be a dataframe' + assert isinstance(proba, pd.DataFrame), "proba must be a dataframe" if majority_only: cond = proba.apply(lambda x: x == x.max(), axis=1) proba = proba.where(cond, other=0) - ax = proba.plot(kind='area', color=palette, figsize=(10, 5), alpha=.8, stacked=True, lw=0) + ax = proba.plot(kind="area", color=palette, figsize=(10, 5), alpha=0.8, stacked=True, lw=0) # Add confidence # confidence = proba.max(1) # ax.plot(confidence, lw=1, color='k', ls='-', alpha=0.5, diff --git a/yasa/tests/test_detection.py b/yasa/tests/test_detection.py index cbf23cd..12b178c 100644 --- a/yasa/tests/test_detection.py +++ b/yasa/tests/test_detection.py @@ -17,20 +17,20 @@ sf = 100 # 1) Single channel, we take one every other point to keep a sf of 100 Hz -data = np.loadtxt('notebooks/data_N2_spindles_15sec_200Hz.txt')[::2] -data_sigma = filter_data(data, sf, 12, 15, method='fir', verbose=0) +data = np.loadtxt("notebooks/data_N2_spindles_15sec_200Hz.txt")[::2] +data_sigma = filter_data(data, sf, 12, 15, method="fir", verbose=0) # Load an extract of N3 sleep without any spindle -data_n3 = np.loadtxt('notebooks/data_N3_no-spindles_30sec_100Hz.txt') +data_n3 = np.loadtxt("notebooks/data_N3_no-spindles_30sec_100Hz.txt") # 2) Multi-channel # Load a full recording and its hypnogram -data_full = np.load('notebooks/data_full_6hrs_100Hz_Cz+Fz+Pz.npz').get('data') -chan_full = np.load('notebooks/data_full_6hrs_100Hz_Cz+Fz+Pz.npz').get('chan') -hypno_full = np.load('notebooks/data_full_6hrs_100Hz_hypno.npz').get('hypno') +data_full = np.load("notebooks/data_full_6hrs_100Hz_Cz+Fz+Pz.npz").get("data") +chan_full = np.load("notebooks/data_full_6hrs_100Hz_Cz+Fz+Pz.npz").get("chan") +hypno_full = np.load("notebooks/data_full_6hrs_100Hz_hypno.npz").get("hypno") # Let's add a channel with bad data amplitude -chan_full = np.append(chan_full, 'Bad') # ['Cz', 'Fz', 'Pz', 'Bad'] +chan_full = np.append(chan_full, "Bad") # ['Cz', 'Fz', 'Pz', 'Bad'] data_full = np.vstack((data_full, data_full[-1, :] * 1e8)) # Keep only Fz and during a N3 sleep period with (huge) slow-waves @@ -38,16 +38,15 @@ hypno_sw = hypno_full[666000:672000] # MNE Raw -data_mne = mne.io.read_raw_fif('notebooks/sub-02_mne_raw.fif', preload=True, verbose=0) +data_mne = mne.io.read_raw_fif("notebooks/sub-02_mne_raw.fif", preload=True, verbose=0) data_mne.pick_types(eeg=True) -data_mne_single = data_mne.copy().pick_channels(['F3']) -hypno_mne = np.loadtxt('notebooks/sub-02_hypno_30s.txt', dtype=str) +data_mne_single = data_mne.copy().pick_channels(["F3"]) +hypno_mne = np.loadtxt("notebooks/sub-02_hypno_30s.txt", dtype=str) hypno_mne = hypno_str_to_int(hypno_mne) hypno_mne = hypno_upsample_to_data(hypno=hypno_mne, sf_hypno=(1 / 30), data=data_mne) class TestDetection(unittest.TestCase): - def test_check_data_hypno(self): """Test preprocessing of data and hypno.""" pass @@ -74,20 +73,20 @@ def test_spindles_detect(self): sp.plot_average(ci=None, filt=(None, 30)) # Skip bootstrapping np.testing.assert_array_equal(np.squeeze(sp._data), data) assert sp._sf == sf - sp.summary(grp_chan=True, grp_stage=True, aggfunc='median', sort=False) + sp.summary(grp_chan=True, grp_stage=True, aggfunc="median", sort=False) # Test with custom thresholds - spindles_detect(data, sf, thresh={'rel_pow': 0.25}) - spindles_detect(data, sf, thresh={'rms': 1.25}) - spindles_detect(data, sf, thresh={'rel_pow': 0.25, 'corr': .60}) + spindles_detect(data, sf, thresh={"rel_pow": 0.25}) + spindles_detect(data, sf, thresh={"rms": 1.25}) + spindles_detect(data, sf, thresh={"rel_pow": 0.25, "corr": 0.60}) # Test with disabled thresholds - spindles_detect(data, sf, thresh={'rel_pow': None}) - spindles_detect(data, sf, thresh={'corr': None}, verbose='debug') - spindles_detect(data, sf, thresh={'rms': None}) - spindles_detect(data, sf, thresh={'rms': None, 'corr': None}) - spindles_detect(data, sf, thresh={'rms': None, 'rel_pow': None}) - spindles_detect(data, sf, thresh={'corr': None, 'rel_pow': None}) + spindles_detect(data, sf, thresh={"rel_pow": None}) + spindles_detect(data, sf, thresh={"corr": None}, verbose="debug") + spindles_detect(data, sf, thresh={"rms": None}) + spindles_detect(data, sf, thresh={"rms": None, "corr": None}) + spindles_detect(data, sf, thresh={"rms": None, "rel_pow": None}) + spindles_detect(data, sf, thresh={"corr": None, "rel_pow": None}) # Test with hypnogram spindles_detect(data, sf, hypno=np.ones(data.size)) @@ -99,27 +98,27 @@ def test_spindles_detect(self): with pytest.raises(ValueError): sp.get_coincidence_matrix() - with self.assertLogs('yasa', level='WARNING'): + with self.assertLogs("yasa", level="WARNING"): spindles_detect(data_n3, sf) # assert sp is None --> Fails? # Ensure that the two warnings are tested - with self.assertLogs('yasa', level='WARNING'): - sp = spindles_detect(data_n3, sf, thresh={'corr': .95}) + with self.assertLogs("yasa", level="WARNING"): + sp = spindles_detect(data_n3, sf, thresh={"corr": 0.95}) assert sp is None # Test with wrong data amplitude (1) - with self.assertLogs('yasa', level='ERROR'): + with self.assertLogs("yasa", level="ERROR"): sp = spindles_detect(data_n3 / 1e6, sf) assert sp is None # Test with wrong data amplitude (2) - with self.assertLogs('yasa', level='ERROR'): + with self.assertLogs("yasa", level="ERROR"): sp = spindles_detect(data_n3 * 1e6, sf) assert sp is None # Test with a random array - with self.assertLogs('yasa', level='ERROR'): + with self.assertLogs("yasa", level="ERROR"): np.random.seed(123) sp = spindles_detect(np.random.random(size=1000), sf) assert sp is None @@ -144,17 +143,15 @@ def test_spindles_detect(self): assert sp._data.shape == sp._data_filt.shape np.testing.assert_array_equal(sp._data, data_full) assert sp._sf == sf - sp_no_out = spindles_detect(data_full, sf, chan_full, - remove_outliers=True) - sp_multi = spindles_detect(data_full, sf, chan_full, - multi_only=True) + sp_no_out = spindles_detect(data_full, sf, chan_full, remove_outliers=True) + sp_multi = spindles_detect(data_full, sf, chan_full, multi_only=True) assert sp_multi.summary().shape[0] < sp.summary().shape[0] assert sp_no_out.summary().shape[0] < sp.summary().shape[0] # Test with hypnogram sp = spindles_detect(data_full, sf, hypno=hypno_full, include=2) sp.summary(grp_chan=False, grp_stage=False) - sp.summary(grp_chan=False, grp_stage=True, aggfunc='median') + sp.summary(grp_chan=False, grp_stage=True, aggfunc="median") sp.summary(grp_chan=True, grp_stage=False) sp.summary(grp_chan=True, grp_stage=True, sort=False) sp.plot_average(ci=None) @@ -162,13 +159,12 @@ def test_spindles_detect(self): sp.plot_detection() # Using a MNE raw object (and disabling one threshold) - spindles_detect(data_mne, thresh={'corr': None, 'rms': 3}) + spindles_detect(data_mne, thresh={"corr": None, "rms": 3}) spindles_detect(data_mne, hypno=hypno_mne, include=2, verbose=True) - plt.close('all') + plt.close("all") def test_sw_detect(self): - """Test function slow-wave detect - """ + """Test function slow-wave detect""" # Parameters product testing freq_sw = [(0.3, 3.5), (0.5, 4)] dur_neg = [(0.3, 1.5), [0.1, 2]] @@ -181,8 +177,8 @@ def test_sw_detect(self): for i, (f, dn, dp, an, ap, aptp) in enumerate(prod_args): # print((f, dn, dp, an, ap, aptp)) sw_detect( - data_sw, sf, freq_sw=f, dur_neg=dn, dur_pos=dp, amp_neg=an, amp_pos=ap, - amp_ptp=aptp) + data_sw, sf, freq_sw=f, dur_neg=dn, dur_pos=dp, amp_neg=an, amp_pos=ap, amp_ptp=aptp + ) # With N3 hypnogram sw = sw_detect(data_sw, sf, hypno=hypno_sw, coupling=True) @@ -196,12 +192,12 @@ def test_sw_detect(self): assert sw._sf == sf # Test with wrong data amplitude - with self.assertLogs('yasa', level='ERROR'): + with self.assertLogs("yasa", level="ERROR"): sw = sw_detect(data_sw * 0, sf) # All channels are flat assert sw is None # With 2D data - sw_detect(data_sw[np.newaxis, ...], sf, verbose='INFO') + sw_detect(data_sw[np.newaxis, ...], sf, verbose="INFO") # No values in hypno intersect with include with pytest.raises(AssertionError): @@ -225,7 +221,7 @@ def test_sw_detect(self): # Test with hypnogram sw = sw_detect(data_full, sf, chan_full, hypno=hypno_full, coupling=True) sw.summary(grp_chan=False, grp_stage=False) - sw.summary(grp_chan=False, grp_stage=True, aggfunc='median') + sw.summary(grp_chan=False, grp_stage=True, aggfunc="median") sw.summary(grp_chan=True, grp_stage=False) sw.summary(grp_chan=True, grp_stage=True, sort=False) sw.plot_average(ci=None) @@ -235,33 +231,39 @@ def test_sw_detect(self): sw_sum = sw.summary() assert "ndPAC" in sw_sum.columns # There should be some zero in the ndPAC (full dataframe) - assert sw._events[sw._events['ndPAC'] == 0].shape[0] > 0 + assert sw._events[sw._events["ndPAC"] == 0].shape[0] > 0 # Coinciding spindles and masking sp = spindles_detect(data_full, sf, chan_full, hypno=hypno_full) sw.find_cooccurring_spindles(sp.summary()) sw_sum = sw.summary() assert "CooccurringSpindle" in sw_sum.columns assert "DistanceSpindleToSW" in sw_sum.columns - sw_sum_masked = sw.summary(grp_chan=True, grp_stage=False, - mask=sw._events['CooccurringSpindle']) + sw_sum_masked = sw.summary( + grp_chan=True, grp_stage=False, mask=sw._events["CooccurringSpindle"] + ) assert sw_sum_masked.shape[0] < sw_sum.shape[0] # Test with different coupling params - sw_detect(data_full, sf, chan_full, hypno=hypno_full, coupling=True, - coupling_params={"freq_sp": (12, 16), "time": 2, "p": None}) + sw_detect( + data_full, + sf, + chan_full, + hypno=hypno_full, + coupling=True, + coupling_params={"freq_sp": (12, 16), "time": 2, "p": None}, + ) # Using a MNE raw object sw_detect(data_mne) sw_detect(data_mne, hypno=hypno_mne, include=3) - plt.close('all') + plt.close("all") def test_rem_detect(self): - """Test function REM detect - """ - file_rem = np.load('notebooks/data_EOGs_REM_256Hz.npz') - data_rem = file_rem['data'] + """Test function REM detect""" + file_rem = np.load("notebooks/data_EOGs_REM_256Hz.npz") + data_rem = file_rem["data"] loc, roc = data_rem[0, :], data_rem[1, :] - sf_rem = file_rem['sf'] + sf_rem = file_rem["sf"] hypno_rem = 4 * np.ones_like(loc) # Parameters product testing @@ -275,7 +277,7 @@ def test_rem_detect(self): rem_detect(loc, roc, sf_rem, hypno=h, freq_rem=f, duration=dr, amplitude=am) # With isolation forest - rem = rem_detect(loc, roc, sf, verbose='info') + rem = rem_detect(loc, roc, sf, verbose="info") rem2 = rem_detect(loc, roc, sf, remove_outliers=True) assert rem.summary().shape[0] > rem2.summary().shape[0] rem.summary() @@ -290,15 +292,15 @@ def test_rem_detect(self): hypno_rem = np.r_[np.ones(int(loc.size / 2)), 4 * np.ones(int(loc.size / 2))] rem2 = rem_detect(loc, roc, sf, hypno=hypno_rem) assert rem.summary().shape[0] > rem2.summary().shape[0] - rem2.summary(grp_stage=True, aggfunc='median') + rem2.summary(grp_stage=True, aggfunc="median") # Test with wrong data amplitude on ROC - with self.assertLogs('yasa', level='ERROR'): + with self.assertLogs("yasa", level="ERROR"): rem = rem_detect(loc * 1e-8, roc, sf) assert rem is None # Test with wrong data amplitude on LOC - with self.assertLogs('yasa', level='ERROR'): + with self.assertLogs("yasa", level="ERROR"): rem = rem_detect(loc, roc * 1e8, sf) assert rem is None @@ -307,40 +309,62 @@ def test_rem_detect(self): rem_detect(loc, roc, sf, hypno=hypno_rem, include=5) def test_art_detect(self): - """Test function art_detect - """ - file_9 = np.load('notebooks/data_full_6hrs_100Hz_9channels.npz') - data_9 = file_9.get('data') - hypno_9 = np.load('notebooks/data_full_6hrs_100Hz_hypno.npz').get('hypno') # noqa + """Test function art_detect""" + file_9 = np.load("notebooks/data_full_6hrs_100Hz_9channels.npz") + data_9 = file_9.get("data") + hypno_9 = np.load("notebooks/data_full_6hrs_100Hz_hypno.npz").get("hypno") # noqa # For the sake of the example, let's add some flat data at the end data_9 = np.concatenate((data_9, np.zeros((data_9.shape[0], 20000))), axis=1) hypno_9 = np.concatenate((hypno_9, np.zeros(20000))) # Start different combinations - art_detect(data_9, sf=100, window=10, method='covar', threshold=3) - art_detect(data_9, sf=100, window=6, hypno=hypno_9, include=(2, 3), - method='covar', threshold=3) - art_detect(data_9, sf=100, window=5, method='std', threshold=2) - art_detect(data_9, sf=100, window=5, hypno=hypno_9, method='std', - include=(0, 1, 2, 3, 4, 5, 6), threshold=2) - art_detect(data_9, sf=100, window=5., hypno=hypno_9, method='std', - include=(0, 1, 2, 3, 4, 5, 6), threshold=10) + art_detect(data_9, sf=100, window=10, method="covar", threshold=3) + art_detect( + data_9, sf=100, window=6, hypno=hypno_9, include=(2, 3), method="covar", threshold=3 + ) + art_detect(data_9, sf=100, window=5, method="std", threshold=2) + art_detect( + data_9, + sf=100, + window=5, + hypno=hypno_9, + method="std", + include=(0, 1, 2, 3, 4, 5, 6), + threshold=2, + ) + art_detect( + data_9, + sf=100, + window=5.0, + hypno=hypno_9, + method="std", + include=(0, 1, 2, 3, 4, 5, 6), + threshold=10, + ) # Single channel - art_detect(data_9[0], 100, window=10, method='covar') - art_detect(data_9[0], 100, window=5, method='std', verbose=True) + art_detect(data_9[0], 100, window=10, method="covar") + art_detect(data_9[0], 100, window=5, method="std", verbose=True) # Not enough epochs for stage hypno_9[:100] = 6 - art_detect(data_9, sf, window=5., hypno=hypno_9, include=6, - method='std', threshold=3, n_chan_reject=5) + art_detect( + data_9, + sf, + window=5.0, + hypno=hypno_9, + include=6, + method="std", + threshold=3, + n_chan_reject=5, + ) # With a flat channel data_with_flat = np.vstack((data_9, np.zeros(data_9.shape[-1]))) - art_detect(data_with_flat, sf, method='std', n_chan_reject=5) + art_detect(data_with_flat, sf, method="std", n_chan_reject=5) # Using a MNE raw object - art_detect(data_mne, window=10., hypno=hypno_mne, method='covar', verbose='INFO') + art_detect(data_mne, window=10.0, hypno=hypno_mne, method="covar", verbose="INFO") with pytest.raises(AssertionError): # None of include in hypno - art_detect(data_mne, window=10., hypno=hypno_mne, include=[7, 8]) + art_detect(data_mne, window=10.0, hypno=hypno_mne, include=[7, 8]) diff --git a/yasa/tests/test_heart.py b/yasa/tests/test_heart.py index 9974837..873aeb3 100644 --- a/yasa/tests/test_heart.py +++ b/yasa/tests/test_heart.py @@ -11,14 +11,14 @@ class TestHeart(unittest.TestCase): - def test_hrv_stage(self): """Test function hrv_stage""" epochs, rpeaks = hrv_stage(data, sf, hypno=hypno) assert epochs.shape[0] == len(rpeaks) assert epochs["duration"].min() == 120 # 2 minutes assert np.array_equal( - epochs.columns, ['start', 'duration', 'hr_mean', 'hr_std', 'hrv_rmssd']) + epochs.columns, ["start", "duration", "hr_mean", "hr_std", "hrv_rmssd"] + ) # Only N2 epochs_N2, _ = hrv_stage(data, sf, hypno=hypno, include=2) @@ -42,12 +42,11 @@ def test_hrv_stage(self): # The heartbeat detection is applied on the entire recording! epochs_nohypno, _ = hrv_stage(data, sf) assert epochs_nohypno.shape[0] == 1 - assert epochs_nohypno.loc[(0, 0), 'duration'] == data.size / sf + assert epochs_nohypno.loc[(0, 0), "duration"] == data.size / sf # No hypno (= full recording) with equal_length # Equivalent to a sliding window approach - epochs_nohypno, _ = hrv_stage( - data, sf, equal_length=True) - assert epochs_nohypno['start'].is_monotonic_increasing - assert epochs_nohypno['duration'].nunique() == 1 + epochs_nohypno, _ = hrv_stage(data, sf, equal_length=True) + assert epochs_nohypno["start"].is_monotonic_increasing + assert epochs_nohypno["duration"].nunique() == 1 assert epochs_nohypno.shape[0] == data.size / (2 * 60 * sf) # 2 minutes diff --git a/yasa/tests/test_hypno.py b/yasa/tests/test_hypno.py index b672d1b..d66b88c 100644 --- a/yasa/tests/test_hypno.py +++ b/yasa/tests/test_hypno.py @@ -4,21 +4,24 @@ import numpy as np import pandas as pd from pandas.testing import assert_frame_equal -from yasa.hypno import (hypno_str_to_int, hypno_int_to_str, - hypno_upsample_to_sf, hypno_fit_to_data, - hypno_upsample_to_data) +from yasa.hypno import ( + hypno_str_to_int, + hypno_int_to_str, + hypno_upsample_to_sf, + hypno_fit_to_data, + hypno_upsample_to_data, +) from yasa.hypno import hypno_find_periods as hfp hypno = np.array([0, 0, 0, 1, 2, 2, 3, 3, 4]) -hypno_txt = np.array(['W', 'W', 'W', 'N1', 'N2', 'N2', 'N3', 'N3', 'R']) +hypno_txt = np.array(["W", "W", "W", "N1", "N2", "N2", "N3", "N3", "R"]) -def create_raw(npts, ch_names=['F4-M1', 'F3-M2'], sf=100): +def create_raw(npts, ch_names=["F4-M1", "F3-M2"], sf=100): """Utility function for test fit to data.""" nchan = len(ch_names) - info = mne.create_info(ch_names=ch_names, sfreq=sf, - ch_types=['eeg'] * nchan, verbose=0) + info = mne.create_info(ch_names=ch_names, sfreq=sf, ch_types=["eeg"] * nchan, verbose=0) data = np.random.rand(nchan, npts) raw = mne.io.RawArray(data, info, verbose=0) return raw @@ -28,14 +31,12 @@ class TestHypno(unittest.TestCase): """Test functions in the hypno.py file.""" def test_conversion(self): - """Test str <--> int conversion. - """ + """Test str <--> int conversion.""" assert np.array_equal(hypno_str_to_int(hypno_txt), hypno) assert np.array_equal(hypno_int_to_str(hypno), hypno_txt) def test_upsampling(self): - """Test hypnogram upsampling. - """ + """Test hypnogram upsampling.""" hypno100 = hypno_upsample_to_sf(hypno=hypno, sf_hypno=(1 / 30), sf_data=100) nhyp100 = hypno100.size assert nhyp100 / hypno.size == 3000 @@ -50,24 +51,27 @@ def test_upsampling(self): assert hypno_fit_to_data(hypno100, create_raw(26750)).size == 26750 # .. Using Numpy + SF from numpy.random import rand + assert np.array_equal(hypno_fit_to_data(hypno100, rand(nhyp100), 100), hypno100) assert hypno_fit_to_data(hypno100, rand(27250), 100).size == 27250 assert hypno_fit_to_data(hypno100, rand(26750), 100).size == 26750 # .. No SF - assert np.array_equal(hypno_fit_to_data(hypno100, rand(nhyp100)), - hypno100) + assert np.array_equal(hypno_fit_to_data(hypno100, rand(nhyp100)), hypno100) assert hypno_fit_to_data(hypno100, rand(27250)).size == 27250 assert hypno_fit_to_data(hypno100, rand(26750)).size == 26750 # Two steps combined - assert (hypno_upsample_to_data(hypno, sf_hypno=1 / 30, - data=create_raw(26750)).size == 26750) - assert (hypno_upsample_to_data(hypno, sf_hypno=1 / 30, - data=rand(27250), sf_data=100 - ).size == 27250) - assert (hypno_upsample_to_data(hypno, sf_hypno=1 / 30, - data=rand(2 * (hypno100.size + 250)), - sf_data=200).size == 2 * 27250) + assert hypno_upsample_to_data(hypno, sf_hypno=1 / 30, data=create_raw(26750)).size == 26750 + assert ( + hypno_upsample_to_data(hypno, sf_hypno=1 / 30, data=rand(27250), sf_data=100).size + == 27250 + ) + assert ( + hypno_upsample_to_data( + hypno, sf_hypno=1 / 30, data=rand(2 * (hypno100.size + 250)), sf_data=200 + ).size + == 2 * 27250 + ) def test_periods(self): """Test periods detection.""" @@ -75,41 +79,50 @@ def test_periods(self): x = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0] # 1a. No thresholding - expected = pd.DataFrame({ - 'values': [0, 1, 0, 1, 0], - 'start': [0, 11, 14, 16, 25], - 'length': [11, 3, 2, 9, 2]}) + expected = pd.DataFrame( + {"values": [0, 1, 0, 1, 0], "start": [0, 11, 14, 16, 25], "length": [11, 3, 2, 9, 2]} + ) kwargs = dict( - check_dtype=False, check_index_type=False, check_column_type=False, - check_frame_type=False) + check_dtype=False, + check_index_type=False, + check_column_type=False, + check_frame_type=False, + ) assert_frame_equal(hfp(x, sf_hypno=1 / 60, threshold="0min"), expected, **kwargs) assert_frame_equal(hfp(x, sf_hypno=1, threshold="0min"), expected, **kwargs) # 1b. With thresholding - expected = pd.DataFrame( - {'values': [0, 1], 'start': [0, 16], 'length': [11, 9]}) + expected = pd.DataFrame({"values": [0, 1], "start": [0, 16], "length": [11, 9]}) assert_frame_equal(hfp(x, sf_hypno=1 / 60, threshold="5min"), expected, **kwargs) assert hfp(x, sf_hypno=1, threshold="5min").size == 0 # 1c. Equal length - expected = pd.DataFrame({ - 'values': [0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0], - 'start': [0, 2, 4, 6, 8, 11, 14, 16, 18, 20, 22, 25], - 'length': [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]}) + expected = pd.DataFrame( + { + "values": [0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0], + "start": [0, 2, 4, 6, 8, 11, 14, 16, 18, 20, 22, 25], + "length": [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], + } + ) assert_frame_equal( - hfp(x, sf_hypno=1 / 60, threshold="2min", equal_length=True), expected, **kwargs) + hfp(x, sf_hypno=1 / 60, threshold="2min", equal_length=True), expected, **kwargs + ) # TEST 2: MULTI-CLASS VECTOR x = [0, 0, 0, 0, 1, 2, 2, 2, 2, 2, 2, 0, 0, 0, 1, 0, 1] - expected = pd.DataFrame({ - 'values': [0, 1, 2, 0, 1, 0, 1], - 'start': [0, 4, 5, 11, 14, 15, 16], - 'length': [4, 1, 6, 3, 1, 1, 1]}) + expected = pd.DataFrame( + { + "values": [0, 1, 2, 0, 1, 0, 1], + "start": [0, 4, 5, 11, 14, 15, 16], + "length": [4, 1, 6, 3, 1, 1, 1], + } + ) assert_frame_equal(hfp(x, sf_hypno=1 / 60, threshold="0min"), expected, **kwargs) # With a string dtype expected["values"] = expected["values"].astype(str) assert_frame_equal( - hfp(np.array(x).astype(str), sf_hypno=1 / 60, threshold="0min"), expected, **kwargs) + hfp(np.array(x).astype(str), sf_hypno=1 / 60, threshold="0min"), expected, **kwargs + ) diff --git a/yasa/tests/test_io.py b/yasa/tests/test_io.py index 1904d3c..952a2d3 100644 --- a/yasa/tests/test_io.py +++ b/yasa/tests/test_io.py @@ -2,11 +2,15 @@ import pytest import logging import unittest -from yasa.io import (is_sleepecg_installed, set_log_level, is_tensorpac_installed, - is_pyriemann_installed) +from yasa.io import ( + is_sleepecg_installed, + set_log_level, + is_tensorpac_installed, + is_pyriemann_installed, +) -logger = logging.getLogger('yasa') -levels = ['debug', 'info', 'warning', 'error', 'critical'] +logger = logging.getLogger("yasa") +levels = ["debug", "info", "warning", "error", "critical"] class TestIO(unittest.TestCase): @@ -20,7 +24,7 @@ def test_log_level(self): set_log_level(True) set_log_level(None) with pytest.raises(ValueError): - set_log_level('WRONG') + set_log_level("WRONG") def test_logger(self): """Test logger levels.""" diff --git a/yasa/tests/test_numba.py b/yasa/tests/test_numba.py index b4e117b..d158667 100644 --- a/yasa/tests/test_numba.py +++ b/yasa/tests/test_numba.py @@ -6,10 +6,8 @@ class TestNumba(unittest.TestCase): - def test_numba(self): - """Test numba functions - """ + """Test numba functions""" x = np.asarray([4, 5, 7, 8, 5, 6], dtype=np.float64) y = np.asarray([1, 5, 4, 6, 8, 5], dtype=np.float64) @@ -21,8 +19,7 @@ def test_numba(self): y = np.arange(30) + 3 * np.random.random(30) times = np.arange(y.size, dtype=np.float64) slope = _slope_lstsq(times, y) - np.testing.assert_array_almost_equal(_detrend(times, y), - detrend(y, type='linear')) + np.testing.assert_array_almost_equal(_detrend(times, y), detrend(y, type="linear")) X = times[..., np.newaxis] X = np.column_stack((np.ones(X.shape[0]), X)) slope_np = np.linalg.lstsq(X, y, rcond=None)[0][1] diff --git a/yasa/tests/test_others.py b/yasa/tests/test_others.py index 10e1205..ce74c4a 100644 --- a/yasa/tests/test_others.py +++ b/yasa/tests/test_others.py @@ -6,35 +6,38 @@ from mne.filter import filter_data from yasa.hypno import hypno_str_to_int, hypno_upsample_to_data -from yasa.others import (moving_transform, trimbothstd, get_centered_indices, - sliding_window, _merge_close, _zerocrossings, - _index_to_events) +from yasa.others import ( + moving_transform, + trimbothstd, + get_centered_indices, + sliding_window, + _merge_close, + _zerocrossings, + _index_to_events, +) # Load data -data = np.loadtxt('notebooks/data_N2_spindles_15sec_200Hz.txt') +data = np.loadtxt("notebooks/data_N2_spindles_15sec_200Hz.txt") sf = 200 -data_sigma = filter_data(data, sf, 12, 15, method='fir', verbose=0) +data_sigma = filter_data(data, sf, 12, 15, method="fir", verbose=0) # Load a full recording and its hypnogram -file_full = np.load('notebooks/data_full_6hrs_100Hz_Cz+Fz+Pz.npz') -data_full = file_full.get('data') -chan_full = file_full.get('chan') +file_full = np.load("notebooks/data_full_6hrs_100Hz_Cz+Fz+Pz.npz") +data_full = file_full.get("data") +chan_full = file_full.get("chan") sf_full = 100 -hypno_full = np.load('notebooks/data_full_6hrs_100Hz_hypno.npz').get('hypno') +hypno_full = np.load("notebooks/data_full_6hrs_100Hz_hypno.npz").get("hypno") # Using MNE -data_mne = mne.io.read_raw_fif('notebooks/sub-02_mne_raw.fif', preload=True, - verbose=0) +data_mne = mne.io.read_raw_fif("notebooks/sub-02_mne_raw.fif", preload=True, verbose=0) data_mne.pick_types(eeg=True) -data_mne_single = data_mne.copy().pick_channels(['F3']) -hypno_mne = np.loadtxt('notebooks/sub-02_hypno_30s.txt', dtype=str) +data_mne_single = data_mne.copy().pick_channels(["F3"]) +hypno_mne = np.loadtxt("notebooks/sub-02_hypno_30s.txt", dtype=str) hypno_mne = hypno_str_to_int(hypno_mne) -hypno_mne = hypno_upsample_to_data(hypno=hypno_mne, sf_hypno=(1 / 30), - data=data_mne) +hypno_mne = hypno_upsample_to_data(hypno=hypno_mne, sf_hypno=(1 / 30), data=data_mne) class TestOthers(unittest.TestCase): - def test_index_to_events(self): """Test functions _index_to_events""" a = np.array([[3, 6], [8, 12], [14, 20]]) @@ -45,71 +48,67 @@ def test_index_to_events(self): def test_merge_close(self): """Test functions _merge_close""" a = np.array([4, 5, 6, 7, 10, 11, 12, 13, 20, 21, 22, 100, 102]) - good = np.array([4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 17, 18, 19, 20, 21, 22, 100, 101, 102]) + good = np.array( + [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 100, 101, 102] + ) # Events that are less than 100 ms apart (i.e. 10 points at 100 Hz sf) out = _merge_close(a, 100, 100) np.testing.assert_equal(good, out) def test_moving_transform(self): """Test moving_transform""" - method = ['mean', 'min', 'max', 'ptp', 'rms', 'prop_above_zero', - 'slope', 'corr', 'covar'] + method = ["mean", "min", "max", "ptp", "rms", "prop_above_zero", "slope", "corr", "covar"] interp = [False, True] - win = [.3, .5] - step = [0, .5] + win = [0.3, 0.5] + step = [0, 0.5] prod_args = product(win, step, method, interp) for i, (w, s, m, i) in enumerate(prod_args): moving_transform(data, data_sigma, sf, w, s, m, i) - t, out = moving_transform(data, None, sf, w, s, 'rms', True) + t, out = moving_transform(data, None, sf, w, s, "rms", True) assert t.size == out.size assert out.size == data.size def test_trimbothstd(self): - """Test function trimbothstd - """ + """Test function trimbothstd""" x = [4, 5, 7, 0, 18, 6, 7, 8, 9, 10] y = np.random.normal(size=(10, 100)) assert trimbothstd(x) < np.std(x, ddof=1) assert (trimbothstd(y) < np.std(y, ddof=1, axis=-1)).all() def test_zerocrossings(self): - """Test _zerocrossings - """ + """Test _zerocrossings""" a = np.array([4, 2, -1, -3, 1, 2, 3, -2, -5]) idx_zc = _zerocrossings(a) np.testing.assert_equal(idx_zc, [1, 3, 6]) def test_sliding_window(self): - """Test function sliding window. - """ + """Test function sliding window.""" x = np.arange(1000) # 1D t, sl = sliding_window(x, sf=100, window=2) # No overlap - assert np.array_equal(t, [0., 2., 4., 6., 8.]) + assert np.array_equal(t, [0.0, 2.0, 4.0, 6.0, 8.0]) assert np.array_equal(sl.shape, (5, 200)) t, sl = sliding_window(x, sf=100, window=2, step=1) # 1 sec overlap - assert np.array_equal(t, [0., 1., 2., 3., 4., 5., 6., 7., 8.]) + assert np.array_equal(t, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) assert np.array_equal(sl.shape, (9, 200)) - t, sl = sliding_window(np.arange(1002), sf=100., window=1., step=.1) + t, sl = sliding_window(np.arange(1002), sf=100.0, window=1.0, step=0.1) assert t.size == 91 assert np.array_equal(sl.shape, (91, 100)) # 2D x_2d = np.random.rand(2, 1100) - t, sl = sliding_window(x_2d, sf=100, window=2, step=1.) - assert np.array_equal(t, [0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]) + t, sl = sliding_window(x_2d, sf=100, window=2, step=1.0) + assert np.array_equal(t, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]) assert np.array_equal(sl.shape, (10, 2, 200)) - t, sl = sliding_window(x_2d, sf=100., window=4., step=None) - assert np.array_equal(t, [0., 4.]) + t, sl = sliding_window(x_2d, sf=100.0, window=4.0, step=None) + assert np.array_equal(t, [0.0, 4.0]) assert np.array_equal(sl.shape, (2, 2, 400)) def test_get_centered_indices(self): - """Test function get_centered_indices - """ + """Test function get_centered_indices""" data = np.arange(100) - idx = [1, 10., 20, 30, 50, 102] + idx = [1, 10.0, 20, 30, 50, 102] before, after = 3, 2 idx_ep, idx_nomask = get_centered_indices(data, idx, before, after) assert (data[idx_ep] == idx_ep).all() diff --git a/yasa/tests/test_plotting.py b/yasa/tests/test_plotting.py index 850e008..62c56fd 100644 --- a/yasa/tests/test_plotting.py +++ b/yasa/tests/test_plotting.py @@ -8,31 +8,31 @@ class TestPlotting(unittest.TestCase): - def test_topoplot(self): """Test topoplot""" data = pd.Series( - [4, 8, 7, 1, 2, 3, 5], - index=['F4', 'F3', 'C4', 'C3', 'P3', 'P4', 'Oz'], - name='Values') - _ = topoplot(data, title='My first topoplot') - _ = topoplot(data, vmin=0, vmax=8, cbar_title='Hello') + [4, 8, 7, 1, 2, 3, 5], index=["F4", "F3", "C4", "C3", "P3", "P4", "Oz"], name="Values" + ) + _ = topoplot(data, title="My first topoplot") + _ = topoplot(data, vmin=0, vmax=8, cbar_title="Hello") _ = topoplot(data, n_colors=10, vmin=0, cmap="Blues") - _ = topoplot(data, sensors='ko', res=64, names='values', show_names=True) + _ = topoplot(data, sensors="ko", res=64, names="values", show_names=True) data = pd.Series( - [-4, -8, -7, -1, -2, -3], - index=['F4-M1', 'F3-M1', 'C4-M1', 'C3-M1', 'P3-M1', 'P4-M1']) + [-4, -8, -7, -1, -2, -3], index=["F4-M1", "F3-M1", "C4-M1", "C3-M1", "P3-M1", "P4-M1"] + ) _ = topoplot(data) - _ = topoplot(data, vmin=0, vmax=8, cbar_title='Hello') + _ = topoplot(data, vmin=0, vmax=8, cbar_title="Hello") _ = topoplot(data, n_colors=10, vmin=0, cmap="Blues") _ = topoplot(data, show_names=False) - data = pd.Series([-0.5, -0.7, -0.3, 0.1, 0.15, 0.3, 0.55], - index=['F3', 'Fz', 'F4', 'C3', 'Cz', 'C4', 'Pz']) + data = pd.Series( + [-0.5, -0.7, -0.3, 0.1, 0.15, 0.3, 0.55], + index=["F3", "Fz", "F4", "C3", "Cz", "C4", "Pz"], + ) _ = topoplot(data, vmin=-1, vmax=1, n_colors=8) - plt.close('all') + plt.close("all") def test_plot_hypnogram(self): """Test plot_hypnogram function.""" @@ -55,4 +55,4 @@ def test_plot_hypnogram(self): with pytest.raises(AssertionError): hypno = np.repeat([0, 1, 2, 3, 4, -2, -1, -3], 120) _ = plot_hypnogram(hypno) - plt.close('all') + plt.close("all") diff --git a/yasa/tests/test_sleepstats.py b/yasa/tests/test_sleepstats.py index b2b13b1..fb10761 100644 --- a/yasa/tests/test_sleepstats.py +++ b/yasa/tests/test_sleepstats.py @@ -8,10 +8,8 @@ class TestSleepStats(unittest.TestCase): - def test_transition(self): - """Test transition_matrix - """ + """Test transition_matrix""" a = [1, 1, 1, 0, 0, 2, 2, 0, 2, 0, 1, 1, 0, 0] counts, probs = transition_matrix(a) c = np.array([[2, 1, 2], [2, 3, 0], [2, 0, 1]]) @@ -24,22 +22,36 @@ def test_transition(self): counts, probs = transition_matrix(x) c = np.array([[2, 2, 1], [2, 1, 0], [1, 0, 1]]) p = np.array([[0.4, 0.4, 0.2], [2 / 3, 1 / 3, 0], [0.5, 0, 0.5]]) - assert pd.DataFrame(c, index=[0, 2, 4], - columns=[0, 2, 4]).equals(counts) - assert pd.DataFrame(p, index=[0, 2, 4], - columns=[0, 2, 4]).equals(probs) + assert pd.DataFrame(c, index=[0, 2, 4], columns=[0, 2, 4]).equals(counts) + assert pd.DataFrame(p, index=[0, 2, 4], columns=[0, 2, 4]).equals(probs) assert (probs.sum(1) == 1).all() def test_sleepstatistics(self): - """Test sleep statistics. - """ + """Test sleep statistics.""" a = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 3, 3, 4, 4, 4, 4, 0, 0] - validation = {'TIB': 10.0, 'SPT': 8.0, 'WASO': 0.0, 'TST': 8.0, - 'N1': 1.5, 'N2': 2.0, 'N3': 2.5, 'REM': 2.0, - 'NREM': 6.0, 'SOL': 1.0, 'Lat_N1': 1.0, 'Lat_N2': 2.5, - 'Lat_N3': 4.0, 'Lat_REM': 7.0, - '%N1': 18.75, '%N2': 25.0, '%N3': 31.25, '%REM': 25.0, - '%NREM': 75.0, 'SE': 80.0, 'SME': 100.0} + validation = { + "TIB": 10.0, + "SPT": 8.0, + "WASO": 0.0, + "TST": 8.0, + "N1": 1.5, + "N2": 2.0, + "N3": 2.5, + "REM": 2.0, + "NREM": 6.0, + "SOL": 1.0, + "Lat_N1": 1.0, + "Lat_N2": 2.5, + "Lat_N3": 4.0, + "Lat_REM": 7.0, + "%N1": 18.75, + "%N2": 25.0, + "%N3": 31.25, + "%REM": 25.0, + "%NREM": 75.0, + "SE": 80.0, + "SME": 100.0, + } s = sleep_statistics(a, sf_hyp=1 / 30) # Compare with different sampling frequencies @@ -51,4 +63,4 @@ def test_sleepstatistics(self): a = [0, 0, 1, 1, 0, 0, 0, 0, 2, 2, 2, 0, 1, 1, 0, 0, 0] s = sleep_statistics(a, sf_hyp=1 / 60) # We cannot compare with NaN - assert s['%REM'] == 0 + assert s["%REM"] == 0 diff --git a/yasa/tests/test_spectral.py b/yasa/tests/test_spectral.py index 36c4c21..e20ed6d 100644 --- a/yasa/tests/test_spectral.py +++ b/yasa/tests/test_spectral.py @@ -8,50 +8,52 @@ from yasa.plotting import plot_spectrogram from yasa.hypno import hypno_str_to_int, hypno_upsample_to_data -from yasa.spectral import (bandpower, bandpower_from_psd, - bandpower_from_psd_ndarray, irasa, stft_power) +from yasa.spectral import ( + bandpower, + bandpower_from_psd, + bandpower_from_psd_ndarray, + irasa, + stft_power, +) # Load 1D data -data = np.loadtxt('notebooks/data_N2_spindles_15sec_200Hz.txt') +data = np.loadtxt("notebooks/data_N2_spindles_15sec_200Hz.txt") sf = 200 # Load a full recording and its hypnogram -file_full = np.load('notebooks/data_full_6hrs_100Hz_Cz+Fz+Pz.npz') -data_full = file_full.get('data') -chan_full = file_full.get('chan') +file_full = np.load("notebooks/data_full_6hrs_100Hz_Cz+Fz+Pz.npz") +data_full = file_full.get("data") +chan_full = file_full.get("chan") sf_full = 100 -hypno_full = np.load('notebooks/data_full_6hrs_100Hz_hypno.npz').get('hypno') +hypno_full = np.load("notebooks/data_full_6hrs_100Hz_hypno.npz").get("hypno") # Using MNE -data_mne = mne.io.read_raw_fif('notebooks/sub-02_mne_raw.fif', preload=True, - verbose=0) +data_mne = mne.io.read_raw_fif("notebooks/sub-02_mne_raw.fif", preload=True, verbose=0) data_mne.pick_types(eeg=True) -hypno_mne = np.loadtxt('notebooks/sub-02_hypno_30s.txt', dtype=str) +hypno_mne = np.loadtxt("notebooks/sub-02_hypno_30s.txt", dtype=str) hypno_mne = hypno_str_to_int(hypno_mne) -hypno_mne = hypno_upsample_to_data(hypno=hypno_mne, sf_hypno=(1 / 30), - data=data_mne) +hypno_mne = hypno_upsample_to_data(hypno=hypno_mne, sf_hypno=(1 / 30), data=data_mne) # Eyes-open 6 minutes resting-state, 2 channels, 200 Hz -raw_eo = mne.io.read_raw_fif('notebooks/data_resting_EO_200Hz_raw.fif', - verbose=0) +raw_eo = mne.io.read_raw_fif("notebooks/data_resting_EO_200Hz_raw.fif", verbose=0) data_eo = raw_eo.get_data(units=dict(eeg="uV", emg="uV", eog="uV", ecg="uV")) -sf_eo = raw_eo.info['sfreq'] +sf_eo = raw_eo.info["sfreq"] chan_eo = raw_eo.ch_names class TestSpectral(unittest.TestCase): - def test_bandpower(self): - """Test function bandpower - """ + """Test function bandpower""" # BANDPOWER bandpower(data_mne) # Raw MNE multi-channel bandpower(data, sf=sf, bandpass=True) # Single channel Numpy - bandpower(data, sf=sf, ch_names='F4') # Single channel Numpy labelled - bandpower(data_full, sf=sf_full, ch_names=chan_full, hypno=hypno_full, - include=(2, 3)) # Multi channel numpy - bandpower(data_full, sf=sf_full, hypno=hypno_full, - include=(3, 4, 5), bandpass=True) # Multi channel numpy + bandpower(data, sf=sf, ch_names="F4") # Single channel Numpy labelled + bandpower( + data_full, sf=sf_full, ch_names=chan_full, hypno=hypno_full, include=(2, 3) + ) # Multi channel numpy + bandpower( + data_full, sf=sf_full, hypno=hypno_full, include=(3, 4, 5), bandpass=True + ) # Multi channel numpy bandpower(data_mne, hypno=hypno_mne, include=2) # Raw MNE with hypno # BANDPOWER_FROM_PSD @@ -59,33 +61,33 @@ def test_bandpower(self): win = int(2 * sf) freqs, psd = welch(data, sf, nperseg=win) bp_abs_true = bandpower_from_psd(psd, freqs, relative=False) - bp = bandpower_from_psd(psd, freqs, ch_names=['F4']) - bands = ['Delta', 'Theta', 'Alpha', 'Sigma', 'Beta', 'Gamma'] + bp = bandpower_from_psd(psd, freqs, ch_names=["F4"]) + bands = ["Delta", "Theta", "Alpha", "Sigma", "Beta", "Gamma"] assert bp.shape[0] == 1 - assert bp.at[0, 'Chan'] == 'F4' - assert bp.at[0, 'FreqRes'] == 1 / (win / sf) + assert bp.at[0, "Chan"] == "F4" + assert bp.at[0, "FreqRes"] == 1 / (win / sf) assert np.isclose(bp.loc[0, bands].sum(), 1, atol=1e-2) - assert (bp.bands_ == "[(0.5, 4, 'Delta'), (4, 8, 'Theta'), " - "(8, 12, 'Alpha'), (12, 16, 'Sigma'), " - "(16, 30, 'Beta'), (30, 40, 'Gamma')]") + assert ( + bp.bands_ == "[(0.5, 4, 'Delta'), (4, 8, 'Theta'), " + "(8, 12, 'Alpha'), (12, 16, 'Sigma'), " + "(16, 30, 'Beta'), (30, 40, 'Gamma')]" + ) # Check that we can recover the physical power using TotalAbsPow - bands = ['Delta', 'Theta', 'Alpha', 'Sigma', 'Beta', 'Gamma'] - bp_abs = (bp[bands] * bp['TotalAbsPow'].values[..., None]) - np.testing.assert_array_almost_equal(bp_abs[bands].values, - bp_abs_true[bands].values) + bands = ["Delta", "Theta", "Alpha", "Sigma", "Beta", "Gamma"] + bp_abs = bp[bands] * bp["TotalAbsPow"].values[..., None] + np.testing.assert_array_almost_equal(bp_abs[bands].values, bp_abs_true[bands].values) # 2-D EEG data win = int(4 * sf) freqs, psd = welch(data_full, sf_full, nperseg=win) bp = bandpower_from_psd(psd, freqs, ch_names=chan_full) assert bp.shape[0] == len(chan_full) - assert bp.at[0, 'Chan'].upper() == 'CZ' - assert bp.at[1, 'FreqRes'] == 1 / (win / sf_full) + assert bp.at[0, "Chan"].upper() == "CZ" + assert bp.at[1, "FreqRes"] == 1 / (win / sf_full) # Unlabelled bp = bandpower_from_psd(psd, freqs, ch_names=None, relative=False) - assert np.array_equal(bp.loc[:, 'Chan'], - ['CHAN000', 'CHAN001', 'CHAN002']) + assert np.array_equal(bp.loc[:, "Chan"], ["CHAN000", "CHAN001", "CHAN002"]) # Bandpower from PSD with NDarray n_chan = 4 @@ -99,21 +101,20 @@ def test_bandpower(self): freqs, psd_3d = welch(data_3d, sf, nperseg=int(4 * sf), axis=-1) bandpower_from_psd_ndarray(psd_1d, freqs, relative=True) bandpower_from_psd_ndarray(psd_2d, freqs, relative=False) - assert (bandpower_from_psd_ndarray(psd_3d, freqs, - bands=[(0.5, 4, 'Delta')], - relative=True) == 1).all() + assert ( + bandpower_from_psd_ndarray(psd_3d, freqs, bands=[(0.5, 4, "Delta")], relative=True) == 1 + ).all() # With negative values: we should get a logger warning freqs = np.arange(0, 50.5, 0.5) psd = np.random.normal(size=(6, freqs.size)) - with self.assertLogs('yasa', level='WARNING'): + with self.assertLogs("yasa", level="WARNING"): bandpower_from_psd(psd, freqs) - with self.assertLogs('yasa', level='WARNING'): + with self.assertLogs("yasa", level="WARNING"): bandpower_from_psd_ndarray(psd, freqs) def test_irasa(self): - """Test function IRASA. - """ + """Test function IRASA.""" # 1D Numpy freqs, psd_aperiodic, psd_osc, fit_params = irasa(data=data, sf=sf) assert np.isin(freqs, np.arange(1, 30.25, 0.25), True).all() @@ -128,10 +129,9 @@ def test_irasa(self): assert len(irasa(data_mne, band=(2, 24), win_sec=2)) == 4 def test_stft_power(self): - """Test function stft_power - """ + """Test function stft_power""" window = [2, 4] - step = [0, .1, 1] + step = [0, 0.1, 1] band = [(0.5, 20), (1, 30), [5, 12], None] norm = [True, False] interp = [True, False] @@ -139,11 +139,9 @@ def test_stft_power(self): prod_args = product(window, step, band, interp, norm) for i, (w, s, b, i, n) in enumerate(prod_args): - stft_power(data, sf, window=w, step=s, band=b, interp=i, - norm=n) + stft_power(data, sf, window=w, step=s, band=b, interp=i, norm=n) - f, t, _ = stft_power(data, sf, window=4, step=.1, band=(11, 16), - interp=True, norm=False) + f, t, _ = stft_power(data, sf, window=4, step=0.1, band=(11, 16), interp=True, norm=False) assert f[1] - f[0] == 0.25 assert t.size == data.size @@ -151,15 +149,14 @@ def test_stft_power(self): assert min(f) == 11 def test_plot_spectrogram(self): - """Test function plot_spectrogram - """ + """Test function plot_spectrogram""" plot_spectrogram(data_full[0, :], sf_full, fmin=0.5, fmax=30) plot_spectrogram(data_full[0, :], sf_full, hypno_full, trimperc=5) hypno_full_art = np.copy(hypno_full) - hypno_full_art[hypno_full_art == 3.] = -1 + hypno_full_art[hypno_full_art == 3.0] = -1 # Replace N3 by Artefact plot_spectrogram(data_full[0, :], sf_full, hypno_full_art, trimperc=5) # Now replace REM by Unscored - hypno_full_art[hypno_full_art == 4.] = -2 + hypno_full_art[hypno_full_art == 4.0] = -2 plot_spectrogram(data_full[0, :], sf_full, hypno_full_art) - plt.close('all') + plt.close("all") diff --git a/yasa/tests/test_staging.py b/yasa/tests/test_staging.py index 271f59b..a87edd8 100644 --- a/yasa/tests/test_staging.py +++ b/yasa/tests/test_staging.py @@ -10,9 +10,8 @@ ############################################################################## # MNE Raw -raw = mne.io.read_raw_fif('notebooks/sub-02_mne_raw.fif', preload=True, - verbose=0) -hypno = np.loadtxt('notebooks/sub-02_hypno_30s.txt', dtype=str) +raw = mne.io.read_raw_fif("notebooks/sub-02_mne_raw.fif", preload=True, verbose=0) +hypno = np.loadtxt("notebooks/sub-02_hypno_30s.txt", dtype=str) class TestStaging(unittest.TestCase): @@ -20,8 +19,9 @@ class TestStaging(unittest.TestCase): def test_sleep_staging(self): """Test sleep staging""" - sls = SleepStaging(raw, eeg_name="C4", eog_name="EOG1", - emg_name="EMG1", metadata=dict(age=21, male=False)) + sls = SleepStaging( + raw, eeg_name="C4", eog_name="EOG1", emg_name="EMG1", metadata=dict(age=21, male=False) + ) sls.get_features() y_pred = sls.predict() proba = sls.predict_proba() @@ -33,12 +33,11 @@ def test_sleep_staging(self): # Plot sls.plot_predict_proba() sls.plot_predict_proba(proba, majority_only=True) - plt.close('all') + plt.close("all") # Same with different combinations of predictors # .. without metadata - SleepStaging(raw, eeg_name="C4", eog_name="EOG1", - emg_name="EMG1").fit() + SleepStaging(raw, eeg_name="C4", eog_name="EOG1", emg_name="EMG1").fit() # .. without EMG SleepStaging(raw, eeg_name="C4", eog_name="EOG1").fit() # .. just the EEG