Source code for msqms.qc.freq_domain_metrics

# -*- coding: utf-8 -*-
"""Frequency Domain Metric for MEG Data.
"""

import mne
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from msqms.utils import clogger
from msqms.constants import MEG_TYPE
from msqms.qc import Metrics


[docs] class FreqDomainMetric(Metrics): """ Class to calculate frequency domain metrics for MEG data. This class processes MEG data and computes frequency domain features for all MEG channels, with support for both sequential and parallel computation. Parameters ---------- raw : mne.io.Raw The raw MEG data. data_type : str The type of MEG data (e.g., 'opm' or 'squid'). origin_raw : mne.io.Raw The original raw MEG data for comparison. n_jobs : int, optional Number of parallel jobs to use for computation. Default is 1 (no parallelization). verbose : bool, optional If True, enables verbose output. Default is False. """ def __init__(self, raw: mne.io.Raw, data_type, origin_raw, n_jobs=1, verbose=False): super().__init__(raw, n_jobs=n_jobs, data_type=data_type, origin_raw=origin_raw, verbose=verbose) @staticmethod def _get_fre_domain_features(signal: np.ndarray, Fs=1000) -> dict: """ Compute frequency domain features for a single channel. Parameters ---------- signal : np.ndarray Time series data for a single channel. Fs : float, optional Sampling frequency of the data. Default is 1000 Hz. Returns ------- features : dict A dictionary of computed frequency domain features: - mean_amplitude(p1) - std_amplitude(p2) - skewness_amplitude(p3) - kurtosis_amplitude(p4) - mean_frequency(p5) - std_frequency(p6) - rms_frequency(p7) - fourth_moment_frequency(p8) - normalized_second_moment(p9) - frequency_dispersion(p10) - frequency_skewness(p11) - frequency_kurtosis(p12) - mean_absolute_deviation(p13) """ L = len(signal) y = abs(np.fft.fft(signal / L))[: int(L / 2)] y[0] = 0 # Remove DC component f = np.fft.fftfreq(L, 1 / Fs)[: int(L / 2)] fre_line_num = len(y) # Pre-compute common values to avoid repeated calculations and handle edge cases y_mean = y.mean() y_var = np.var(y) y_std = np.sqrt(y_var) if y_var > 0 else 0.0 y_sum = np.sum(y) # Handle zero sum case (flat signal or all zeros after DC removal) if y_sum == 0 or y_sum < np.finfo(float).eps: # Return default values for flat/zero signals return { "mean_amplitude": 0.0, "std_amplitude": 0.0, "skewness_amplitude": 0.0, "kurtosis_amplitude": 0.0, "mean_frequency": 0.0, "std_frequency": 0.0, "rms_frequency": 0.0, "fourth_moment_frequency": 0.0, "normalized_second_moment": 0.0, "frequency_dispersion": 0.0, "frequency_skewness": 0.0, "frequency_kurtosis": 0.0, "mean_absolute_deviation": 0.0, } # Compute mean frequency mean_freq = np.sum(f * y) / y_sum # Compute frequency variance and standard deviation freq_var = np.sum((f - mean_freq) ** 2 * y) / fre_line_num freq_std = np.sqrt(freq_var) if freq_var > 0 else 0.0 # Compute sum of f^2 * y and f^4 * y for later use f2y_sum = np.sum(f ** 2 * y) f4y_sum = np.sum(f ** 4 * y) # Handle division by zero for f2y_sum if f2y_sum == 0 or f2y_sum < np.finfo(float).eps: fourth_moment_freq = 0.0 else: fourth_moment_freq = np.sqrt(f4y_sum / f2y_sum) # Handle division by zero for normalized_second_moment if f4y_sum == 0 or f4y_sum < np.finfo(float).eps: normalized_second_moment = 0.0 else: normalized_second_moment = f2y_sum / np.sqrt(y_sum * f4y_sum) features = { "mean_amplitude": y_mean, "std_amplitude": np.sqrt(np.sum((y - y_mean) ** 2) / fre_line_num) if fre_line_num > 0 else 0.0, "skewness_amplitude": (np.sum((y - y_mean) ** 3) / (fre_line_num * y_std ** 3)) if y_std > 0 and fre_line_num > 0 else 0.0, "kurtosis_amplitude": (np.sum((y - y_mean) ** 4) / (fre_line_num * y_std ** 4)) if y_std > 0 and fre_line_num > 0 else 0.0, "mean_frequency": mean_freq, "std_frequency": freq_std, "rms_frequency": np.sqrt(f2y_sum / y_sum) if y_sum > 0 else 0.0, "fourth_moment_frequency": fourth_moment_freq, "normalized_second_moment": normalized_second_moment, "frequency_dispersion": (freq_std / mean_freq) if mean_freq != 0 else 0.0, "frequency_skewness": (np.sum((f - mean_freq) ** 3 * y) / (freq_std ** 3 * fre_line_num)) if freq_std > 0 and fre_line_num > 0 else 0.0, "frequency_kurtosis": (np.sum((f - mean_freq) ** 4 * y) / (freq_std ** 4 * fre_line_num)) if freq_std > 0 and fre_line_num > 0 else 0.0, "mean_absolute_deviation": (np.sum(abs(f - mean_freq) * y) / (freq_std * fre_line_num)) if freq_std > 0 and fre_line_num > 0 else 0.0, } return features
[docs] def compute_metrics(self, meg_type: MEG_TYPE) -> pd.DataFrame: """ Compute frequency domain metrics for MEG data. Parameters ---------- meg_type : MEG_TYPE Type of MEG channels to process (e.g., 'mag', 'grad'). Returns ------- freq_feat_df : pd.DataFrame DataFrame containing frequency domain metrics for all channels, including their average and standard deviation. """ self.meg_type = meg_type self.meg_data = self.raw.get_data(meg_type) self.meg_names = self._get_meg_names(meg_type) if self.n_jobs == 1: # Sequential computation freq_list = [] for i in range(self.meg_data.shape[0]): features = self._get_fre_domain_features(self.meg_data[i, :], Fs=self.samp_freq) freq_list.append(pd.DataFrame([features], index=[self.meg_names[i]])) freq_feat_df = pd.concat(freq_list) else: # Parallel computation clogger.info(f"Using {self.n_jobs} parallel cores.") # Store samp_freq in a local variable to ensure it's properly captured in the closure samp_freq = self.samp_freq freq_list = Parallel(self.n_jobs, verbose=10)( delayed(self._get_fre_domain_features)(single_ch_data, Fs=samp_freq) for single_ch_data in self.meg_data ) freq_feat_df = pd.DataFrame(freq_list, index=self.meg_names) # Compute average and standard deviation avg_freq_feat = freq_feat_df.mean(axis=0) std_freq_feat = freq_feat_df.std(axis=0) freq_feat_df.loc[f'avg_{meg_type}'] = avg_freq_feat freq_feat_df.loc[f'std_{meg_type}'] = std_freq_feat return freq_feat_df