Source code for msqms.libs.pyprep.removeTrend

"""High-pass filter and locally detrend the EEG signal."""
import logging

import mne
import numpy as np

from .utils import _eeglab_create_highpass, _eeglab_fir_filter


[docs] def removeTrend( EEG, sample_rate, detrendType="high pass", detrendCutoff=1.0, detrendChannels=None, matlab_strict=False, ): """Remove trends (i.e., slow drifts in baseline) from an array of EEG data. Parameters ---------- EEG : np.ndarray A 2-D array of EEG data to detrend. sample_rate : float The sample rate (in Hz) of the input EEG data. detrendType : str, optional Type of detrending to be performed: must be one of 'high pass', 'high pass sinc, or 'local detrend'. Defaults to 'high pass'. detrendCutoff : float, optional The high-pass cutoff frequency (in Hz) to use for detrending. Defaults to 1.0 Hz. detrendChannels : {list, None}, optional List of the indices of all channels that require detrending/filtering. If ``None``, all channels are used (default). matlab_strict : bool, optional Whether or not detrending should strictly follow MATLAB PREP's internal math, ignoring any improvements made in PyPREP over the original code (see :ref:`matlab-diffs` for more details). Defaults to ``False``. Returns ------- EEG : np.ndarray A 2-D array containing the filtered/detrended EEG data. Notes ----- High-pass filtering is implemented using the MNE filter function :func:``mne.filter.filter_data`` unless `matlab_strict` is ``True``, in which case it is performed using a minimal re-implementation of EEGLAB's ``pop_eegfiltnew``. Local detrending is performed using a Python re-implementation of the ``runline`` function from the Chronux package for MATLAB [1]_. References ---------- .. [1] http://chronux.org/ """ if len(EEG.shape) == 1: EEG = np.reshape(EEG, (1, EEG.shape[0])) if detrendType.lower() == "high pass": if matlab_strict: picks = detrendChannels if detrendChannels else range(EEG.shape[0]) filt = _eeglab_create_highpass(detrendCutoff, sample_rate) EEG[picks, :] = _eeglab_fir_filter(EEG[picks, :], filt) else: EEG = mne.filter.filter_data( EEG, sfreq=sample_rate, l_freq=detrendCutoff, h_freq=None, picks=detrendChannels, ) elif detrendType.lower() == "high pass sinc": fOrder = np.round(14080 * sample_rate / 512) fOrder = int(fOrder + fOrder % 2) EEG = mne.filter.filter_data( data=EEG, sfreq=sample_rate, l_freq=1, h_freq=None, picks=detrendChannels, filter_length=fOrder, fir_window="blackman", ) elif detrendType.lower() == "local detrend": if detrendChannels is None: detrendChannels = np.arange(0, EEG.shape[0]) windowSize = 1.5 / detrendCutoff windowSize = np.minimum(windowSize, EEG.shape[1]) stepSize = 0.02 EEG = np.transpose(EEG) n = np.round(sample_rate * windowSize) dn = np.round(sample_rate * stepSize) if dn > n or dn < 1: logging.error( "Step size should be less than the window size and " "contain at least 1 sample" ) if n == EEG.shape[0]: # data = scipy.signal.detrend(EEG, axis=0) pass else: for ch in detrendChannels: EEG[:, ch] = runline(EEG[:, ch], int(n), int(dn)) EEG = np.transpose(EEG) else: logging.warning( "No filtering/detreding performed since the detrend type did not match" ) return EEG
[docs] def runline(y, n, dn): """Perform local linear regression on a channel of EEG data. A re-implementation of the ``runline`` function from the Chronux package for MATLAB [1]_. Parameters ---------- y : np.ndarray A 1-D array of data from a single EEG channel. n : int Length of the detrending window. dn : int Length of the window step size. Returns ------- y: np.ndarray The detrended signal for the given EEG channel. References ---------- .. [1] http://chronux.org/ """ nt = y.shape[0] y_line = np.zeros((nt, 1)) norm = np.zeros((nt, 1)) nwin = int(np.ceil((nt - n) / dn)) yfit = np.zeros((nwin, n)) xwt = (np.arange(1, n + 1) - n / 2) / (n / 2) wt = np.power(1 - np.power(np.absolute(xwt), 3), 3) for j in range(0, nwin): tseg = y[dn * j : dn * j + n] y1 = np.mean(tseg) y2 = np.mean(np.multiply(np.arange(1, n + 1), tseg)) * (2 / (n + 1)) a = np.multiply(np.subtract(y2, y1), 6 / (n - 1)) b = np.subtract(y1, a * (n + 1) / 2) yfit[j, :] = np.multiply(np.arange(1, n + 1), a) + b y_line[j * dn : j * dn + n] = y_line[j * dn : j * dn + n] + np.reshape( np.multiply(yfit[j, :], wt), (n, 1) ) norm[j * dn : j * dn + n] = norm[j * dn : j * dn + n] + np.reshape(wt, (n, 1)) for i in range(0, len(norm)): if norm[i] > 0: y_line[i] = y_line[i] / norm[i] indx = (nwin - 1) * dn + n - 1 npts = len(y) - indx + 1 y_line[indx - 1 :] = np.reshape( (np.multiply(np.arange(n + 1, n + npts + 1), a) + b), (npts, 1) ) for i in range(0, len(y_line)): y[i] = y[i] - y_line[i] return y