Source code for msqms.qc.statistic_metrics

# -*- coding: utf-8 -*-
"""Statistics Domain quality control metric."""
import mne
import numpy as np
import pandas as pd
from mne.preprocessing import annotate_muscle_zscore
from scipy.stats import skew, kurtosis
from msqms.qc import Metrics
from msqms.constants import MEG_TYPE
from msqms.utils import clogger
from msqms.libs.pyprep.find_noisy_channels import NoisyChannels
from msqms.libs.osl import detect_badsegments, detect_badchannels
from msqms.utils import normative_score


[docs] class StatsDomainMetric(Metrics): """ Compute statistical domain quality metrics for MEG data. Includes baseline offset, skewness, kurtosis, and identification of bad channels, segments, and flat signals. """ def __init__(self, raw: mne.io.Raw, data_type, origin_raw: mne.io.Raw = None, n_jobs=-1, verbose=False): """ Initialize the statistical domain metric computation. Parameters ---------- raw : mne.io.Raw The MEG raw data. data_type : str Data type of the MEG ('opm' or 'squid'). origin_raw : mne.io.Raw, optional Original raw data for muscle annotation, by default None. n_jobs : int, optional Number of parallel jobs, by default -1. verbose : bool, optional Enable verbose logging, by default False. """ super().__init__(raw, n_jobs=n_jobs, data_type=data_type, origin_raw=origin_raw, verbose=verbose)
[docs] def compute_metrics(self, meg_type: MEG_TYPE): """ Compute statistical quality metrics for MEG data(all_channels * all_timepoints). Parameters ---------- meg_type : MEG_TYPE The MEG channel type ('mag', 'grad', or 'eeg'). Returns ------- pd.DataFrame DataFrame containing average and standard deviation of the metrics. """ meg_metrics = dict() self.all_meg_names = self._get_meg_names(True) self.all_meg_data = self.raw.get_data('meg') self.meg_type = meg_type self.meg_names = self._get_meg_names(self.meg_type) self.meg_data = self.raw.get_data(meg_type) max_mean_offset, mean_offset, std_mean_offset, max_median_offset, median_offset, std_median_offset = self.compute_baseline_offset( self.meg_data) meg_metrics['max_mean_offset'] = max_mean_offset meg_metrics['mean_offset'] = mean_offset meg_metrics['std_mean_offset'] = std_mean_offset meg_metrics['max_median_offset'] = max_median_offset meg_metrics['median_offset'] = median_offset meg_metrics['std_median_offset'] = std_median_offset # Bad channels bad_channels_ratio, self.bad_chan_names, self.bad_chan_index, self.bad_chan_mask = self.identify_bad_channels() meg_metrics["BadChanRatio"] = bad_channels_ratio clogger.info(f"Bad channels: {self.bad_chan_names} (ratio: {bad_channels_ratio})") clogger.info(f"Get all bad channel:{self.bad_chan_names}--BadChanRatio:{meg_metrics['BadChanRatio']}") # bad segments detection bad_segs_num, self.bad_seg_mask = self.find_bad_segments_by_osl() bad_segs_thres = float(self.data_type_specific_config['BadSegmentsRatio_threshold']) bad_segs_ratio = normative_score(bad_segs_num, bad_segs_thres) meg_metrics['BadSegmentsRatio'] = bad_segs_ratio clogger.info(f"BadSegmentsRatio is {bad_segs_ratio}") # Zero and NaN values self.zero_mask, zero_ratio = self.find_zero_values(self.all_meg_data) self.nan_mask, nan_ratio = self.find_NaN_values(self.all_meg_data) # Flat channels flat_thres_det = self.data_type_specific_config['flat_wave_detection_threshold'] flat_info = self.find_flat(flat_thres_det) flat_ratio = flat_info['flat_chan_ratio'] self.flat_mask = flat_info['flat_chan_mask'] meg_metrics['Zero_ratio'] = zero_ratio meg_metrics['NaN_ratio'] = nan_ratio meg_metrics['Flat_chan_ratio'] = flat_ratio # average meg_metrics_df = pd.DataFrame([meg_metrics], index=[f'avg_{meg_type}']) meg_metrics_df.loc[f'std_{meg_type}'] = [0] * len(meg_metrics_df.columns) # meg_metrics_df.std() return meg_metrics_df
[docs] def identify_bad_channels(self): """ Identify bad channels using multiple methods and create a mask. Returns ------- float Ratio of bad channels. list List of bad channel names. np.ndarray Mask indicating bad channels. """ bad_channels = set() # Pyprep prep_ratio, prep_bad = self.find_bad_channels_by_prep() bad_channels.update(prep_bad) # PSD psd_ratio, psd_bad = self.find_bad_channels_by_psd() bad_channels.update(psd_bad) # OSL osl_ratio, osl_bad = self.find_bad_channels_by_osl() bad_channels.update(osl_bad) # Create bad channel mask bad_chan_names = list(bad_channels) bad_chan_index = self._get_channel_index(bad_chan_names) bad_chan_mask = np.full(self.all_meg_data.shape, False, dtype=bool) bad_chan_mask[bad_chan_index] = True print("bad_chan_index:", bad_chan_index, "bad_chan_mask:", bad_chan_mask.shape) # Final ratio and list ratio = len(bad_channels) / len(self.all_meg_names) if len(self.all_meg_names) > 0 else 0.0 return ratio, bad_chan_names, bad_chan_index, bad_chan_mask
def _get_channel_index(self, channel_name_list): """Returns the channel index based on the channel name """ ch_index = [] for i in channel_name_list: try: ch_index.append(self.raw.ch_names.index(i)) except ValueError: # Channel name not found, skip it clogger.warning(f"Channel {i} not found in raw.ch_names, skipping.") continue return ch_index
[docs] def compute_skewness(self, data: np.ndarray): """ Skewness: measure the shape of the distribution # skewness = 0, normally distributed # skewness > 0, more weight in the left tail of the distribution. # skewnees < 0, more weight in the right tail of the distribution. compute the ratio of left tail of the distributrion.[by channels] compute the mean of skewness. """ skewness = skew(data, axis=1, bias=True) left_tail = len(skewness[skewness > 0]) left_skew_ratio = left_tail / data.shape[0] if data.shape[0] > 0 else 0.0 mean_skewness = np.nanmean(skewness) std_skewness = np.nanstd(skewness) return left_skew_ratio, mean_skewness, std_skewness
[docs] def compute_kurtosis(self, data: np.ndarray): """Kurtosis: # It is also a statistical term and an important characteristic of frequency distribution. # It determines whether a distribution is heavy-tailed in respect of the distribution. # It provides information about the shape of a frequency distribution. # Kurtosis for normal distribution is equal to 3. # For a distribution having kurtosis < 3: It is called playkurtic. # For a distribution having kurtosis > 3, It is called leptokurtic # and it signifies that it tries to produce more outliers rather than the normal distribution. compute mean kurtosis by channel. compute the playkurtic ratio by channel. """ kurtosis_value = kurtosis(data, bias=True, fisher=False) playkurtic_ratio = len(kurtosis_value[kurtosis_value < 3]) / data.shape[0] mean_kurtosis = np.mean(kurtosis_value) return mean_kurtosis, playkurtic_ratio
[docs] def find_bad_channels_by_prep(self): noisy_data = NoisyChannels(self.raw, random_state=1337) # find bad by corr # noisy_data.find_bad_by_correlation() # clogger.info(f"pyprep: finding bad channels by corr.{noisy_data.bad_by_correlation}") # find bad by deviation noisy_data.find_bad_by_deviation() clogger.info(f"pyprep: finding bad channels by deviation.{noisy_data.bad_by_deviation}") # find bad by snr noisy_data.find_bad_by_SNR() clogger.info(f"pyprep: finding bad channels by snr.{noisy_data.bad_by_SNR}") # find bad by nan flat noisy_data.find_bad_by_nan_flat() clogger.info(f"pyprep: finding bad channels by nan:{noisy_data.bad_by_nan}--flat:{noisy_data.bad_by_flat}") noisy_data.find_bad_by_hfnoise() clogger.info(f"pyprep: finding bad channels by hfonoise.{noisy_data.bad_by_hf_noise}") # find bad by ransac # noisy_data.find_bad_by_ransac(channel_wise=True, max_chunk_size=1) # clogger.info(f"pyprep: finding bad channels by ransac[slow].{noisy_data.bad_by_ransac}") bad_channels = noisy_data.get_bads() clogger.info(f"Get All Bad Channels:{bad_channels}") bad_channels_ratio = len(bad_channels) / len(self.all_meg_names) return bad_channels_ratio, bad_channels
[docs] def find_bad_channels_by_psd(self): """Calculate the PSD (power spectral density) of all channels. find the ones that exceed the mean plus 3*standard deviation, and determine them as bad channels. """ ch_names = np.array(self.raw.info['ch_names']) psd = self.raw.compute_psd() psd_data = psd.get_data() ch_mean_psd = np.mean(psd_data, axis=1) total_mean = np.mean(ch_mean_psd) total_std = np.std(ch_mean_psd) ids = np.where((ch_mean_psd > total_mean + 3 * total_std)) bad_channel = ch_names[ids[0]] bad_channels_ratio = len(bad_channel) / len(self.all_meg_names) clogger.info(f"Detect BadChannels by PSD: {bad_channel}") return bad_channels_ratio, bad_channel
[docs] def find_bad_channels_by_osl(self): """Find the bad channels by OSL Library. """ bad_channel_mag = detect_badchannels(self.raw, picks='mag', ref_meg=False) try: bad_channel_grad = detect_badchannels(self.raw, picks='grad') except ValueError: bad_channel_grad = [] bad_channel = bad_channel_mag + bad_channel_grad bad_channels_ratio = len(bad_channel) / len(self.all_meg_names) clogger.info( f"Bad channel name:{bad_channel}--Bad channels ratio:{bad_channels_ratio}--all channels:{len(self.all_meg_names)}") return bad_channels_ratio, bad_channel
[docs] def find_bad_segments_by_osl(self): bad_segs_num = 0 annot_muscle = None bad_segs_osl = detect_badsegments(self.raw, picks=self.meg_type, ref_meg=False, segment_len=1000, detect_zeros=False, significance_level=0.05, annotate=False) # clogger.info(f"bad segments by osl:{bad_segs_osl['onsets']}") # mne if self.origin_raw is not None and self.origin_raw.info['lowpass'] >= 140 and self.origin_raw.info['highpass'] <= 110: annot_muscle, _ = annotate_muscle_zscore(self.origin_raw, ch_type=self.meg_type, threshold=5, filter_freq=[110, 140]) # clogger.info(f"bad segments by mne:{annot_muscle.onset}") # merge if annot_muscle != None: osl_onsets = bad_segs_osl['onsets'] mne_onsets = annot_muscle.onset mne_durs = annot_muscle.duration tmp_onset = [] tmp_dur = [] for idx, o in enumerate(osl_onsets): if o not in mne_onsets: tmp_onset.append(o) tmp_dur.append(bad_segs_osl['durations'][idx]) bad_segs = {"onsets": np.append(mne_onsets, tmp_onset), "durations": np.append(mne_durs, tmp_dur)} else: bad_segs = bad_segs_osl if bad_segs and 'onsets' in bad_segs and len(bad_segs['onsets']) > 0: bad_segs_num = len(bad_segs['onsets']) else: bad_segs_num = 0 clogger.info(f"bad segments num:{bad_segs_num}") bad_seg_mask = np.full(self.meg_data.shape, False, dtype=bool) # bad segments mask if bad_segs and 'onsets' in bad_segs and 'durations' in bad_segs and len(bad_segs['onsets']) > 0: for idx, onset in enumerate(bad_segs['onsets']): duration = bad_segs['durations'][idx] seg_start = self.raw.time_as_index(onset)[0] seg_end = seg_start + int(duration * self.raw.info['sfreq']) # Ensure seg_end doesn't exceed data length seg_end = min(seg_end, self.meg_data.shape[1]) if seg_start < seg_end: bad_seg_mask[:, seg_start:seg_end] = True return bad_segs_num, bad_seg_mask
[docs] def find_zero_values(self, data: np.ndarray): """ Detect zero values. Parameters ---------- data : Returns ------- zero_mask: np.ndarray the mask of zero values. zero_ratio: float the ratio of zero values. """ zero_mask_positions = np.argwhere(data == 0) zero_mask = np.full(data.shape, False, dtype=bool) for pos in zero_mask_positions: zero_mask[tuple(pos)] = True zero_count = len(zero_mask_positions) total_elements = data.size zero_ratio = (zero_count / total_elements) * 100 return zero_mask, zero_ratio
[docs] def find_NaN_values(self, data: np.ndarray): """ Detect NaN values Parameters ---------- data : Returns ------- - NaN mask matrix - NaN ratio, accounts for all data points. """ nan_mask = np.isnan(data) nan_count = np.sum(nan_mask) # total_elements = data.size thres = float(self.data_type_specific_config['NaN_ratio_threshold']) nan_ratio = normative_score(nan_count, thres) # nan_ratio = (nan_count / total_elements) * 100 return nan_mask, nan_ratio
[docs] def find_flat(self, flat_thres): """detect flat channels or constant channels.""" if isinstance(flat_thres, str): flat_thres = float(flat_thres) std_values = np.nanstd(self.all_meg_data, axis=1) flat_chan_inds = np.argwhere(std_values <= flat_thres) flat_chan_names = [self.raw.info['ch_names'][fc[0]] for fc in flat_chan_inds] flat_chan_ratio = (len(flat_chan_names) / len(self.all_meg_names)) if len(self.all_meg_names) > 0 else 0.0 # * 100 # percentage flat_chan_mask = np.full(self.all_meg_data.shape, False, dtype=bool) for fc in flat_chan_inds: flat_chan_mask[fc] = True return {"flat_chan_names": flat_chan_names, "flat_chan_ratio": flat_chan_ratio, "flat_chan_mask": flat_chan_mask}
[docs] def compute_mag_field_change(self, data: np.ndarray): """Calculate the Mag Field Change,and record the degree of magnetic field change. Calculate the maximum value of the magnetic field change, and the mean value and variance of the magnetic field change by channel. """ diff_field = np.abs(np.diff(data, axis=1)) max_field_change = np.max(diff_field) mean_field_change = np.mean(diff_field) std_field_change = np.std(diff_field) return max_field_change, mean_field_change, std_field_change
[docs] def compute_baseline_offset(self, data: np.ndarray): """Baseline offset: Calculate the baseline drift of each channel (mean, median); Calculate the average deviation degree of the channel data mean relative to the population mean and population median. """ overall_mean = np.mean(data) channel_means = np.mean(data, axis=1) mea_offset_abs = np.abs(channel_means - overall_mean) mean_offset = np.mean(mea_offset_abs) std_mean_offset = np.std(mea_offset_abs) max_mean_offset = np.max(mea_offset_abs) # median overall_median = np.median(data) channel_medians = np.median(data, axis=1) med_offset_abs = np.abs(channel_medians - overall_median) median_offset = np.mean(med_offset_abs) std_median_offset = np.std(med_offset_abs) max_median_offset = np.max(med_offset_abs) return max_mean_offset, mean_offset, std_mean_offset, max_median_offset, median_offset, std_median_offset