bci_essentials.classification.ssvep_riemannian_mdm_classifier

SSVEP Riemannian MDM Classifier

Classifies SSVEP based on relative band power at the expected frequencies.

  1"""
  2**SSVEP Riemannian MDM Classifier**
  3
  4Classifies SSVEP based on relative band power at the expected
  5frequencies.
  6
  7"""
  8
  9# Stock libraries
 10import numpy as np
 11from sklearn.model_selection import StratifiedKFold
 12from sklearn.pipeline import Pipeline
 13from sklearn.metrics import confusion_matrix, precision_score, recall_score
 14from pyriemann.classification import MDM
 15from pyriemann.estimation import Covariances
 16from pyriemann.channelselection import FlatChannelRemover
 17
 18# Import bci_essentials modules and methods
 19from ..classification.generic_classifier import (
 20    GenericClassifier,
 21    Prediction,
 22    KernelResults,
 23)
 24from ..signal_processing import bandpass
 25from ..channel_selection import channel_selection_by_method
 26from ..utils.logger import Logger  # Logger wrapper
 27
 28# Instantiate a logger for the module at the default level of logging.INFO
 29# Logs to bci_essentials.__module__) where __module__ is the name of the module
 30logger = Logger(name=__name__)
 31
 32
 33class SsvepRiemannianMdmClassifier(GenericClassifier):
 34    """SSVEP Riemannian MDM Classifier class
 35    (*inherits from GenericClassifier*)
 36
 37    """
 38
 39    def set_ssvep_settings(
 40        self,
 41        n_splits=3,
 42        random_seed=42,
 43        n_harmonics=2,
 44        f_width=0.2,
 45        covariance_estimator="oas",
 46        remove_flats=True,
 47    ):
 48        """Set the SSVEP settings.
 49
 50        Parameters
 51        ----------
 52        n_splits : int, *optional*
 53            Number of folds for cross-validation.
 54            - Default is `3`.
 55        random_seed : int, *optional*
 56            Random seed.
 57            - Default is `42`.
 58        n_harmonics : int, *optional*
 59            Number of harmonics to be used for each frequency.
 60            - Default is `2`.
 61        f_width : float, *optional*
 62            Width of frequency bins to be used around the target
 63            frequencies.
 64            - Default is `0.2`.
 65        covariance_estimator : str, *optional*
 66            Covariance Estimator (see Covariances - pyriemann)
 67            - Default is `"oas"`.
 68        remove_flats : bool, *optional*
 69            Remove flat channels.
 70            - Default is `True`.
 71
 72        Returns
 73        -------
 74        `None`
 75            Models created used in `fit()`.
 76
 77        """
 78        # Build the cross-validation split
 79        self.n_splits = n_splits
 80        self.cv = StratifiedKFold(
 81            n_splits=self.n_splits, shuffle=True, random_state=random_seed
 82        )
 83
 84        self.rebuild = True
 85
 86        self.n_harmonics = n_harmonics
 87        self.f_width = f_width
 88        self.covariance_estimator = covariance_estimator
 89
 90        # Use an MDM classifier, maybe there will be other options later
 91        mdm = MDM(metric=dict(mean="riemann", distance="riemann"), n_jobs=1)
 92        self.clf_model = Pipeline([("MDM", mdm)])
 93        self.clf = Pipeline([("MDM", mdm)])
 94
 95        if remove_flats:
 96            rf = FlatChannelRemover()
 97            self.clf_model.steps.insert(0, ["Remove Flat Channels", rf])
 98            self.clf.steps.insert(0, ["Remove Flat Channels", rf])
 99
100    def get_ssvep_supertrial(
101        self,
102        X,
103        target_freqs,
104        fsample,
105        f_width=0.4,
106        n_harmonics=2,
107        covariance_estimator="oas",
108    ):
109        """Get SSVEP Supertrial.
110
111        Creates the Riemannian Geometry supertrial for SSVEP.
112
113        Parameters
114        ----------
115        X : numpy.ndarray
116            Trials of EEG data.
117            3D array containing data with `float` type.
118            shape = (`n_trials`,`n_channels`,`n_samples`)
119        target_freqs : numpy.ndarray
120            Target frequencies for the SSVEP.
121        fsample : float
122            Sampling rate.
123        f_width : float, *optional*
124            Width of frequency bins to be used around the target
125            frequencies.
126            - Default is `0.4`.
127        n_harmonics : int, *optional*
128            Number of harmonics to be used for each frequency.
129            - Default is `2`.
130        covarianc_estimator : str, *optional*
131            Covariance Estimator (see Covariances - pyriemann)
132            - Default is `"oas"`.
133
134        Returns
135        -------
136        super_X : numpy.ndarray
137            Supertrials of X.
138            3D array containing data with `float` type.
139
140            shape = (`n_trials`,`n_channels*number of target_freqs`,
141            `n_channels*number of target_freqs`)
142
143        """
144        n_trials, n_channels, n_samples = X.shape
145        n_target_freqs = len(target_freqs)
146
147        super_X = np.zeros(
148            [n_trials, n_channels * n_target_freqs, n_channels * n_target_freqs]
149        )
150
151        # Create super trial of all trials filtered at all bands
152        for trial in range(n_trials):
153            for tf, target_freq in enumerate(target_freqs):
154                lower_bound = int((n_channels * tf))
155                upper_bound = int((n_channels * tf) + n_channels)
156
157                signal = X[trial, :, :]
158                for f in range(n_harmonics):
159                    if f == 0:
160                        filt_signal = bandpass(
161                            signal,
162                            f_low=target_freq - (f_width / 2),
163                            f_high=target_freq + (f_width / 2),
164                            order=5,
165                            fsample=fsample,
166                        )
167                    else:
168                        filt_signal += bandpass(
169                            signal,
170                            f_low=(target_freq * (f + 1)) - (f_width / 2),
171                            f_high=(target_freq * (f + 1)) + (f_width / 2),
172                            order=5,
173                            fsample=fsample,
174                        )
175
176                cov_mat = Covariances(estimator=covariance_estimator).transform(
177                    np.expand_dims(filt_signal, axis=0)
178                )
179
180                cov_mat_diag = np.diag(np.diag(cov_mat[0, :, :]))
181
182                super_X[trial, lower_bound:upper_bound, lower_bound:upper_bound] = (
183                    cov_mat_diag
184                )
185
186        return super_X
187
188    def fit(self):
189        """Fit the model.
190
191        Returns
192        -------
193        `None`
194            Models created used in `predict()`.
195
196        """
197
198        # Convert each trial of X into a SPD of dimensions [n_trials, n_channels*nfreqs, n_channels*nfreqs]
199        n_trials, n_channels, n_samples = self.X.shape
200
201        #################
202        # Try rebuilding the classifier each time
203        if self.rebuild:
204            self.next_fit_trial = 0
205            self.clf = self.clf_model
206
207        # get temporal subset
208        subX = self.X[self.next_fit_trial :, :, :]
209        suby = self.y[self.next_fit_trial :]
210        self.next_fit_trial = n_trials
211
212        # Init predictions to all false
213        cv_preds = np.zeros(n_trials)
214
215        def __ssvep_kernel(subX, suby):
216            """SSVEP kernel.
217
218            Parameters
219            ----------
220            subX : numpy.ndarray
221                Input data ffor training/testing the classifier.
222                3D array with shape = (`n_epochs`, `n_channels`, `n_samples`).
223            suby : numpy.ndarray
224                Target labels for the input data.
225                1D array with shape = (`n_epochs`, ).
226
227            Returns
228            -------
229            kernelResults : KernelResults
230                KernelResults object containing the following attributes:
231                    model : classifier
232                        The trained classification model.
233                    cv_preds : numpy.ndarray
234                        The predictions from the model using cross validation.
235                        1D array with the same shape as `suby`.
236                    accuracy : float
237                        The accuracy of the trained classification model.
238                    precision : float
239                        The precision of the trained classification model.
240                    recall : float
241                        The recall of the trained classification model.
242
243
244            """
245            for train_idx, test_idx in self.cv.split(subX, suby):
246                self.clf = self.clf_model
247
248                X_train, X_test = subX[train_idx], subX[test_idx]
249                y_train = suby[train_idx]
250
251                # get the covariance matrices for the training set
252                X_train_super = self.get_ssvep_supertrial(
253                    X_train,
254                    self.target_freqs,
255                    fsample=256,
256                    n_harmonics=self.n_harmonics,
257                    f_width=self.f_width,
258                    covariance_estimator=self.covariance_estimator,
259                )
260                X_test_super = self.get_ssvep_supertrial(
261                    X_test,
262                    self.target_freqs,
263                    fsample=256,
264                    n_harmonics=self.n_harmonics,
265                    f_width=self.f_width,
266                    covariance_estimator=self.covariance_estimator,
267                )
268
269                # fit the classsifier
270                self.clf.fit(X_train_super, y_train)
271                cv_preds[test_idx] = self.clf.predict(X_test_super)
272
273            # Create super trial with all available data
274            X_super = self.get_ssvep_supertrial(
275                subX,
276                self.target_freqs,
277                fsample=256,
278                n_harmonics=self.n_harmonics,
279                f_width=self.f_width,
280                covariance_estimator=self.covariance_estimator,
281            )
282
283            # Train final model with all available data
284            self.clf.fit(X_super, suby)
285            model = self.clf
286
287            accuracy = sum(cv_preds == self.y) / len(cv_preds)
288            precision = precision_score(self.y, cv_preds, average="micro")
289            recall = recall_score(self.y, cv_preds, average="micro")
290
291            return KernelResults(model, cv_preds, accuracy, precision, recall)
292
293        # Check if channel selection is true
294        if self.channel_selection_setup:
295            logger.info("Doing channel selection")
296
297            channel_selection_results = channel_selection_by_method(
298                __ssvep_kernel,
299                self.X,
300                self.y,
301                self.channel_labels,  # kernel setup
302                self.chs_method,
303                self.chs_metric,
304                self.chs_initial_subset,  # wrapper setup
305                self.chs_max_time,
306                self.chs_min_channels,
307                self.chs_max_channels,
308                self.chs_performance_delta,  # stopping criterion
309                self.chs_n_jobs,
310            )
311
312            preds = channel_selection_results.best_preds
313            accuracy = channel_selection_results.best_accuracy
314            precision = channel_selection_results.best_precision
315            recall = channel_selection_results.best_recall
316
317            logger.info(
318                "The optimal subset is: %s",
319                channel_selection_results.best_channel_subset,
320            )
321
322            self.subset = channel_selection_results.best_channel_subset
323            self.clf = channel_selection_results.best_model
324        else:
325            logger.warning("Not doing channel selection")
326            current_results = __ssvep_kernel(subX, suby)
327            self.clf = current_results.model
328            preds = current_results.cv_preds
329            accuracy = current_results.accuracy
330            precision = current_results.precision
331            recall = current_results.recall
332
333        # Log performance stats
334
335        self.offline_trial_count = n_trials
336        self.offline_trial_counts.append(self.offline_trial_count)
337
338        # accuracy
339        accuracy = sum(preds == self.y) / len(preds)
340        self.offline_accuracy.append(accuracy)
341        logger.info("Accuracy = %s", accuracy)
342
343        # precision
344        precision = precision_score(self.y, preds, average="micro")
345        self.offline_precision.append(precision)
346        logger.info("Precision = %s", precision)
347
348        # recall
349        recall = recall_score(self.y, preds, average="micro")
350        self.offline_recall.append(recall)
351        logger.info("Recall = %s", recall)
352
353        # confusion matrix in command line
354        cm = confusion_matrix(self.y, preds)
355        self.offline_cm = cm
356        logger.info("Confusion matrix:\n%s", cm)
357
358    def predict(self, X):
359        """Predict the class labels for the provided data.
360
361        Parameters
362        ----------
363        X : numpy.ndarray
364            3D array where shape = (trials, channels, samples)
365
366        Returns
367        -------
368        prediction : Prediction
369            Results of predict call containing the predicted class labels, and
370            the probabilities of the labels.
371
372        """
373        # if X is 2D, make it 3D with one as first dimension
374        if len(X.shape) < 3:
375            X = X[np.newaxis, ...]
376
377        X = self.get_subset(X)
378
379        logger.info("The shape of X is %s", X.shape)
380
381        X_super = self.get_ssvep_supertrial(
382            X,
383            self.target_freqs,
384            fsample=256,
385            n_harmonics=self.n_harmonics,
386            f_width=self.f_width,
387        )
388
389        pred = self.clf.predict(X_super)
390        pred_proba = self.clf.predict_proba(X_super)
391
392        logger.info("Prediction: %s", pred)
393        logger.info("Prediction probabilities: %s", pred_proba)
394
395        for i in range(len(pred)):
396            self.predictions.append(pred[i])
397            self.pred_probas.append(pred_proba[i])
398
399        return Prediction(labels=pred, probabilities=pred_proba)
class SsvepRiemannianMdmClassifier(bci_essentials.classification.generic_classifier.GenericClassifier):
 34class SsvepRiemannianMdmClassifier(GenericClassifier):
 35    """SSVEP Riemannian MDM Classifier class
 36    (*inherits from GenericClassifier*)
 37
 38    """
 39
 40    def set_ssvep_settings(
 41        self,
 42        n_splits=3,
 43        random_seed=42,
 44        n_harmonics=2,
 45        f_width=0.2,
 46        covariance_estimator="oas",
 47        remove_flats=True,
 48    ):
 49        """Set the SSVEP settings.
 50
 51        Parameters
 52        ----------
 53        n_splits : int, *optional*
 54            Number of folds for cross-validation.
 55            - Default is `3`.
 56        random_seed : int, *optional*
 57            Random seed.
 58            - Default is `42`.
 59        n_harmonics : int, *optional*
 60            Number of harmonics to be used for each frequency.
 61            - Default is `2`.
 62        f_width : float, *optional*
 63            Width of frequency bins to be used around the target
 64            frequencies.
 65            - Default is `0.2`.
 66        covariance_estimator : str, *optional*
 67            Covariance Estimator (see Covariances - pyriemann)
 68            - Default is `"oas"`.
 69        remove_flats : bool, *optional*
 70            Remove flat channels.
 71            - Default is `True`.
 72
 73        Returns
 74        -------
 75        `None`
 76            Models created used in `fit()`.
 77
 78        """
 79        # Build the cross-validation split
 80        self.n_splits = n_splits
 81        self.cv = StratifiedKFold(
 82            n_splits=self.n_splits, shuffle=True, random_state=random_seed
 83        )
 84
 85        self.rebuild = True
 86
 87        self.n_harmonics = n_harmonics
 88        self.f_width = f_width
 89        self.covariance_estimator = covariance_estimator
 90
 91        # Use an MDM classifier, maybe there will be other options later
 92        mdm = MDM(metric=dict(mean="riemann", distance="riemann"), n_jobs=1)
 93        self.clf_model = Pipeline([("MDM", mdm)])
 94        self.clf = Pipeline([("MDM", mdm)])
 95
 96        if remove_flats:
 97            rf = FlatChannelRemover()
 98            self.clf_model.steps.insert(0, ["Remove Flat Channels", rf])
 99            self.clf.steps.insert(0, ["Remove Flat Channels", rf])
100
101    def get_ssvep_supertrial(
102        self,
103        X,
104        target_freqs,
105        fsample,
106        f_width=0.4,
107        n_harmonics=2,
108        covariance_estimator="oas",
109    ):
110        """Get SSVEP Supertrial.
111
112        Creates the Riemannian Geometry supertrial for SSVEP.
113
114        Parameters
115        ----------
116        X : numpy.ndarray
117            Trials of EEG data.
118            3D array containing data with `float` type.
119            shape = (`n_trials`,`n_channels`,`n_samples`)
120        target_freqs : numpy.ndarray
121            Target frequencies for the SSVEP.
122        fsample : float
123            Sampling rate.
124        f_width : float, *optional*
125            Width of frequency bins to be used around the target
126            frequencies.
127            - Default is `0.4`.
128        n_harmonics : int, *optional*
129            Number of harmonics to be used for each frequency.
130            - Default is `2`.
131        covarianc_estimator : str, *optional*
132            Covariance Estimator (see Covariances - pyriemann)
133            - Default is `"oas"`.
134
135        Returns
136        -------
137        super_X : numpy.ndarray
138            Supertrials of X.
139            3D array containing data with `float` type.
140
141            shape = (`n_trials`,`n_channels*number of target_freqs`,
142            `n_channels*number of target_freqs`)
143
144        """
145        n_trials, n_channels, n_samples = X.shape
146        n_target_freqs = len(target_freqs)
147
148        super_X = np.zeros(
149            [n_trials, n_channels * n_target_freqs, n_channels * n_target_freqs]
150        )
151
152        # Create super trial of all trials filtered at all bands
153        for trial in range(n_trials):
154            for tf, target_freq in enumerate(target_freqs):
155                lower_bound = int((n_channels * tf))
156                upper_bound = int((n_channels * tf) + n_channels)
157
158                signal = X[trial, :, :]
159                for f in range(n_harmonics):
160                    if f == 0:
161                        filt_signal = bandpass(
162                            signal,
163                            f_low=target_freq - (f_width / 2),
164                            f_high=target_freq + (f_width / 2),
165                            order=5,
166                            fsample=fsample,
167                        )
168                    else:
169                        filt_signal += bandpass(
170                            signal,
171                            f_low=(target_freq * (f + 1)) - (f_width / 2),
172                            f_high=(target_freq * (f + 1)) + (f_width / 2),
173                            order=5,
174                            fsample=fsample,
175                        )
176
177                cov_mat = Covariances(estimator=covariance_estimator).transform(
178                    np.expand_dims(filt_signal, axis=0)
179                )
180
181                cov_mat_diag = np.diag(np.diag(cov_mat[0, :, :]))
182
183                super_X[trial, lower_bound:upper_bound, lower_bound:upper_bound] = (
184                    cov_mat_diag
185                )
186
187        return super_X
188
189    def fit(self):
190        """Fit the model.
191
192        Returns
193        -------
194        `None`
195            Models created used in `predict()`.
196
197        """
198
199        # Convert each trial of X into a SPD of dimensions [n_trials, n_channels*nfreqs, n_channels*nfreqs]
200        n_trials, n_channels, n_samples = self.X.shape
201
202        #################
203        # Try rebuilding the classifier each time
204        if self.rebuild:
205            self.next_fit_trial = 0
206            self.clf = self.clf_model
207
208        # get temporal subset
209        subX = self.X[self.next_fit_trial :, :, :]
210        suby = self.y[self.next_fit_trial :]
211        self.next_fit_trial = n_trials
212
213        # Init predictions to all false
214        cv_preds = np.zeros(n_trials)
215
216        def __ssvep_kernel(subX, suby):
217            """SSVEP kernel.
218
219            Parameters
220            ----------
221            subX : numpy.ndarray
222                Input data ffor training/testing the classifier.
223                3D array with shape = (`n_epochs`, `n_channels`, `n_samples`).
224            suby : numpy.ndarray
225                Target labels for the input data.
226                1D array with shape = (`n_epochs`, ).
227
228            Returns
229            -------
230            kernelResults : KernelResults
231                KernelResults object containing the following attributes:
232                    model : classifier
233                        The trained classification model.
234                    cv_preds : numpy.ndarray
235                        The predictions from the model using cross validation.
236                        1D array with the same shape as `suby`.
237                    accuracy : float
238                        The accuracy of the trained classification model.
239                    precision : float
240                        The precision of the trained classification model.
241                    recall : float
242                        The recall of the trained classification model.
243
244
245            """
246            for train_idx, test_idx in self.cv.split(subX, suby):
247                self.clf = self.clf_model
248
249                X_train, X_test = subX[train_idx], subX[test_idx]
250                y_train = suby[train_idx]
251
252                # get the covariance matrices for the training set
253                X_train_super = self.get_ssvep_supertrial(
254                    X_train,
255                    self.target_freqs,
256                    fsample=256,
257                    n_harmonics=self.n_harmonics,
258                    f_width=self.f_width,
259                    covariance_estimator=self.covariance_estimator,
260                )
261                X_test_super = self.get_ssvep_supertrial(
262                    X_test,
263                    self.target_freqs,
264                    fsample=256,
265                    n_harmonics=self.n_harmonics,
266                    f_width=self.f_width,
267                    covariance_estimator=self.covariance_estimator,
268                )
269
270                # fit the classsifier
271                self.clf.fit(X_train_super, y_train)
272                cv_preds[test_idx] = self.clf.predict(X_test_super)
273
274            # Create super trial with all available data
275            X_super = self.get_ssvep_supertrial(
276                subX,
277                self.target_freqs,
278                fsample=256,
279                n_harmonics=self.n_harmonics,
280                f_width=self.f_width,
281                covariance_estimator=self.covariance_estimator,
282            )
283
284            # Train final model with all available data
285            self.clf.fit(X_super, suby)
286            model = self.clf
287
288            accuracy = sum(cv_preds == self.y) / len(cv_preds)
289            precision = precision_score(self.y, cv_preds, average="micro")
290            recall = recall_score(self.y, cv_preds, average="micro")
291
292            return KernelResults(model, cv_preds, accuracy, precision, recall)
293
294        # Check if channel selection is true
295        if self.channel_selection_setup:
296            logger.info("Doing channel selection")
297
298            channel_selection_results = channel_selection_by_method(
299                __ssvep_kernel,
300                self.X,
301                self.y,
302                self.channel_labels,  # kernel setup
303                self.chs_method,
304                self.chs_metric,
305                self.chs_initial_subset,  # wrapper setup
306                self.chs_max_time,
307                self.chs_min_channels,
308                self.chs_max_channels,
309                self.chs_performance_delta,  # stopping criterion
310                self.chs_n_jobs,
311            )
312
313            preds = channel_selection_results.best_preds
314            accuracy = channel_selection_results.best_accuracy
315            precision = channel_selection_results.best_precision
316            recall = channel_selection_results.best_recall
317
318            logger.info(
319                "The optimal subset is: %s",
320                channel_selection_results.best_channel_subset,
321            )
322
323            self.subset = channel_selection_results.best_channel_subset
324            self.clf = channel_selection_results.best_model
325        else:
326            logger.warning("Not doing channel selection")
327            current_results = __ssvep_kernel(subX, suby)
328            self.clf = current_results.model
329            preds = current_results.cv_preds
330            accuracy = current_results.accuracy
331            precision = current_results.precision
332            recall = current_results.recall
333
334        # Log performance stats
335
336        self.offline_trial_count = n_trials
337        self.offline_trial_counts.append(self.offline_trial_count)
338
339        # accuracy
340        accuracy = sum(preds == self.y) / len(preds)
341        self.offline_accuracy.append(accuracy)
342        logger.info("Accuracy = %s", accuracy)
343
344        # precision
345        precision = precision_score(self.y, preds, average="micro")
346        self.offline_precision.append(precision)
347        logger.info("Precision = %s", precision)
348
349        # recall
350        recall = recall_score(self.y, preds, average="micro")
351        self.offline_recall.append(recall)
352        logger.info("Recall = %s", recall)
353
354        # confusion matrix in command line
355        cm = confusion_matrix(self.y, preds)
356        self.offline_cm = cm
357        logger.info("Confusion matrix:\n%s", cm)
358
359    def predict(self, X):
360        """Predict the class labels for the provided data.
361
362        Parameters
363        ----------
364        X : numpy.ndarray
365            3D array where shape = (trials, channels, samples)
366
367        Returns
368        -------
369        prediction : Prediction
370            Results of predict call containing the predicted class labels, and
371            the probabilities of the labels.
372
373        """
374        # if X is 2D, make it 3D with one as first dimension
375        if len(X.shape) < 3:
376            X = X[np.newaxis, ...]
377
378        X = self.get_subset(X)
379
380        logger.info("The shape of X is %s", X.shape)
381
382        X_super = self.get_ssvep_supertrial(
383            X,
384            self.target_freqs,
385            fsample=256,
386            n_harmonics=self.n_harmonics,
387            f_width=self.f_width,
388        )
389
390        pred = self.clf.predict(X_super)
391        pred_proba = self.clf.predict_proba(X_super)
392
393        logger.info("Prediction: %s", pred)
394        logger.info("Prediction probabilities: %s", pred_proba)
395
396        for i in range(len(pred)):
397            self.predictions.append(pred[i])
398            self.pred_probas.append(pred_proba[i])
399
400        return Prediction(labels=pred, probabilities=pred_proba)

SSVEP Riemannian MDM Classifier class (inherits from GenericClassifier)

def set_ssvep_settings( self, n_splits=3, random_seed=42, n_harmonics=2, f_width=0.2, covariance_estimator='oas', remove_flats=True):
40    def set_ssvep_settings(
41        self,
42        n_splits=3,
43        random_seed=42,
44        n_harmonics=2,
45        f_width=0.2,
46        covariance_estimator="oas",
47        remove_flats=True,
48    ):
49        """Set the SSVEP settings.
50
51        Parameters
52        ----------
53        n_splits : int, *optional*
54            Number of folds for cross-validation.
55            - Default is `3`.
56        random_seed : int, *optional*
57            Random seed.
58            - Default is `42`.
59        n_harmonics : int, *optional*
60            Number of harmonics to be used for each frequency.
61            - Default is `2`.
62        f_width : float, *optional*
63            Width of frequency bins to be used around the target
64            frequencies.
65            - Default is `0.2`.
66        covariance_estimator : str, *optional*
67            Covariance Estimator (see Covariances - pyriemann)
68            - Default is `"oas"`.
69        remove_flats : bool, *optional*
70            Remove flat channels.
71            - Default is `True`.
72
73        Returns
74        -------
75        `None`
76            Models created used in `fit()`.
77
78        """
79        # Build the cross-validation split
80        self.n_splits = n_splits
81        self.cv = StratifiedKFold(
82            n_splits=self.n_splits, shuffle=True, random_state=random_seed
83        )
84
85        self.rebuild = True
86
87        self.n_harmonics = n_harmonics
88        self.f_width = f_width
89        self.covariance_estimator = covariance_estimator
90
91        # Use an MDM classifier, maybe there will be other options later
92        mdm = MDM(metric=dict(mean="riemann", distance="riemann"), n_jobs=1)
93        self.clf_model = Pipeline([("MDM", mdm)])
94        self.clf = Pipeline([("MDM", mdm)])
95
96        if remove_flats:
97            rf = FlatChannelRemover()
98            self.clf_model.steps.insert(0, ["Remove Flat Channels", rf])
99            self.clf.steps.insert(0, ["Remove Flat Channels", rf])

Set the SSVEP settings.

Parameters
  • n_splits (int, optional): Number of folds for cross-validation.
    • Default is 3.
  • random_seed (int, optional): Random seed.
    • Default is 42.
  • n_harmonics (int, optional): Number of harmonics to be used for each frequency.
    • Default is 2.
  • f_width (float, optional): Width of frequency bins to be used around the target frequencies.
    • Default is 0.2.
  • covariance_estimator (str, optional): Covariance Estimator (see Covariances - pyriemann)
    • Default is "oas".
  • remove_flats (bool, optional): Remove flat channels.
    • Default is True.
Returns
  • None: Models created used in fit().
def get_ssvep_supertrial( self, X, target_freqs, fsample, f_width=0.4, n_harmonics=2, covariance_estimator='oas'):
101    def get_ssvep_supertrial(
102        self,
103        X,
104        target_freqs,
105        fsample,
106        f_width=0.4,
107        n_harmonics=2,
108        covariance_estimator="oas",
109    ):
110        """Get SSVEP Supertrial.
111
112        Creates the Riemannian Geometry supertrial for SSVEP.
113
114        Parameters
115        ----------
116        X : numpy.ndarray
117            Trials of EEG data.
118            3D array containing data with `float` type.
119            shape = (`n_trials`,`n_channels`,`n_samples`)
120        target_freqs : numpy.ndarray
121            Target frequencies for the SSVEP.
122        fsample : float
123            Sampling rate.
124        f_width : float, *optional*
125            Width of frequency bins to be used around the target
126            frequencies.
127            - Default is `0.4`.
128        n_harmonics : int, *optional*
129            Number of harmonics to be used for each frequency.
130            - Default is `2`.
131        covarianc_estimator : str, *optional*
132            Covariance Estimator (see Covariances - pyriemann)
133            - Default is `"oas"`.
134
135        Returns
136        -------
137        super_X : numpy.ndarray
138            Supertrials of X.
139            3D array containing data with `float` type.
140
141            shape = (`n_trials`,`n_channels*number of target_freqs`,
142            `n_channels*number of target_freqs`)
143
144        """
145        n_trials, n_channels, n_samples = X.shape
146        n_target_freqs = len(target_freqs)
147
148        super_X = np.zeros(
149            [n_trials, n_channels * n_target_freqs, n_channels * n_target_freqs]
150        )
151
152        # Create super trial of all trials filtered at all bands
153        for trial in range(n_trials):
154            for tf, target_freq in enumerate(target_freqs):
155                lower_bound = int((n_channels * tf))
156                upper_bound = int((n_channels * tf) + n_channels)
157
158                signal = X[trial, :, :]
159                for f in range(n_harmonics):
160                    if f == 0:
161                        filt_signal = bandpass(
162                            signal,
163                            f_low=target_freq - (f_width / 2),
164                            f_high=target_freq + (f_width / 2),
165                            order=5,
166                            fsample=fsample,
167                        )
168                    else:
169                        filt_signal += bandpass(
170                            signal,
171                            f_low=(target_freq * (f + 1)) - (f_width / 2),
172                            f_high=(target_freq * (f + 1)) + (f_width / 2),
173                            order=5,
174                            fsample=fsample,
175                        )
176
177                cov_mat = Covariances(estimator=covariance_estimator).transform(
178                    np.expand_dims(filt_signal, axis=0)
179                )
180
181                cov_mat_diag = np.diag(np.diag(cov_mat[0, :, :]))
182
183                super_X[trial, lower_bound:upper_bound, lower_bound:upper_bound] = (
184                    cov_mat_diag
185                )
186
187        return super_X

Get SSVEP Supertrial.

Creates the Riemannian Geometry supertrial for SSVEP.

Parameters
  • X (numpy.ndarray): Trials of EEG data. 3D array containing data with float type. shape = (n_trials,n_channels,n_samples)
  • target_freqs (numpy.ndarray): Target frequencies for the SSVEP.
  • fsample (float): Sampling rate.
  • f_width (float, optional): Width of frequency bins to be used around the target frequencies.
    • Default is 0.4.
  • n_harmonics (int, optional): Number of harmonics to be used for each frequency.
    • Default is 2.
  • covarianc_estimator (str, optional): Covariance Estimator (see Covariances - pyriemann)
    • Default is "oas".
Returns
  • super_X (numpy.ndarray): Supertrials of X. 3D array containing data with float type.

    shape = (n_trials,n_channels*number of target_freqs, n_channels*number of target_freqs)

def fit(self):
189    def fit(self):
190        """Fit the model.
191
192        Returns
193        -------
194        `None`
195            Models created used in `predict()`.
196
197        """
198
199        # Convert each trial of X into a SPD of dimensions [n_trials, n_channels*nfreqs, n_channels*nfreqs]
200        n_trials, n_channels, n_samples = self.X.shape
201
202        #################
203        # Try rebuilding the classifier each time
204        if self.rebuild:
205            self.next_fit_trial = 0
206            self.clf = self.clf_model
207
208        # get temporal subset
209        subX = self.X[self.next_fit_trial :, :, :]
210        suby = self.y[self.next_fit_trial :]
211        self.next_fit_trial = n_trials
212
213        # Init predictions to all false
214        cv_preds = np.zeros(n_trials)
215
216        def __ssvep_kernel(subX, suby):
217            """SSVEP kernel.
218
219            Parameters
220            ----------
221            subX : numpy.ndarray
222                Input data ffor training/testing the classifier.
223                3D array with shape = (`n_epochs`, `n_channels`, `n_samples`).
224            suby : numpy.ndarray
225                Target labels for the input data.
226                1D array with shape = (`n_epochs`, ).
227
228            Returns
229            -------
230            kernelResults : KernelResults
231                KernelResults object containing the following attributes:
232                    model : classifier
233                        The trained classification model.
234                    cv_preds : numpy.ndarray
235                        The predictions from the model using cross validation.
236                        1D array with the same shape as `suby`.
237                    accuracy : float
238                        The accuracy of the trained classification model.
239                    precision : float
240                        The precision of the trained classification model.
241                    recall : float
242                        The recall of the trained classification model.
243
244
245            """
246            for train_idx, test_idx in self.cv.split(subX, suby):
247                self.clf = self.clf_model
248
249                X_train, X_test = subX[train_idx], subX[test_idx]
250                y_train = suby[train_idx]
251
252                # get the covariance matrices for the training set
253                X_train_super = self.get_ssvep_supertrial(
254                    X_train,
255                    self.target_freqs,
256                    fsample=256,
257                    n_harmonics=self.n_harmonics,
258                    f_width=self.f_width,
259                    covariance_estimator=self.covariance_estimator,
260                )
261                X_test_super = self.get_ssvep_supertrial(
262                    X_test,
263                    self.target_freqs,
264                    fsample=256,
265                    n_harmonics=self.n_harmonics,
266                    f_width=self.f_width,
267                    covariance_estimator=self.covariance_estimator,
268                )
269
270                # fit the classsifier
271                self.clf.fit(X_train_super, y_train)
272                cv_preds[test_idx] = self.clf.predict(X_test_super)
273
274            # Create super trial with all available data
275            X_super = self.get_ssvep_supertrial(
276                subX,
277                self.target_freqs,
278                fsample=256,
279                n_harmonics=self.n_harmonics,
280                f_width=self.f_width,
281                covariance_estimator=self.covariance_estimator,
282            )
283
284            # Train final model with all available data
285            self.clf.fit(X_super, suby)
286            model = self.clf
287
288            accuracy = sum(cv_preds == self.y) / len(cv_preds)
289            precision = precision_score(self.y, cv_preds, average="micro")
290            recall = recall_score(self.y, cv_preds, average="micro")
291
292            return KernelResults(model, cv_preds, accuracy, precision, recall)
293
294        # Check if channel selection is true
295        if self.channel_selection_setup:
296            logger.info("Doing channel selection")
297
298            channel_selection_results = channel_selection_by_method(
299                __ssvep_kernel,
300                self.X,
301                self.y,
302                self.channel_labels,  # kernel setup
303                self.chs_method,
304                self.chs_metric,
305                self.chs_initial_subset,  # wrapper setup
306                self.chs_max_time,
307                self.chs_min_channels,
308                self.chs_max_channels,
309                self.chs_performance_delta,  # stopping criterion
310                self.chs_n_jobs,
311            )
312
313            preds = channel_selection_results.best_preds
314            accuracy = channel_selection_results.best_accuracy
315            precision = channel_selection_results.best_precision
316            recall = channel_selection_results.best_recall
317
318            logger.info(
319                "The optimal subset is: %s",
320                channel_selection_results.best_channel_subset,
321            )
322
323            self.subset = channel_selection_results.best_channel_subset
324            self.clf = channel_selection_results.best_model
325        else:
326            logger.warning("Not doing channel selection")
327            current_results = __ssvep_kernel(subX, suby)
328            self.clf = current_results.model
329            preds = current_results.cv_preds
330            accuracy = current_results.accuracy
331            precision = current_results.precision
332            recall = current_results.recall
333
334        # Log performance stats
335
336        self.offline_trial_count = n_trials
337        self.offline_trial_counts.append(self.offline_trial_count)
338
339        # accuracy
340        accuracy = sum(preds == self.y) / len(preds)
341        self.offline_accuracy.append(accuracy)
342        logger.info("Accuracy = %s", accuracy)
343
344        # precision
345        precision = precision_score(self.y, preds, average="micro")
346        self.offline_precision.append(precision)
347        logger.info("Precision = %s", precision)
348
349        # recall
350        recall = recall_score(self.y, preds, average="micro")
351        self.offline_recall.append(recall)
352        logger.info("Recall = %s", recall)
353
354        # confusion matrix in command line
355        cm = confusion_matrix(self.y, preds)
356        self.offline_cm = cm
357        logger.info("Confusion matrix:\n%s", cm)

Fit the model.

Returns
def predict(self, X):
359    def predict(self, X):
360        """Predict the class labels for the provided data.
361
362        Parameters
363        ----------
364        X : numpy.ndarray
365            3D array where shape = (trials, channels, samples)
366
367        Returns
368        -------
369        prediction : Prediction
370            Results of predict call containing the predicted class labels, and
371            the probabilities of the labels.
372
373        """
374        # if X is 2D, make it 3D with one as first dimension
375        if len(X.shape) < 3:
376            X = X[np.newaxis, ...]
377
378        X = self.get_subset(X)
379
380        logger.info("The shape of X is %s", X.shape)
381
382        X_super = self.get_ssvep_supertrial(
383            X,
384            self.target_freqs,
385            fsample=256,
386            n_harmonics=self.n_harmonics,
387            f_width=self.f_width,
388        )
389
390        pred = self.clf.predict(X_super)
391        pred_proba = self.clf.predict_proba(X_super)
392
393        logger.info("Prediction: %s", pred)
394        logger.info("Prediction probabilities: %s", pred_proba)
395
396        for i in range(len(pred)):
397            self.predictions.append(pred[i])
398            self.pred_probas.append(pred_proba[i])
399
400        return Prediction(labels=pred, probabilities=pred_proba)

Predict the class labels for the provided data.

Parameters
  • X (numpy.ndarray): 3D array where shape = (trials, channels, samples)
Returns
  • prediction (Prediction): Results of predict call containing the predicted class labels, and the probabilities of the labels.