bci_essentials.classification.erp_rg_classifier_hyperparamgridsearch

ERP RG Classifier

This classifier is used to classify ERPs using the Riemannian Geometry approach.

  1"""**ERP RG Classifier**
  2
  3This classifier is used to classify ERPs using the Riemannian Geometry
  4approach.
  5
  6"""
  7
  8# Stock libraries
  9import numpy as np
 10import matplotlib.pyplot as plt
 11from sklearn.model_selection import StratifiedKFold, GridSearchCV
 12from sklearn.metrics import (
 13    confusion_matrix,
 14    ConfusionMatrixDisplay,
 15    precision_score,
 16    recall_score,
 17    roc_auc_score,
 18    make_scorer,
 19)
 20from sklearn.pipeline import Pipeline
 21from pyriemann.tangentspace import TangentSpace
 22from pyriemann.estimation import XdawnCovariances
 23from pyriemann.channelselection import FlatChannelRemover
 24from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
 25
 26# Import bci_essentials modules and methods
 27from .generic_classifier import (
 28    GenericClassifier,
 29    Prediction,
 30)
 31from ..signal_processing import lico, random_oversampling, random_undersampling
 32from ..utils.logger import Logger  # Logger wrapper
 33
 34# Instantiate a logger for the module at the default level of logging.INFO
 35# Logs to bci_essentials.__module__) where __module__ is the name of the module
 36logger = Logger(name=__name__)
 37
 38
 39class ErpRgClassifierHyperparamGridSearch(GenericClassifier):
 40    """ERP RG Classifier with hyperparameter grid search
 41    class (*inherits from `GenericClassifier`*)."""
 42
 43    def set_p300_clf_settings(
 44        self,
 45        n_splits=3,
 46        resampling_method=None,
 47        lico_expansion_factor=1,
 48        oversample_ratio=0,
 49        undersample_ratio=0,
 50        random_seed=42,
 51        remove_flats=True,
 52    ):
 53        """Set P300 Classifier Settings.
 54
 55        Parameters
 56        ----------
 57        n_splits : int, *optional*
 58            Number of folds for cross-validation.
 59            - Default is `3`.
 60        resampling_method : str, *optional*, None
 61            Resampling method to use ["lico", "oversample", "undersample"].
 62            Default is None.
 63        lico_expansion_factor : int, *optional*
 64            Linear Combination Oversampling expansion factor, which is the
 65            factor by which the number of ERPs in the training set will be
 66            expanded.
 67            - Default is `1`.
 68        oversample_ratio : float, *optional*
 69            Traditional oversampling. Range is from from 0.1-1 resulting
 70            from the ratio of erp to non-erp class. 0 for no oversampling.
 71            - Default is `0`.
 72        undersample_ratio : float, *optional*
 73            Traditional undersampling. Range is from from 0.1-1 resulting
 74            from the ratio of erp to non-erp class. 0 for no undersampling.
 75            - Default is `0`.
 76        random_seed : int, *optional*
 77            Random seed.
 78            - Default is `42`.
 79        remove_flats : bool, *optional*
 80            Whether to remove flat channels.
 81            - Default is `True`.
 82
 83        Returns
 84        -------
 85        `None`
 86
 87        """
 88        self.n_splits = n_splits
 89        self.resampling_method = resampling_method
 90        self.lico_expansion_factor = lico_expansion_factor
 91        self.oversample_ratio = oversample_ratio
 92        self.undersample_ratio = undersample_ratio
 93        self.random_seed = random_seed
 94
 95        # # Create steps list with proper formatting
 96        steps = []
 97        if remove_flats:
 98            steps.append(("remove_flats", FlatChannelRemover()))
 99
100        steps.extend(
101            [
102                ("xdawn", XdawnCovariances()),
103                ("tangent", TangentSpace()),
104                ("lda", LinearDiscriminantAnalysis()),
105            ]
106        )
107
108        # Create pipeline
109        self.clf = Pipeline(steps)
110
111        # Hyperparameters to be optimized
112        # TODO: Implement an extended nfilter set, dynamically based on the number of channels
113        # Example of dynamic nfilter set
114        # n_channels = self.X.shape[1]
115        # nfilter_set = list(range(2, n_channels+1))  # Example range from 2 to n_channels inclusive
116        # Then set "xdawn__nfilter": nfilter_set in the param_grid below
117        self.param_grid = {
118            "xdawn__nfilter": [2, 3, 4],
119            "xdawn__estimator": ["oas", "lwf"],
120            "tangent__metric": ["riemann"],
121            "lda__solver": ["lsqr", "eigen"],
122            "lda__shrinkage": np.linspace(0.5, 0.9, 5),
123        }
124
125    def fit(
126        self,
127        plot_cm=False,
128        plot_roc=False,
129    ):
130        """Fit the model.
131
132        Parameters
133        ----------
134        plot_cm : bool, *optional*
135            Whether to plot the confusion matrix during training.
136            - Default is `False`.
137        plot_roc : bool, *optional*
138            Whether to plot the ROC curve during training.
139            - Default is `False`.
140
141        Returns
142        -------
143        `None`
144            Models created used in `predict()`.
145
146        """
147        logger.info("Fitting the model using RG")
148        logger.info("X shape: %s", self.X.shape)
149        logger.info("y shape: %s", self.y.shape)
150
151        # Resample data if needed
152        self.X, self.y = self.__resample_data()
153
154        # Optimize hyperparameters with cross-validation
155        self.__optimize_hyperparameters()
156
157        # Fit the model with the complete dataset and optimized hyperparameters
158        self.clf.fit(self.X, self.y)
159
160        # Get predictions for final model
161        y_pred_proba = self.clf.predict_proba(self.X)[:, 1]
162
163        # Calculate estimate of training metrics of final model
164        # TODO: Implement proper training metrics calculation, using cross validation.
165        # self.offline_accuracy = sum(y_pred == self.y) / len(self.y)
166        # self.offline_precision = precision_score(self.y, y_pred)
167        # self.offline_recall = recall_score(self.y, y_pred)
168
169        try:
170            roc_auc = roc_auc_score(self.y, y_pred_proba)
171            logger.info(f"ROC AUC Score: {roc_auc:0.3f}")
172        except Exception as e:
173            logger.warning(f"Could not calculate ROC AUC score: {e}")
174
175        # Display training confusion matrix
176        # self.offline_cm = confusion_matrix(self.y, y_pred)
177        if plot_cm:
178            disp = ConfusionMatrixDisplay(confusion_matrix=self.offline_cm)
179            disp.plot()
180            plt.title("Training confusion matrix")
181
182        if plot_roc:
183            # TODO Implementation missing
184            pass
185
186        # Log training metrics
187        logger.info("Final model training performance metrics:")
188        logger.info(f"Accuracy: {self.offline_accuracy:0.3f} - MAY NOT BE ACCURATE")
189        logger.info(f"Precision: {self.offline_precision:0.3f} - MAY NOT BE ACCURATE")
190        logger.info(f"Recall: {self.offline_recall:0.3f} - MAY NOT BE ACCURATE")
191        logger.info(f"Confusion Matrix:\n{self.offline_cm} ")
192        logger.warning(
193            "Note: Training metrics may not be accurate due to the use of "
194            "cross-validation and resampling methods. Use with caution."
195        )
196
197    def predict(self, X):
198        """Predict the class of the data
199
200        Parameters
201        ----------
202        X : numpy.ndarray
203            3D array where shape = (n_epochs, n_channels, n_samples)
204
205        Returns
206        -------
207        prediction : Prediction
208            Predict object. Contains the predicted labels and and the probability.
209            Because this classifier chooses the P300 object with the highest posterior probability,
210            the probability is only the posterior probability of the chosen object.
211
212        """
213
214        subset_X = self.get_subset(X, self.subset, self.channel_labels)
215
216        # Get posterior probability for each target
217        posterior_prob = self.clf.predict_proba(subset_X)[:, 1]
218
219        label = [int(np.argmax(posterior_prob))]
220        probability = [np.max(posterior_prob)]
221
222        return Prediction(label, probability)
223
224    # TODO implement additional resampling methods, JIRA ticket: B4K-342
225    def __resample_data(self):
226        """Resample data based on the selected method"""
227
228        X_resampled = self.X.copy()
229        y_resampled = self.y.copy()
230
231        try:
232            if (self.resampling_method == "lico") and (self.lico_expansion_factor > 1):
233                [X_resampled, y_resampled] = lico(
234                    self.X, self.y, self.lico_expansion_factor
235                )
236                pass
237
238            elif (self.resampling_method == "oversample") and (
239                self.oversample_ratio > 0
240            ):
241                [X_resampled, y_resampled] = random_oversampling(
242                    self.X, self.y, self.oversample_ratio
243                )
244                pass
245
246            elif (self.resampling_method == "undersample") and (
247                self.undersample_ratio > 0
248            ):
249                [X_resampled, y_resampled] = random_undersampling(
250                    self.X, self.y, self.undersample_ratio
251                )
252                pass
253
254            logger.info(f"Resampling  with {self.resampling_method} done")
255            logger.info(f"X_resampled shape: {X_resampled.shape}")
256            logger.info(f"y_resampled shape: {y_resampled.shape}")
257
258        except Exception as e:
259            logger.error(
260                f"{self.resampling_method.capitalize()} resampling method failed"
261            )
262            logger.error(e)
263
264        return X_resampled, y_resampled
265
266    def __optimize_hyperparameters(self):
267        """Optimize hyperparameters with cross-validation using brute force grid search
268
269        Returns
270        -------
271        `None`
272            Model with best hyperparameters to be used in `predict()`.
273
274        """
275
276        # Perform cross-validation
277        cv = StratifiedKFold(
278            n_splits=self.n_splits, shuffle=True, random_state=self.random_seed
279        )
280
281        # Create custom scorer function
282        custom_scorer = make_scorer(
283            self._valid_roc_auc, response_method="predict_proba", greater_is_better=True
284        )
285
286        # Create GridSearchCV object
287        grid_search = GridSearchCV(
288            estimator=self.clf,
289            param_grid=self.param_grid,
290            cv=cv,
291            n_jobs=-1,
292            verbose=1,
293            scoring=custom_scorer,
294            refit=True,
295            return_train_score=True,
296        )
297
298        # Start grid search optimization
299        logger.info("Starting grid search optimization...")
300        grid_search.fit(self.X, self.y)
301
302        # Get best parameters and score
303        logger.info("Grid search optimization completed.")
304        best_params = grid_search.best_params_
305        best_score = grid_search.best_score_
306
307        # Report training metrics: TODO: Verify this is the right way to calculate training metrics
308        self.offline_accuracy = grid_search.best_estimator_.score(self.X, self.y)
309        self.offline_cm = confusion_matrix(
310            self.y, grid_search.best_estimator_.predict(self.X)
311        )
312        self.offline_precision = precision_score(
313            self.y, grid_search.best_estimator_.predict(self.X)
314        )
315        self.offline_recall = recall_score(
316            self.y, grid_search.best_estimator_.predict(self.X)
317        )
318
319        # Update classifier with best parameters
320        self.clf.set_params(**best_params)
321        logger.info(f"Best parameters found: {best_params}")
322        logger.info(f"Best CV score: {best_score:0.3f}")
323
324    def _valid_roc_auc(self, y_true, y_pred, **kwargs):
325        """Calculate the ROC AUC score for the classifier.
326        This method is used because the stock `roc_auc_score` function
327        does not handle the case where one class is missing in the fold.
328        This method will return 0.5 in that case.
329
330        Parameters
331        ----------
332        y_true : numpy.ndarray
333            True labels.
334        y_pred : numpy.ndarray
335            Predicted labels.
336        **kwargs : dict
337            Additional keyword arguments passed by make_scorer.
338
339        Returns
340        -------
341        roc_auc : float
342            ROC AUC score.
343
344        """
345        try:
346            # Check if we have both classes in the fold
347            if len(np.unique(y_true)) < 2:
348                logger.warning("Fold contains only one class")
349                return 0.5
350
351            return roc_auc_score(y_true, y_pred)
352
353        except Exception as e:
354            logger.warning(f"ROC AUC calculation failed: {e}")
355            return 0.5
class ErpRgClassifierHyperparamGridSearch(bci_essentials.classification.generic_classifier.GenericClassifier):
 40class ErpRgClassifierHyperparamGridSearch(GenericClassifier):
 41    """ERP RG Classifier with hyperparameter grid search
 42    class (*inherits from `GenericClassifier`*)."""
 43
 44    def set_p300_clf_settings(
 45        self,
 46        n_splits=3,
 47        resampling_method=None,
 48        lico_expansion_factor=1,
 49        oversample_ratio=0,
 50        undersample_ratio=0,
 51        random_seed=42,
 52        remove_flats=True,
 53    ):
 54        """Set P300 Classifier Settings.
 55
 56        Parameters
 57        ----------
 58        n_splits : int, *optional*
 59            Number of folds for cross-validation.
 60            - Default is `3`.
 61        resampling_method : str, *optional*, None
 62            Resampling method to use ["lico", "oversample", "undersample"].
 63            Default is None.
 64        lico_expansion_factor : int, *optional*
 65            Linear Combination Oversampling expansion factor, which is the
 66            factor by which the number of ERPs in the training set will be
 67            expanded.
 68            - Default is `1`.
 69        oversample_ratio : float, *optional*
 70            Traditional oversampling. Range is from from 0.1-1 resulting
 71            from the ratio of erp to non-erp class. 0 for no oversampling.
 72            - Default is `0`.
 73        undersample_ratio : float, *optional*
 74            Traditional undersampling. Range is from from 0.1-1 resulting
 75            from the ratio of erp to non-erp class. 0 for no undersampling.
 76            - Default is `0`.
 77        random_seed : int, *optional*
 78            Random seed.
 79            - Default is `42`.
 80        remove_flats : bool, *optional*
 81            Whether to remove flat channels.
 82            - Default is `True`.
 83
 84        Returns
 85        -------
 86        `None`
 87
 88        """
 89        self.n_splits = n_splits
 90        self.resampling_method = resampling_method
 91        self.lico_expansion_factor = lico_expansion_factor
 92        self.oversample_ratio = oversample_ratio
 93        self.undersample_ratio = undersample_ratio
 94        self.random_seed = random_seed
 95
 96        # # Create steps list with proper formatting
 97        steps = []
 98        if remove_flats:
 99            steps.append(("remove_flats", FlatChannelRemover()))
100
101        steps.extend(
102            [
103                ("xdawn", XdawnCovariances()),
104                ("tangent", TangentSpace()),
105                ("lda", LinearDiscriminantAnalysis()),
106            ]
107        )
108
109        # Create pipeline
110        self.clf = Pipeline(steps)
111
112        # Hyperparameters to be optimized
113        # TODO: Implement an extended nfilter set, dynamically based on the number of channels
114        # Example of dynamic nfilter set
115        # n_channels = self.X.shape[1]
116        # nfilter_set = list(range(2, n_channels+1))  # Example range from 2 to n_channels inclusive
117        # Then set "xdawn__nfilter": nfilter_set in the param_grid below
118        self.param_grid = {
119            "xdawn__nfilter": [2, 3, 4],
120            "xdawn__estimator": ["oas", "lwf"],
121            "tangent__metric": ["riemann"],
122            "lda__solver": ["lsqr", "eigen"],
123            "lda__shrinkage": np.linspace(0.5, 0.9, 5),
124        }
125
126    def fit(
127        self,
128        plot_cm=False,
129        plot_roc=False,
130    ):
131        """Fit the model.
132
133        Parameters
134        ----------
135        plot_cm : bool, *optional*
136            Whether to plot the confusion matrix during training.
137            - Default is `False`.
138        plot_roc : bool, *optional*
139            Whether to plot the ROC curve during training.
140            - Default is `False`.
141
142        Returns
143        -------
144        `None`
145            Models created used in `predict()`.
146
147        """
148        logger.info("Fitting the model using RG")
149        logger.info("X shape: %s", self.X.shape)
150        logger.info("y shape: %s", self.y.shape)
151
152        # Resample data if needed
153        self.X, self.y = self.__resample_data()
154
155        # Optimize hyperparameters with cross-validation
156        self.__optimize_hyperparameters()
157
158        # Fit the model with the complete dataset and optimized hyperparameters
159        self.clf.fit(self.X, self.y)
160
161        # Get predictions for final model
162        y_pred_proba = self.clf.predict_proba(self.X)[:, 1]
163
164        # Calculate estimate of training metrics of final model
165        # TODO: Implement proper training metrics calculation, using cross validation.
166        # self.offline_accuracy = sum(y_pred == self.y) / len(self.y)
167        # self.offline_precision = precision_score(self.y, y_pred)
168        # self.offline_recall = recall_score(self.y, y_pred)
169
170        try:
171            roc_auc = roc_auc_score(self.y, y_pred_proba)
172            logger.info(f"ROC AUC Score: {roc_auc:0.3f}")
173        except Exception as e:
174            logger.warning(f"Could not calculate ROC AUC score: {e}")
175
176        # Display training confusion matrix
177        # self.offline_cm = confusion_matrix(self.y, y_pred)
178        if plot_cm:
179            disp = ConfusionMatrixDisplay(confusion_matrix=self.offline_cm)
180            disp.plot()
181            plt.title("Training confusion matrix")
182
183        if plot_roc:
184            # TODO Implementation missing
185            pass
186
187        # Log training metrics
188        logger.info("Final model training performance metrics:")
189        logger.info(f"Accuracy: {self.offline_accuracy:0.3f} - MAY NOT BE ACCURATE")
190        logger.info(f"Precision: {self.offline_precision:0.3f} - MAY NOT BE ACCURATE")
191        logger.info(f"Recall: {self.offline_recall:0.3f} - MAY NOT BE ACCURATE")
192        logger.info(f"Confusion Matrix:\n{self.offline_cm} ")
193        logger.warning(
194            "Note: Training metrics may not be accurate due to the use of "
195            "cross-validation and resampling methods. Use with caution."
196        )
197
198    def predict(self, X):
199        """Predict the class of the data
200
201        Parameters
202        ----------
203        X : numpy.ndarray
204            3D array where shape = (n_epochs, n_channels, n_samples)
205
206        Returns
207        -------
208        prediction : Prediction
209            Predict object. Contains the predicted labels and and the probability.
210            Because this classifier chooses the P300 object with the highest posterior probability,
211            the probability is only the posterior probability of the chosen object.
212
213        """
214
215        subset_X = self.get_subset(X, self.subset, self.channel_labels)
216
217        # Get posterior probability for each target
218        posterior_prob = self.clf.predict_proba(subset_X)[:, 1]
219
220        label = [int(np.argmax(posterior_prob))]
221        probability = [np.max(posterior_prob)]
222
223        return Prediction(label, probability)
224
225    # TODO implement additional resampling methods, JIRA ticket: B4K-342
226    def __resample_data(self):
227        """Resample data based on the selected method"""
228
229        X_resampled = self.X.copy()
230        y_resampled = self.y.copy()
231
232        try:
233            if (self.resampling_method == "lico") and (self.lico_expansion_factor > 1):
234                [X_resampled, y_resampled] = lico(
235                    self.X, self.y, self.lico_expansion_factor
236                )
237                pass
238
239            elif (self.resampling_method == "oversample") and (
240                self.oversample_ratio > 0
241            ):
242                [X_resampled, y_resampled] = random_oversampling(
243                    self.X, self.y, self.oversample_ratio
244                )
245                pass
246
247            elif (self.resampling_method == "undersample") and (
248                self.undersample_ratio > 0
249            ):
250                [X_resampled, y_resampled] = random_undersampling(
251                    self.X, self.y, self.undersample_ratio
252                )
253                pass
254
255            logger.info(f"Resampling  with {self.resampling_method} done")
256            logger.info(f"X_resampled shape: {X_resampled.shape}")
257            logger.info(f"y_resampled shape: {y_resampled.shape}")
258
259        except Exception as e:
260            logger.error(
261                f"{self.resampling_method.capitalize()} resampling method failed"
262            )
263            logger.error(e)
264
265        return X_resampled, y_resampled
266
267    def __optimize_hyperparameters(self):
268        """Optimize hyperparameters with cross-validation using brute force grid search
269
270        Returns
271        -------
272        `None`
273            Model with best hyperparameters to be used in `predict()`.
274
275        """
276
277        # Perform cross-validation
278        cv = StratifiedKFold(
279            n_splits=self.n_splits, shuffle=True, random_state=self.random_seed
280        )
281
282        # Create custom scorer function
283        custom_scorer = make_scorer(
284            self._valid_roc_auc, response_method="predict_proba", greater_is_better=True
285        )
286
287        # Create GridSearchCV object
288        grid_search = GridSearchCV(
289            estimator=self.clf,
290            param_grid=self.param_grid,
291            cv=cv,
292            n_jobs=-1,
293            verbose=1,
294            scoring=custom_scorer,
295            refit=True,
296            return_train_score=True,
297        )
298
299        # Start grid search optimization
300        logger.info("Starting grid search optimization...")
301        grid_search.fit(self.X, self.y)
302
303        # Get best parameters and score
304        logger.info("Grid search optimization completed.")
305        best_params = grid_search.best_params_
306        best_score = grid_search.best_score_
307
308        # Report training metrics: TODO: Verify this is the right way to calculate training metrics
309        self.offline_accuracy = grid_search.best_estimator_.score(self.X, self.y)
310        self.offline_cm = confusion_matrix(
311            self.y, grid_search.best_estimator_.predict(self.X)
312        )
313        self.offline_precision = precision_score(
314            self.y, grid_search.best_estimator_.predict(self.X)
315        )
316        self.offline_recall = recall_score(
317            self.y, grid_search.best_estimator_.predict(self.X)
318        )
319
320        # Update classifier with best parameters
321        self.clf.set_params(**best_params)
322        logger.info(f"Best parameters found: {best_params}")
323        logger.info(f"Best CV score: {best_score:0.3f}")
324
325    def _valid_roc_auc(self, y_true, y_pred, **kwargs):
326        """Calculate the ROC AUC score for the classifier.
327        This method is used because the stock `roc_auc_score` function
328        does not handle the case where one class is missing in the fold.
329        This method will return 0.5 in that case.
330
331        Parameters
332        ----------
333        y_true : numpy.ndarray
334            True labels.
335        y_pred : numpy.ndarray
336            Predicted labels.
337        **kwargs : dict
338            Additional keyword arguments passed by make_scorer.
339
340        Returns
341        -------
342        roc_auc : float
343            ROC AUC score.
344
345        """
346        try:
347            # Check if we have both classes in the fold
348            if len(np.unique(y_true)) < 2:
349                logger.warning("Fold contains only one class")
350                return 0.5
351
352            return roc_auc_score(y_true, y_pred)
353
354        except Exception as e:
355            logger.warning(f"ROC AUC calculation failed: {e}")
356            return 0.5

ERP RG Classifier with hyperparameter grid search class (inherits from GenericClassifier).

def set_p300_clf_settings( self, n_splits=3, resampling_method=None, lico_expansion_factor=1, oversample_ratio=0, undersample_ratio=0, random_seed=42, remove_flats=True):
 44    def set_p300_clf_settings(
 45        self,
 46        n_splits=3,
 47        resampling_method=None,
 48        lico_expansion_factor=1,
 49        oversample_ratio=0,
 50        undersample_ratio=0,
 51        random_seed=42,
 52        remove_flats=True,
 53    ):
 54        """Set P300 Classifier Settings.
 55
 56        Parameters
 57        ----------
 58        n_splits : int, *optional*
 59            Number of folds for cross-validation.
 60            - Default is `3`.
 61        resampling_method : str, *optional*, None
 62            Resampling method to use ["lico", "oversample", "undersample"].
 63            Default is None.
 64        lico_expansion_factor : int, *optional*
 65            Linear Combination Oversampling expansion factor, which is the
 66            factor by which the number of ERPs in the training set will be
 67            expanded.
 68            - Default is `1`.
 69        oversample_ratio : float, *optional*
 70            Traditional oversampling. Range is from from 0.1-1 resulting
 71            from the ratio of erp to non-erp class. 0 for no oversampling.
 72            - Default is `0`.
 73        undersample_ratio : float, *optional*
 74            Traditional undersampling. Range is from from 0.1-1 resulting
 75            from the ratio of erp to non-erp class. 0 for no undersampling.
 76            - Default is `0`.
 77        random_seed : int, *optional*
 78            Random seed.
 79            - Default is `42`.
 80        remove_flats : bool, *optional*
 81            Whether to remove flat channels.
 82            - Default is `True`.
 83
 84        Returns
 85        -------
 86        `None`
 87
 88        """
 89        self.n_splits = n_splits
 90        self.resampling_method = resampling_method
 91        self.lico_expansion_factor = lico_expansion_factor
 92        self.oversample_ratio = oversample_ratio
 93        self.undersample_ratio = undersample_ratio
 94        self.random_seed = random_seed
 95
 96        # # Create steps list with proper formatting
 97        steps = []
 98        if remove_flats:
 99            steps.append(("remove_flats", FlatChannelRemover()))
100
101        steps.extend(
102            [
103                ("xdawn", XdawnCovariances()),
104                ("tangent", TangentSpace()),
105                ("lda", LinearDiscriminantAnalysis()),
106            ]
107        )
108
109        # Create pipeline
110        self.clf = Pipeline(steps)
111
112        # Hyperparameters to be optimized
113        # TODO: Implement an extended nfilter set, dynamically based on the number of channels
114        # Example of dynamic nfilter set
115        # n_channels = self.X.shape[1]
116        # nfilter_set = list(range(2, n_channels+1))  # Example range from 2 to n_channels inclusive
117        # Then set "xdawn__nfilter": nfilter_set in the param_grid below
118        self.param_grid = {
119            "xdawn__nfilter": [2, 3, 4],
120            "xdawn__estimator": ["oas", "lwf"],
121            "tangent__metric": ["riemann"],
122            "lda__solver": ["lsqr", "eigen"],
123            "lda__shrinkage": np.linspace(0.5, 0.9, 5),
124        }

Set P300 Classifier Settings.

Parameters
  • n_splits (int, optional): Number of folds for cross-validation.
    • Default is 3.
  • resampling_method (str, optional, None): Resampling method to use ["lico", "oversample", "undersample"]. Default is None.
  • lico_expansion_factor (int, optional): Linear Combination Oversampling expansion factor, which is the factor by which the number of ERPs in the training set will be expanded.
    • Default is 1.
  • oversample_ratio (float, optional): Traditional oversampling. Range is from from 0.1-1 resulting from the ratio of erp to non-erp class. 0 for no oversampling.
    • Default is 0.
  • undersample_ratio (float, optional): Traditional undersampling. Range is from from 0.1-1 resulting from the ratio of erp to non-erp class. 0 for no undersampling.
    • Default is 0.
  • random_seed (int, optional): Random seed.
    • Default is 42.
  • remove_flats (bool, optional): Whether to remove flat channels.
    • Default is True.
Returns
  • None
def fit(self, plot_cm=False, plot_roc=False):
126    def fit(
127        self,
128        plot_cm=False,
129        plot_roc=False,
130    ):
131        """Fit the model.
132
133        Parameters
134        ----------
135        plot_cm : bool, *optional*
136            Whether to plot the confusion matrix during training.
137            - Default is `False`.
138        plot_roc : bool, *optional*
139            Whether to plot the ROC curve during training.
140            - Default is `False`.
141
142        Returns
143        -------
144        `None`
145            Models created used in `predict()`.
146
147        """
148        logger.info("Fitting the model using RG")
149        logger.info("X shape: %s", self.X.shape)
150        logger.info("y shape: %s", self.y.shape)
151
152        # Resample data if needed
153        self.X, self.y = self.__resample_data()
154
155        # Optimize hyperparameters with cross-validation
156        self.__optimize_hyperparameters()
157
158        # Fit the model with the complete dataset and optimized hyperparameters
159        self.clf.fit(self.X, self.y)
160
161        # Get predictions for final model
162        y_pred_proba = self.clf.predict_proba(self.X)[:, 1]
163
164        # Calculate estimate of training metrics of final model
165        # TODO: Implement proper training metrics calculation, using cross validation.
166        # self.offline_accuracy = sum(y_pred == self.y) / len(self.y)
167        # self.offline_precision = precision_score(self.y, y_pred)
168        # self.offline_recall = recall_score(self.y, y_pred)
169
170        try:
171            roc_auc = roc_auc_score(self.y, y_pred_proba)
172            logger.info(f"ROC AUC Score: {roc_auc:0.3f}")
173        except Exception as e:
174            logger.warning(f"Could not calculate ROC AUC score: {e}")
175
176        # Display training confusion matrix
177        # self.offline_cm = confusion_matrix(self.y, y_pred)
178        if plot_cm:
179            disp = ConfusionMatrixDisplay(confusion_matrix=self.offline_cm)
180            disp.plot()
181            plt.title("Training confusion matrix")
182
183        if plot_roc:
184            # TODO Implementation missing
185            pass
186
187        # Log training metrics
188        logger.info("Final model training performance metrics:")
189        logger.info(f"Accuracy: {self.offline_accuracy:0.3f} - MAY NOT BE ACCURATE")
190        logger.info(f"Precision: {self.offline_precision:0.3f} - MAY NOT BE ACCURATE")
191        logger.info(f"Recall: {self.offline_recall:0.3f} - MAY NOT BE ACCURATE")
192        logger.info(f"Confusion Matrix:\n{self.offline_cm} ")
193        logger.warning(
194            "Note: Training metrics may not be accurate due to the use of "
195            "cross-validation and resampling methods. Use with caution."
196        )

Fit the model.

Parameters
  • plot_cm (bool, optional): Whether to plot the confusion matrix during training.
    • Default is False.
  • plot_roc (bool, optional): Whether to plot the ROC curve during training.
    • Default is False.
Returns
def predict(self, X):
198    def predict(self, X):
199        """Predict the class of the data
200
201        Parameters
202        ----------
203        X : numpy.ndarray
204            3D array where shape = (n_epochs, n_channels, n_samples)
205
206        Returns
207        -------
208        prediction : Prediction
209            Predict object. Contains the predicted labels and and the probability.
210            Because this classifier chooses the P300 object with the highest posterior probability,
211            the probability is only the posterior probability of the chosen object.
212
213        """
214
215        subset_X = self.get_subset(X, self.subset, self.channel_labels)
216
217        # Get posterior probability for each target
218        posterior_prob = self.clf.predict_proba(subset_X)[:, 1]
219
220        label = [int(np.argmax(posterior_prob))]
221        probability = [np.max(posterior_prob)]
222
223        return Prediction(label, probability)

Predict the class of the data

Parameters
  • X (numpy.ndarray): 3D array where shape = (n_epochs, n_channels, n_samples)
Returns
  • prediction (Prediction): Predict object. Contains the predicted labels and and the probability. Because this classifier chooses the P300 object with the highest posterior probability, the probability is only the posterior probability of the chosen object.