Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Include a publication study script and the MNI Detector #50

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 58 additions & 65 deletions mne_hfo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
true_positive_rate, precision, false_discovery_rate
from mne_hfo.sklearn import _make_ydf_sklearn
from mne_hfo.utils import (apply_std, compute_rms,
compute_line_length, compute_hilbert, apply_hilbert,
compute_line_length, apply_hilbert,
merge_contiguous_freq_bands)

ACCEPTED_THRESHOLD_METHODS = ['std', 'hilbert']
Expand All @@ -33,12 +33,12 @@ class Detector(BaseEstimator):

Detectors fit follow the following general flow by implementing
private functions:
1. Compute a statistic on the raw data in _compute_hfo_statistic.
1. Compute a statistic on the raw data in ``compute_hfo_statistic``.
i.e. the LineLength of a time-window
2. Apply a threshold to the statistic computed in (1) in
_threshold_statistic. i.e. std of LineLength
``threshold_hfo_statistic``. i.e. std of LineLength
3. Merge contiguous/overlapping events into unique detections
in _post_process_chs_hfo. i.e. contiguous time windows
in ``post_process_chs_hfo``. i.e. contiguous time windows

Parameters
----------
Expand All @@ -50,20 +50,58 @@ class Detector(BaseEstimator):
Fraction of the window overlap (0 to 1).
scoring_func : str
Either ``'f1'``, or ``'r2'``.
name : str
The name of the HFO detector.
n_jobs : int
The number of jobs used in `joblib` parallelization.
verbose: bool
"""

def __init__(self, threshold: Union[int, float],
win_size: Union[int, None], overlap: Union[float, None],
scoring_func: str, n_jobs: int,
scoring_func: str, hfo_name: str, n_jobs: int,
verbose: bool):
self.win_size = win_size
self.threshold = threshold
self.overlap = overlap
self.scoring_func = scoring_func
self.hfo_name = hfo_name
self.verbose = verbose
self.n_jobs = n_jobs

@property
def hfo_annotations(self):
"""HFO Annotations.

Returns
-------
hfo_annotations : instance of Annotations
`mne.Annotations` object with ``onset``, ``duration``
and specified ``ch_name`` for each HFO event detected.
"""
return self.hfo_annotations_

@property
def hfo_event_arr(self):
"""HFO event array.

Returns
-------
hfo_event_arr : np.ndarray
Array that is (n_chs, n_samples), which has a
value of ``1`` if there is an HFO in that sample.
"""
return self.hfo_event_arr_

@property
def step_size(self):
"""Step size of each window.

Window increment over the samples of signal.
"""
# Calculate window values for easier operation
return int(np.ceil(self.win_size * self.overlap))

def _create_empty_event_arr(self):
"""Create an empty HFO event array.

Expand All @@ -88,7 +126,7 @@ def _create_empty_event_arr(self):
hfo_event_arr = np.empty((self.n_chs, n_windows, n_bands))
return hfo_event_arr

def _compute_hfo_statistic(self, X):
def compute_hfo_statistic(self, X):
"""Compute HFO statistic.

Takes a sliding window approach and computes the existence
Expand All @@ -98,7 +136,7 @@ def _compute_hfo_statistic(self, X):

Parameters
----------
X : np.array
X : np.array shape of (n_times,)
EEG data array for single channel: N = n_times.

Returns
Expand All @@ -110,13 +148,13 @@ def _compute_hfo_statistic(self, X):
raise NotImplementedError('Private function that computes the HFOs '
'needs to be implemented.')

def _threshold_statistic(self, hfo_statistic_arr):
def threshold_hfo_statistic(self, hfo_statistic_arr):
"""Apply threshold(s) to the calculated statistic to generate hfo events.

Parameters
----------
hfo_statistic_arr: np.ndarray
The output of _compute_hfo_statistic
The output of compute_hfo_statistic

Returns
-------
Expand All @@ -127,7 +165,7 @@ def _threshold_statistic(self, hfo_statistic_arr):
raise NotImplementedError('Private function that computes the HFOs '
'needs to be implemented.')

def _post_process_ch_hfos(self, hfo_event_array):
def post_process_ch_hfos(self, hfo_event_array):
"""Post process one channel's HFO events generally after thresholding.

Joins contiguously detected HFOs as one event.
Expand Down Expand Up @@ -278,39 +316,6 @@ def score(self, X, y, sample_weight=None):
score = false_discovery_rate(y, y_pred)
return score

@property
def hfo_annotations(self):
"""HFO Annotations.

Returns
-------
hfo_annotations : instance of Annotations
`mne.Annotations` object with ``onset``, ``duration``
and specified ``ch_name`` for each HFO event detected.
"""
return self.hfo_annotations_

@property
def hfo_event_arr(self):
"""HFO event array.

Returns
-------
hfo_event_arr : np.ndarray
Array that is (n_chs, n_samples), which has a
value of ``1`` if there is an HFO in that sample.
"""
return self.hfo_event_arr_

@property
def step_size(self):
"""Step size of each window.

Window increment over the samples of signal.
"""
# Calculate window values for easier operation
return int(np.ceil(self.win_size * self.overlap))

def to_data_frame(self, format=None):
"""Export HFO annotations in tabular structure as a pandas DataFrame.

Expand Down Expand Up @@ -388,7 +393,7 @@ def fit(self, X, y=None):
ch_name = self.ch_names[idx]

# compute HFOs for this channel
ch_hfo_events, statistic = self._fit_channel(
ch_hfo_events, statistic = self.fit_channel(
sig, sfreq, ch_name, hfo_description=hfo_description)

# create list of annotations
Expand All @@ -403,7 +408,7 @@ def fit(self, X, y=None):

# run joblib parallelization over channels
ch_hfos, statistics = zip(*Parallel(n_jobs=n_jobs)(
delayed(self._fit_channel)(
delayed(self.fit_channel)(
X[idx, :], sfreq, self.ch_names[idx], hfo_description
) for idx in tqdm(range(self.n_chs))
))
Expand All @@ -419,21 +424,20 @@ def fit(self, X, y=None):

# assign annotations object
self.hfo_annotations_ = all_hfo_annots
self.chs_hfos_ = all_hfo_annots
return self

def _fit_channel(self, sig, sfreq, ch_name, hfo_description='hfo'):
def fit_channel(self, sig, sfreq, ch_name, hfo_description='hfo'):
"""Compute a list of HFO events for channel."""
# compute the metric over the signal used to compute the HFO
# e.g. RMS, or Line Length over time
hfo_statistic_arr = self._compute_hfo_statistic(sig)
hfo_statistic_arr = self.compute_hfo_statistic(sig)

# apply the threshold(s) to the statistic to get detections
# of start and stop samples
hfo_detection_arr = self._threshold_statistic(hfo_statistic_arr)
hfo_detection_arr = self.threshold_hfo_statistic(hfo_statistic_arr)

# (optionally) post process HFOs
ch_hfo_list = self._post_process_ch_hfos(hfo_detection_arr)
ch_hfo_list = self.post_process_ch_hfos(hfo_detection_arr)

# extract onset, and durations of each HFO detected to form Annotations
onset, duration = [], []
Expand Down Expand Up @@ -525,17 +529,17 @@ def _compute_sliding_window_detection(self, sig, method):

Parameters
----------
sig: np.array
sig: np.ndarray
Data (1D array) from a single channel
method: str
Method used to compute the detection. Can be one of
``'line_length', 'rms', 'hilbert'``.

Returns
-------
signal_win_stat: np.ndarray
Statistic calculated per window

signal_win_stat: np.ndarray, shape (n_chs, n_windows)
Statistic calculated per window, where the number of
windows is equal to ``(n_samples - win_size) / step_size``.
"""
if method not in ACCEPTED_HFO_METHODS:
raise ValueError(f'Sliding window HFO detection method '
Expand Down Expand Up @@ -575,17 +579,6 @@ def _compute_sliding_window_detection(self, sig, method):
win_idx += 1
return signal_win_stat

def _compute_frq_band_detection(self, sig, method):
if method not in ACCEPTED_HFO_METHODS:
raise ValueError(f'Sliding window HFO detection method '
f'{method} is not implemented. Please '
f'use one of {ACCEPTED_HFO_METHODS}.')
if method == 'hilbert':
hfo_detect_func = compute_hilbert
signal_stat = hfo_detect_func(sig, self.freq_cutoffs,
self.freq_span, self.sfreq)
return signal_stat

def _merge_contiguous_ch_detections(self, detections, method):
"""Merge contiguous hfo detections into distinct events.

Expand Down
Loading