"""Wrappers for MNE functions to perform preprocessing.
"""
# Authors: Andrew Quinn <a.quinn@bham.ac.uk>
# Chetan Gohil <chetan.gohil@psych.ox.ac.uk>
# Mats van Es <mats.vanes@psych.ox.ac.uk>
import logging
import mne
import numpy as np
from os.path import exists
from scipy import stats
logger = logging.getLogger(__name__)
# --------------------------------------------------------------
# OSL preprocessing functions
#
[docs]
def gesd(x, alpha=0.05, p_out=.1, outlier_side=0):
"""Detect outliers using Generalized ESD test
Parameters
----------
x : vector
Data set containing outliers
alpha : scalar
Significance level to detect at (default = 0.05)
p_out : int
Maximum number of outliers to detect (default = 10% of data set)
outlier_side : {-1,0,1}
Specify sidedness of the test
- outlier_side = -1 -> outliers are all smaller
- outlier_side = 0 -> outliers could be small/negative or large/positive (default)
- outlier_side = 1 -> outliers are all larger
Returns
-------
idx : boolean vector
Boolean array with TRUE wherever a sample is an outlier
x2 : vector
Input array with outliers removed
References
----------
B. Rosner (1983). Percentage Points for a Generalized ESD Many-Outlier Procedure. Technometrics 25(2), pp. 165-172.
http://www.jstor.org/stable/1268549?seq=1
"""
if outlier_side == 0:
alpha = alpha / 2
if not isinstance(x, np.ndarray):
x = np.asarray(x)
n_out = int(np.ceil(len(x) * p_out))
if np.any(np.isnan(x)):
# Need to find outliers only in finite x
y = np.where(np.isnan(x))[0]
idx1, x2 = gesd(x[np.isfinite(x)], alpha, n_out, outlier_side)
# idx1 has the indexes of y which were marked as outliers
# the value of y contains the corresponding indexes of x that are outliers
idx = np.zeros_like(x).astype(bool)
idx[y[idx1]] = True
n = len(x)
temp = x.copy()
R = np.zeros((n_out,))
rm_idx = np.zeros((n_out,), dtype=int)
lam = np.zeros((n_out,))
for j in range(0, int(n_out)):
i = j + 1
if outlier_side == -1:
rm_idx[j] = np.nanargmin(temp)
sample = np.nanmin(temp)
R[j] = np.nanmean(temp) - sample
elif outlier_side == 0:
rm_idx[j] = int(np.nanargmax(abs(temp - np.nanmean(temp))))
R[j] = np.nanmax(abs(temp - np.nanmean(temp)))
elif outlier_side == 1:
rm_idx[j] = np.nanargmax(temp)
sample = np.nanmax(temp)
R[j] = sample - np.nanmean(temp)
R[j] = R[j] / np.nanstd(temp)
temp[int(rm_idx[j])] = np.nan
p = 1 - alpha / (n - i + 1)
t = stats.t.ppf(p, n - i - 1)
lam[j] = ((n - i) * t) / (np.sqrt((n - i - 1 + t ** 2) * (n - i + 1)))
# Create a boolean array of outliers
idx = np.zeros((n,)).astype(bool)
idx[rm_idx[np.where(R > lam)[0]]] = True
x2 = x[~idx]
return idx, x2
def _find_outliers_in_dims(X, axis=-1, metric_func=np.std, gesd_args=None):
"""Find outliers across specified dimensions of an array"""
if gesd_args is None:
gesd_args = {}
if axis == -1:
axis = np.arange(X.ndim)[axis]
squashed_axes = tuple(np.setdiff1d(np.arange(X.ndim), axis))
metric = metric_func(X, axis=squashed_axes)
rm_ind, _ = gesd(metric, **gesd_args)
return rm_ind
def _find_outliers_in_segments(X, axis=-1, segment_len=100,
metric_func=np.std, gesd_args=None):
"""Create dummy-segments in a dimension of an array and find outliers in it"""
if gesd_args is None:
gesd_args = {}
if axis == -1:
axis = np.arange(X.ndim)[axis]
# Prepare to slice data array
slc = []
for ii in range(X.ndim):
if ii == axis:
slc.append(slice(0, segment_len))
else:
slc.append(slice(None))
# Preallocate some variables
starts = np.arange(0, X.shape[axis], segment_len)
metric = np.zeros((len(starts),))
bad_inds = np.zeros(X.shape[axis]) * np.nan
# Main loop
for ii in range(len(starts)):
if ii == len(starts) - 1:
stop = None
else:
stop = starts[ii] + segment_len
# Update slice on dim of interest
slc[axis] = slice(starts[ii], stop)
# Compute metric for current chunk
metric[ii] = metric_func(X[tuple(slc)])
# Store which chunk we've used
bad_inds[slc[axis]] = ii
# Get bad segments
rm_ind, _ = gesd(metric, **gesd_args)
# Convert to int indices
rm_ind = np.where(rm_ind)[0]
# Convert to bool in original space of defined axis
bads = np.isin(bad_inds, rm_ind)
return bads
[docs]
def detect_artefacts(X, axis=None, reject_mode='dim', metric_func=np.std,
segment_len=100, gesd_args=None, ret_mode='bad_inds'):
"""Detect bad observations or segments in a dataset
Parameters
----------
X : ndarray
Array to find artefacts in.
axis : int
Index of the axis to detect artefacts in
reject_mode : {'dim' | 'segments'}
Flag indicating whether to detect outliers across a dimension (dim;
default) or whether to split a dim into segments and detect outliers in
the them (segments)
metric_func : function
Function defining metric to detect outliers on. Defaults to np.std but
can be any function taking an array and returning a single number.
segement_len : int > 0
Integer window length of dummy epochs for bad_segment detection
gesd_args : dict
Dictionary of arguments to pass to gesd
ret_mode : {'good_inds','bad_inds','zero_bads','nan_bads'}
Flag indicating whether to return the indices for good observations,
indices for bad observations (default), the input data with outliers
removed (zero_bads) or the input data with outliers replaced with nans
(nan_bads)
Returns
-------
ndarray
If ret_mode is ``'bad_inds'`` or ``'good_inds'``, this returns a boolean vector
of length ``X.shape[axis]`` indicating good or bad samples. If ``ret_mode`` is
``'zero_bads'`` or ``'nan_bads'`` this returns an array copy of the input data
``X`` with bad samples set to zero or ``np.nan`` respectively.
"""
if reject_mode not in ['dim', 'segments']:
raise ValueError("reject_mode: '{0}' not recognised".format(reject_mode))
if ret_mode not in ['bad_inds', 'good_inds', 'zero_bads', 'nan_bads']:
raise ValueError("ret_mode: '{0}' not recognised")
if axis is None or axis > X.ndim:
raise ValueError('bad axis')
if reject_mode == 'dim':
bad_inds = _find_outliers_in_dims(X, axis=axis, metric_func=metric_func, gesd_args=gesd_args)
elif reject_mode == 'segments':
bad_inds = _find_outliers_in_segments(X, axis=axis,
segment_len=segment_len,
metric_func=metric_func,
gesd_args=gesd_args)
if ret_mode == 'bad_inds':
return bad_inds
elif ret_mode == 'good_inds':
return bad_inds == False # noqa: E712
elif ret_mode in ['zero_bads', 'nan_bads']:
out = X.copy()
slc = []
for ii in range(X.ndim):
if ii == axis:
slc.append(bad_inds)
else:
slc.append(slice(None))
slc = tuple(slc)
if ret_mode == 'zero_bads':
out[slc] = 0
return out
elif ret_mode == 'nan_bads':
out[slc] = np.nan
return out
[docs]
def detect_maxfilt_zeros(raw):
"""This function tries to load the maxfilter log files in order
to annotate zeroed out data in the :py:class:`mne.io.Raw <mne.io.Raw>` object. It
assumes that the log file is in the same directory as the
raw file and has the same name, but with the extension ``.log``.
Parameters
----------
raw : :py:class:`mne.io.Raw <mne.io.Raw>`
MNE raw object.
Returns
-------
bad_inds : np.array of bool (n_times,) or None
Boolean array indicating which time points are zeroed out.
"""
if raw.filenames[0] is not None:
log_fname = raw.filenames[0].replace('.fif', '.log')
if 'log_fname' in locals() and exists(log_fname):
try:
starttime = raw.first_time
endtime = raw._last_time
with open(log_fname) as f:
lines = f.readlines()
# for determining the start, end and point
phrase_ndataseg = ['(', ' data buffers)']
gotduration = False
# for detecting zeroed out data
zeroed = []
phrase_zero = ['Time ', ': cont HPI is off, data block is skipped!']
for line in lines:
if gotduration == False and phrase_ndataseg[1] in line:
gotduration = True
n_dataseg = float(
line.split(phrase_ndataseg[0])[1].split(phrase_ndataseg[1])[0]) # number of segments
if phrase_zero[1] in line:
zeroed.append(float(line.split(phrase_zero[0])[1].split(phrase_zero[1])[0])) # in seconds
duration = raw.n_times / n_dataseg # duration of each data segment in samples
starts = (np.array(zeroed) - starttime) * raw.info['sfreq'] # in samples
bad_inds = np.zeros(raw.n_times)
for ii in range(len(starts)):
stop = starts[ii] + duration # in samples
bad_inds[int(starts[ii]):int(stop)] = 1
return bad_inds.astype(bool)
except:
s = "detecting zeroed out data from maxfilter log file failed"
logger.warning(s)
return None
else:
s = "No maxfilter logfile detected - detecting zeroed out data not possible"
logger.info(s)
return None
[docs]
def detect_badsegments(
raw,
picks,
segment_len=1000,
significance_level=0.05,
metric='std',
ref_meg='auto',
mode=None,
detect_zeros=True,
annotate=False,
):
"""Set bad segments in an MNE :py:class:`Raw <mne.io.Raw>` object as defined by the Generalized ESD test in :py:func:`osl.preprocessing.osl_wrappers.gesd <osl.preprocessing.osl_wrappers.gesd>`.
Parameters
----------
raw : :py:class:`mne.io.Raw <mne.io.Raw>`
MNE raw object.
picks : str
Channel types to pick. See Notes for recommendations.
segment_len : int
Window length to divide the data into (non-overlapping).
significance_level : float
Significance level for detecting outliers. Must be between 0-1.
metric : str
Metric to use. Could be ``'std'``, ``'var'`` or ``'kurtosis'``.
ref_meg : str
ref_meg argument to pass with :py:func:`mne.pick_types <mne.pick_types>`.
mode : str
Should be ``None`` ``'diff'`` or ``'maxfilter'``.
When ``mode='diff'`` we calculate a difference time series before
detecting bad segments. When ``mode='maxfilter'`` we only mark the
segments with zeros from MaxFiltering as bad.
detect_zeros : bool
Should we detect segments of zeros based on the maxfilter files?
annotate : bool
add annotations to the Raw object.
Returns
-------
if annotate is False:
bad segments : the dict of bad segments(including onset and duration).
if annotate is True:
raw : :py:class:`mne.io.Raw <mne.io.Raw>`
MNE raw object with bad segments annotated.
Notes
-----
Note that for Elekta/MEGIN data, we recommend using ``picks: 'mag'`` or ``picks: 'grad'`` separately (in no particular order).
Note that with CTF data, mne.pick_types will return:
~274 axial grads (as magnetometers) if ``{picks: 'mag', ref_meg: False}``
~28 reference axial grads if ``{picks: 'grad'}``.
Thus, it is recommended to use ``picks:'mag'`` in combination with ``ref_mag: False``, and ``picks:'grad'`` separately (in no particular order).
"""
gesd_args = {'alpha': significance_level}
if (picks == "mag") or (picks == "grad"):
chinds = mne.pick_types(raw.info, meg=picks, ref_meg=ref_meg, exclude='bads')
elif picks == "meg":
chinds = mne.pick_types(raw.info, meg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "eeg":
chinds = mne.pick_types(raw.info, eeg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "eog":
chinds = mne.pick_types(raw.info, eog=True, ref_meg=ref_meg, exclude='bads')
elif picks == "ecg":
chinds = mne.pick_types(raw.info, ecg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "emg":
chinds = mne.pick_types(raw.info, emg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "misc":
chinds = mne.pick_types(raw.info, misc=True, exclude='bads')
else:
raise NotImplementedError(f"picks={picks} not available.")
if mode is None:
if detect_zeros:
bdinds_maxfilt = detect_maxfilt_zeros(raw)
else:
bdinds_maxfilt = None
XX, XX_times = raw.get_data(picks=chinds, reject_by_annotation='omit', return_times=True)
elif mode == "diff":
bdinds_maxfilt = None
XX, XX_times = raw.get_data(picks=chinds, reject_by_annotation='omit', return_times=True)
XX = np.diff(XX, axis=1)
XX_times = XX_times[1:] # remove the first time point
elif mode == "maxfilter":
bdinds_maxfilt = detect_maxfilt_zeros(raw)
XX, XX_times = raw.get_data(picks=chinds, reject_by_annotation='omit', return_times=True)
allowed_metrics = ["std", "var", "kurtosis"]
if metric not in allowed_metrics:
raise ValueError(f"metric {metric} unknown.")
if metric == "std":
metric_func = np.std
elif metric == "var":
metric_func = np.var
else:
def kurtosis(inputs):
return stats.kurtosis(inputs, axis=None)
metric_func = kurtosis
if mode == "maxfilter":
bad_indices = [bdinds_maxfilt]
else:
bdinds = detect_artefacts(
XX,
axis=1,
reject_mode="segments",
metric_func=metric_func,
segment_len=segment_len,
ret_mode="bad_inds",
gesd_args=gesd_args,
)
bad_indices = [bdinds, bdinds_maxfilt]
for count, bdinds in enumerate(bad_indices):
if bdinds is None:
continue
if count == 1:
descp1 = count * 'maxfilter_' # when count==0, should be ''
descp2 = ' (maxfilter)'
else:
descp1 = ''
descp2 = ''
onsets = np.where(np.diff(bdinds.astype(float)) == 1)[0]
if bdinds[0]:
onsets = np.r_[0, onsets]
offsets = np.where(np.diff(bdinds.astype(float)) == -1)[0]
if bdinds[-1]:
offsets = np.r_[offsets, len(bdinds) - 1]
assert len(onsets) == len(offsets)
descriptions = np.repeat("{0}bad_segment_{1}".format(descp1, picks), len(onsets))
logger.info("Found {0} bad segments".format(len(onsets)))
onsets_secs = raw.first_samp / raw.info["sfreq"] + XX_times[onsets.astype(int)]
offsets_secs = raw.first_samp / raw.info["sfreq"] + XX_times[offsets.astype(int)]
durations_secs = offsets_secs - onsets_secs
bad_segments_annots = {"onsets":onsets_secs, "durations": durations_secs}
mod_dur = durations_secs.sum()
full_dur = raw.n_times / raw.info["sfreq"]
pc = (mod_dur / full_dur) * 100
s = "Modality {0}{1} - {2:02f}/{3} seconds rejected ({4:02f}%)"
logger.info(s.format("picks", descp2, mod_dur, full_dur, pc))
if annotate:
raw.annotations.append(onsets_secs, durations_secs, descriptions)
return raw
else:
return bad_segments_annots
[docs]
def detect_badchannels(raw, picks, ref_meg="auto", significance_level=0.05,annotate=False):
"""Set bad channels in an MNE :py:class:`Raw <mne.io.Raw>` object as defined by the Generalized ESD test in :py:func:`osl.preprocessing.osl_wrappers.gesd <osl.preprocessing.osl_wrappers.gesd>`.
Parameters
----------
raw : :py:class:`mne.io.Raw <mne.io.Raw>`
MNE raw object.
picks : str
Channel types to pick. See Notes for recommendations.
ref_meg : str
ref_meg argument to pass with :py:func:`mne.pick_types <mne.pick_types>`.
significance_level : float
Significance level for detecting outliers. Must be between 0-1.
annotate : bool
add annotations to the Raw object.
Returns
-------
if annotate is False:
bad_channels : the list of bad channels.
if annotate is True:
raw : :py:class:`mne.io.Raw <mne.io.Raw>`
MNE raw object with bad channels annotated.
Notes
-----
Note that for Elekta/MEGIN data, we recommend using ``picks:'mag'`` or ``picks:'grad'`` separately (in no particular order).
Note that with CTF data, mne.pick_types will return:
~274 axial grads (as magnetometers) if ``{picks: 'mag', ref_meg: False}``
~28 reference axial grads if ``{picks: 'grad'}``.
Thus, it is recommended to use ``picks:'mag'`` in combination with ``ref_mag: False``, and ``picks:'grad'`` separately (in no particular order).
"""
gesd_args = {'alpha': significance_level}
if (picks == "mag") or (picks == "grad"):
chinds = mne.pick_types(raw.info, meg=picks, ref_meg=ref_meg, exclude='bads')
elif picks == "meg":
chinds = mne.pick_types(raw.info, meg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "eeg":
chinds = mne.pick_types(raw.info, eeg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "eog":
chinds = mne.pick_types(raw.info, eog=True, ref_meg=ref_meg, exclude='bads')
elif picks == "ecg":
chinds = mne.pick_types(raw.info, ecg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "misc":
chinds = mne.pick_types(raw.info, misc=True, exclude='bads')
else:
raise NotImplementedError(f"picks={picks} not available.")
ch_names = np.array(raw.ch_names)[chinds]
bdinds = detect_artefacts(
raw.get_data(picks=chinds),
axis=0,
reject_mode="dim",
ret_mode="bad_inds",
gesd_args=gesd_args,
)
s = "Modality {0} - {1}/{2} channels rejected ({3:02f}%)"
pc = (bdinds.sum() / len(bdinds)) * 100
logger.info(s.format(picks, bdinds.sum(), len(bdinds), pc))
bad_channels = []
# concatenate newly found bads to existing bads
if np.any(bdinds):
bad_channels.extend(list(ch_names[np.where(bdinds)[0]]))
if annotate:
raw.info["bads"].extend(list(ch_names[np.where(bdinds)[0]]))
return raw
return bad_channels
[docs]
def drop_bad_epochs(
epochs,
picks,
significance_level=0.05,
max_percentage=0.1,
outlier_side=0,
metric='std',
ref_meg='auto',
mode=None,
):
"""Drop bad epochs in an MNE :py:class:`Epochs <mne.Epochs>` object as defined by the Generalized ESD test in :py:func:`osl.preprocessing.osl_wrappers.gesd <osl.preprocessing.osl_wrappers.gesd>`.
Parameters
----------
epochs : :py:class:`mne.Epochs <mne.Epochs>`
MNE Epochs object.
picks : str
Channel types to pick.
significance_level : float
Significance level for detecting outliers. Must be between 0-1.
max_percentage : float
Maximum fraction of the epochs to drop. Should be between 0-1.
outlier_side : int
Specify sidedness of the test:
* outlier_side = -1 -> outliers are all smaller
* outlier_side = 0 -> outliers could be small/negative or large/positive (default)
* outlier_side = 1 -> outliers are all larger
metric : str
Metric to use. Could be ``'std'``, ``'var'`` or ``'kurtosis'``.
ref_meg : str
ref_meg argument to pass with :py:func:`mne.pick_types <mne.pick_types>`.
mode : str
Should be ``'diff'`` or ``None``. When ``mode='diff'`` we calculate a difference time
series before detecting bad segments.
Returns
-------
epochs : :py:meth:`mne.Epochs <mne.Epochs>`
MNE Epochs object with bad epoches marked.
Notes
-----
Note that with CTF data, mne.pick_types will return:
~274 axial grads (as magnetometers) if ``{picks: 'mag', ref_meg: False}``
~28 reference axial grads if ``{picks: 'grad'}``.
"""
gesd_args = {
'alpha': significance_level,
'p_out': max_percentage,
'outlier_side': outlier_side,
}
if (picks == "mag") or (picks == "grad"):
chinds = mne.pick_types(epochs.info, meg=picks, ref_meg=ref_meg, exclude='bads')
elif picks == "meg":
chinds = mne.pick_types(epochs.info, meg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "eeg":
chinds = mne.pick_types(epochs.info, eeg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "eog":
chinds = mne.pick_types(epochs.info, eog=True, ref_meg=ref_meg, exclude='bads')
elif picks == "ecg":
chinds = mne.pick_types(epochs.info, ecg=True, ref_meg=ref_meg, exclude='bads')
elif picks == "misc":
chinds = mne.pick_types(epochs.info, misc=True, ref_meg=ref_meg, exclude='bads')
else:
raise NotImplementedError(f"picks={picks} not available.")
if mode is None:
X = epochs.get_data(picks=chinds)
elif mode == "diff":
X = np.diff(epochs.get_data(picks=chinds), axis=-1)
# Get the function used to calculate the evaluation metric
allowed_metrics = ["std", "var", "kurtosis"]
if metric not in allowed_metrics:
raise ValueError(f"metric {metric} unknown.")
if metric == "std":
metric_func = np.std
elif metric == "var":
metric_func = np.var
else:
metric_func = stats.kurtosis
# Calculate the metric used to evaluate whether an epoch is bad
X = metric_func(X, axis=-1)
# Average over channels so we have a metric for each trial
X = np.mean(X, axis=1)
# Use gesd to find outliers
bad_epochs, _ = gesd(X, **gesd_args)
logger.info(
f"Modality {picks} - {np.sum(bad_epochs)}/{X.shape[0]} epochs rejected"
)
# Drop bad epochs
epochs.drop(bad_epochs)
return epochs