Source code for msqms.qc.metrics_factory
# -*- coding: utf-8 -*-
"""
Factory class for creating and managing different metrics for MEG data.
"""
from msqms.qc import Metrics
from msqms.constants import MEG_TYPE, METRICS_DOMAIN, METRICS_COLUMNS
[docs]
class MetricsFactory:
"""
A factory class for creating and managing metrics for MEG data quality control.
This class supports registering new metric classes, creating metric instances,
and registering custom metric implementations.
Attributes
----------
_registry : dict
A registry that maps metric names to their corresponding classes.
_add_domain : str or None
The name of the most recently added custom domain.
Methods
-------
register_metric(name, metric_class)
Registers a new metric class by name.
create_metric(name, *args, **kwargs)
Creates an instance of a registered metric class.
register_custom_metric(name, func, custom_metrics_name)
Registers a custom metric class using a user-defined function.
"""
_registry = {}
_add_domain = None
[docs]
@classmethod
def register_metric(cls, name: str, metric_class: type):
"""
Register a new metric class.
Parameters
----------
name : str
The name of the metric to register.
metric_class : type
The metric class to register. Must be a subclass of `Metrics`.
Raises
------
ValueError
If the provided class is not a subclass of `Metrics`.
"""
if not issubclass(metric_class, Metrics):
raise ValueError(f"{metric_class} must be a subclass of Metrics")
cls._registry[name] = metric_class
[docs]
@classmethod
def create_metric(cls, name: str, *args, **kwargs):
"""
Create an instance of a registered metric class.
Parameters
----------
name : str
The name of the registered metric to instantiate.
*args
Positional arguments to pass to the metric class constructor.
**kwargs
Keyword arguments to pass to the metric class constructor.
Returns
-------
instance : Metrics
An instance of the specified metric class.
Raises
------
ValueError
If no metric class is registered under the given name.
"""
metric_class = cls._registry.get(name)
if not metric_class:
raise ValueError(f"No metric registered under name: {name}")
return metric_class(*args, **kwargs)
[docs]
@classmethod
def register_custom_metric(cls, name: str, func, custom_metrics_name: list):
"""
Register a custom metric class using a user-defined function.
Parameters
----------
name : str
The name of the custom metric domain.
func : callable
A function that defines how to compute the custom metrics. The function
should accept a `Metrics` instance as its first argument and `meg_type` as a keyword argument.
custom_metrics_name : list of str
Names of the custom metrics computed by the function.
Notes
-----
- The new custom domain is added to `METRICS_DOMAIN`.
- The corresponding metric names are appended to `METRICS_COLUMNS`.
"""
METRICS_DOMAIN.append("custom_domain")
METRICS_COLUMNS[name].extend(custom_metrics_name)
cls._add_domain = name
class CustomMetric(Metrics):
"""
A custom metric class for user-defined metrics.
Parameters
----------
*args
Positional arguments for the metric class.
**kwargs
Keyword arguments for the metric class.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.args = args
self.kwargs = kwargs
def compute_metrics(self, meg_type: MEG_TYPE):
"""
Compute metrics for the custom domain.
Parameters
----------
meg_type : MEG_TYPE
Type of MEG channels to process ('mag' or 'grad').
Returns
-------
meg_metrics_df : pd.DataFrame
A DataFrame containing the custom metrics with average and standard deviation added.
"""
meg_metrics_df = func(self, meg_type=meg_type)
meg_metrics_df.loc[f"avg_{meg_type}"] = meg_metrics_df.mean(axis=0)
meg_metrics_df.loc[f"std_{meg_type}"] = meg_metrics_df.std(axis=0)
return meg_metrics_df
cls.register_metric("custom_domain", CustomMetric)