bci_essentials.classification.erp_rg_classifier

ERP RG Classifier

This classifier is used to classify ERPs using the Riemannian Geometry approach.

  1"""**ERP RG Classifier**
  2
  3This classifier is used to classify ERPs using the Riemannian Geometry
  4approach.
  5
  6"""
  7
  8# Stock libraries
  9import random
 10import numpy as np
 11import matplotlib.pyplot as plt
 12from sklearn.pipeline import make_pipeline
 13from sklearn.model_selection import StratifiedKFold
 14from sklearn.metrics import (
 15    confusion_matrix,
 16    ConfusionMatrixDisplay,
 17    precision_score,
 18    recall_score,
 19)
 20from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
 21from pyriemann.estimation import XdawnCovariances
 22from pyriemann.tangentspace import TangentSpace
 23from pyriemann.channelselection import FlatChannelRemover
 24
 25
 26# Import bci_essentials modules and methods
 27from ..classification.generic_classifier import (
 28    GenericClassifier,
 29    Prediction,
 30    KernelResults,
 31)
 32from ..signal_processing import lico
 33from ..channel_selection import channel_selection_by_method
 34from ..utils.logger import Logger  # Logger wrapper
 35
 36# Instantiate a logger for the module at the default level of logging.INFO
 37# Logs to bci_essentials.__module__) where __module__ is the name of the module
 38logger = Logger(name=__name__)
 39
 40
 41class ErpRgClassifier(GenericClassifier):
 42    """ERP RG Classifier class (*inherits from `GenericClassifier`*)."""
 43
 44    def set_p300_clf_settings(
 45        self,
 46        n_splits=3,
 47        lico_expansion_factor=1,
 48        oversample_ratio=0,
 49        undersample_ratio=0,
 50        random_seed=42,
 51        covariance_estimator="oas",  # Covariance estimator, see pyriemann Covariances
 52        remove_flats=True,
 53    ):
 54        """Set P300 Classifier Settings.
 55
 56        Parameters
 57        ----------
 58        n_splits : int, *optional*
 59            Number of folds for cross-validation.
 60            - Default is `3`.
 61        lico_expansion_factor : int, *optional*
 62            Linear Combination Oversampling expansion factor, which is the
 63            factor by which the number of ERPs in the training set will be
 64            expanded.
 65            - Default is `1`.
 66        oversample_ratio : float, *optional*
 67            Traditional oversampling. Range is from from 0.1-1 resulting
 68            from the ratio of erp to non-erp class. 0 for no oversampling.
 69            - Default is `0`.
 70        undersample_ratio : float, *optional*
 71            Traditional undersampling. Range is from from 0.1-1 resulting
 72            from the ratio of erp to non-erp class. 0 for no undersampling.
 73            - Default is `0`.
 74        random_seed : int, *optional*
 75            Random seed.
 76            - Default is `42`.
 77        covariance_estimator : str, *optional*
 78            Covariance estimator. See pyriemann Covariances.
 79            - Default is `"oas"`.
 80        remove_flats : bool, *optional*
 81            Whether to remove flat channels.
 82            - Default is `True`.
 83
 84        Returns
 85        -------
 86        `None`
 87
 88        """
 89        self.n_splits = n_splits
 90        self.lico_expansion_factor = lico_expansion_factor
 91        self.oversample_ratio = oversample_ratio
 92        self.undersample_ratio = undersample_ratio
 93        self.random_seed = random_seed
 94        self.covariance_estimator = covariance_estimator
 95
 96        # Define the classifier
 97        self.clf = make_pipeline(
 98            XdawnCovariances(estimator=self.covariance_estimator),
 99            TangentSpace(metric="riemann"),
100            LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"),
101        )
102
103        if remove_flats:
104            rf = FlatChannelRemover()
105            self.clf.steps.insert(0, ["Remove Flat Channels", rf])
106
107    def fit(
108        self,
109        plot_cm=False,
110        plot_roc=False,
111        lico_expansion_factor=1,
112    ):
113        """Fit the model.
114
115        Parameters
116        ----------
117        plot_cm : bool, *optional*
118            Whether to plot the confusion matrix during training.
119            - Default is `False`.
120        plot_roc : bool, *optional*
121            Whether to plot the ROC curve during training.
122            - Default is `False`.
123        lico_expansion_factor : int, *optional*
124            Linear combination oversampling expansion factor.
125            Determines the number of ERPs in the training set that will be expanded.
126            Higher value increases the oversampling, generating more synthetic
127            samples for the minority class.
128            - Default is `1`.
129
130        Returns
131        -------
132        `None`
133            Models created used in `predict()`.
134
135        """
136        logger.info("Fitting the model using RG")
137        logger.info("X shape: %s", self.X.shape)
138        logger.info("y shape: %s", self.y.shape)
139
140        # Define the strategy for cross validation
141        cv = StratifiedKFold(
142            n_splits=self.n_splits, shuffle=True, random_state=self.random_seed
143        )
144
145        # Init predictions to all false
146        cv_preds = np.zeros(len(self.y))
147
148        def __erp_rg_kernel(X, y):
149            """ERP RG kernel.
150
151            Parameters
152            ----------
153            X : numpy.ndarray
154                Input features (ERP data) for training.
155                3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`).
156                E.g. (100, 32, 1000) for 100 trials, 32 channels and 1000 samples per channel.
157
158            y : numpy.ndarray
159                Target labels corresponding to the input features in `X`.
160                1D numpy array with shape (n_trails, ).
161                Each label indicates the class of the corresponding trial in `X`.
162                E.g. (100, ) for 100 trials.
163
164
165            Returns
166            -------
167            kernelResults : KernelResults
168                KernelResults object containing the following attributes:
169                    model : classifier
170                        The trained classification model.
171                    cv_preds : numpy.ndarray
172                        The predictions from the model using cross validation.
173                        1D array with the same shape as `y`.
174                    accuracy : float
175                        The accuracy of the trained classification model.
176                    precision : float
177                        The precision of the trained classification model.
178                    recall : float
179                        The recall of the trained classification model.
180
181            """
182            for train_idx, test_idx in cv.split(X, y):
183                y_train, y_test = y[train_idx], y[test_idx]
184
185                X_train, X_test = X[train_idx], X[test_idx]
186
187                # LICO
188                logger.debug(
189                    "Before LICO:\n\tShape X: %s\n\tShape y: %s",
190                    X_train.shape,
191                    y_train.shape,
192                )
193
194                if sum(y_train) > 2:
195                    if lico_expansion_factor > 1:
196                        X_train, y_train = lico(
197                            X_train,
198                            y_train,
199                            expansion_factor=lico_expansion_factor,
200                            sum_num=2,
201                            shuffle=False,
202                        )
203                        logger.debug("y_train = %s", y_train)
204
205                logger.debug(
206                    "After LICO:\n\tShape X: %s\n\tShape y: %s",
207                    X_train.shape,
208                    y_train.shape,
209                )
210
211                # Oversampling
212                if self.oversample_ratio > 0:
213                    p_count = sum(y_train)
214                    n_count = len(y_train) - sum(y_train)
215
216                    num_to_add = int(
217                        np.floor((self.oversample_ratio * n_count) - p_count)
218                    )
219
220                    # Add num_to_add random selections from the positive
221                    true_X_train = X_train[y_train == 1]
222
223                    len_X_train = len(true_X_train)
224
225                    for s in range(num_to_add):
226                        to_add_X = true_X_train[random.randrange(0, len_X_train), :, :]
227
228                        X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0)
229                        y_train = np.append(y_train, [1], axis=0)
230
231                # Undersampling
232                if self.undersample_ratio > 0:
233                    p_count = sum(y_train)
234                    n_count = len(y_train) - sum(y_train)
235
236                    num_to_remove = int(
237                        np.floor(n_count - (p_count / self.undersample_ratio))
238                    )
239
240                    ind_range = np.arange(len(y_train))
241                    ind_list = list(ind_range)
242                    to_remove = []
243
244                    # Remove num_to_remove random selections from the negative
245                    false_ind = list(ind_range[y_train == 0])
246
247                    for s in range(num_to_remove):
248                        # select a random value from the list of false indices
249                        remove_at = false_ind[random.randrange(0, len(false_ind))]
250
251                        # remove that value from the false index list
252                        false_ind.remove(remove_at)
253
254                        # add the index to be removed to a list
255                        to_remove.append(remove_at)
256
257                    remaining_ind = ind_list
258                    for i in range(len(to_remove)):
259                        remaining_ind.remove(to_remove[i])
260
261                    X_train = X_train[remaining_ind, :, :]
262                    y_train = y_train[remaining_ind]
263
264                self.clf.fit(X_train, y_train)
265                cv_preds[test_idx] = self.clf.predict(X_test)
266                predproba = self.clf.predict_proba(X_test)
267
268                # Use pred proba to show what would be predicted
269                predprobs = predproba[:, 1]
270                real = np.where(y_test == 1)
271
272                # TODO handle exception where two probabilities are the same
273                prediction = int(np.where(predprobs == np.amax(predprobs))[0][0])
274
275                logger.debug("y_test = %s", y_test)
276                logger.debug("predproba = %s", predproba)
277                logger.debug("real = %s", real[0])
278                logger.debug("prediction = %s", prediction)
279
280            # Train final model with all available data
281            self.clf.fit(X, y)
282            model = self.clf
283
284            accuracy = sum(cv_preds == self.y) / len(cv_preds)
285            precision = precision_score(self.y, cv_preds)
286            recall = recall_score(self.y, cv_preds)
287
288            return KernelResults(model, cv_preds, accuracy, precision, recall)
289
290        # Check if channel selection is true
291        if self.channel_selection_setup:
292            logger.info("Doing channel selection")
293            logger.debug("Initial subset: %s", self.chs_initial_subset)
294
295            channel_selection_results = channel_selection_by_method(
296                __erp_rg_kernel,
297                self.X,
298                self.y,
299                self.channel_labels,  # kernel setup
300                self.chs_method,
301                self.chs_metric,
302                self.chs_initial_subset,  # wrapper setup
303                self.chs_max_time,
304                self.chs_min_channels,
305                self.chs_max_channels,
306                self.chs_performance_delta,  # stopping criterion
307                self.chs_n_jobs,
308            )  # njobs, output messages
309
310            preds = channel_selection_results.best_preds
311            accuracy = channel_selection_results.best_accuracy
312            precision = channel_selection_results.best_precision
313            recall = channel_selection_results.best_recall
314
315            logger.info(
316                "The optimal subset is %s",
317                channel_selection_results.best_channel_subset,
318            )
319
320            self.results_df = channel_selection_results.results_df
321            self.subset = channel_selection_results.best_channel_subset
322            self.subset_defined = True
323            self.clf = channel_selection_results.best_model
324        else:
325            logger.warning("Not doing channel selection")
326            X = self.get_subset(self.X, self.subset, self.channel_labels)
327
328            current_results = __erp_rg_kernel(X, self.y)
329            self.clf = current_results.model
330            preds = current_results.cv_preds
331            accuracy = current_results.accuracy
332            precision = current_results.precision
333            recall = current_results.recall
334
335        # Log performance stats
336        # accuracy
337        accuracy = sum(preds == self.y) / len(preds)
338        self.offline_accuracy = accuracy
339        logger.info("Accuracy = %s", accuracy)
340
341        # precision
342        precision = precision_score(self.y, preds)
343        self.offline_precision = precision
344        logger.info("Precision = %s", precision)
345
346        # recall
347        recall = recall_score(self.y, preds)
348        self.offline_recall = recall
349        logger.info("Recall = %s", recall)
350
351        # confusion matrix in command line
352        cm = confusion_matrix(self.y, preds)
353        self.offline_cm = cm
354        logger.info("Confusion matrix:\n%s", cm)
355
356        if plot_cm:
357            cm = confusion_matrix(self.y, preds)
358            ConfusionMatrixDisplay(cm).plot()
359            plt.show()
360
361        if plot_roc:
362            logger.error("ROC plot has not been implemented yet")
363
364    def predict(self, X):
365        """Predict the class of the data (Unused in this classifier)
366
367        Parameters
368        ----------
369        X : numpy.ndarray
370            3D array where shape = (n_epochs, n_channels, n_samples)
371
372        Returns
373        -------
374        prediction : Prediction
375            Predict object. Contains the predicted labels and and the probability.
376            Because this classifier chooses the P300 object with the highest posterior probability,
377            the probability is only the posterior probability of the chosen object.
378
379        """
380
381        subset_X = self.get_subset(X, self.subset, self.channel_labels)
382
383        # Get posterior probability for each target
384        posterior_prob = self.clf.predict_proba(subset_X)[:, 1]
385
386        label = [int(np.argmax(posterior_prob))]
387        probability = [np.max(posterior_prob)]
388
389        return Prediction(label, probability)
 42class ErpRgClassifier(GenericClassifier):
 43    """ERP RG Classifier class (*inherits from `GenericClassifier`*)."""
 44
 45    def set_p300_clf_settings(
 46        self,
 47        n_splits=3,
 48        lico_expansion_factor=1,
 49        oversample_ratio=0,
 50        undersample_ratio=0,
 51        random_seed=42,
 52        covariance_estimator="oas",  # Covariance estimator, see pyriemann Covariances
 53        remove_flats=True,
 54    ):
 55        """Set P300 Classifier Settings.
 56
 57        Parameters
 58        ----------
 59        n_splits : int, *optional*
 60            Number of folds for cross-validation.
 61            - Default is `3`.
 62        lico_expansion_factor : int, *optional*
 63            Linear Combination Oversampling expansion factor, which is the
 64            factor by which the number of ERPs in the training set will be
 65            expanded.
 66            - Default is `1`.
 67        oversample_ratio : float, *optional*
 68            Traditional oversampling. Range is from from 0.1-1 resulting
 69            from the ratio of erp to non-erp class. 0 for no oversampling.
 70            - Default is `0`.
 71        undersample_ratio : float, *optional*
 72            Traditional undersampling. Range is from from 0.1-1 resulting
 73            from the ratio of erp to non-erp class. 0 for no undersampling.
 74            - Default is `0`.
 75        random_seed : int, *optional*
 76            Random seed.
 77            - Default is `42`.
 78        covariance_estimator : str, *optional*
 79            Covariance estimator. See pyriemann Covariances.
 80            - Default is `"oas"`.
 81        remove_flats : bool, *optional*
 82            Whether to remove flat channels.
 83            - Default is `True`.
 84
 85        Returns
 86        -------
 87        `None`
 88
 89        """
 90        self.n_splits = n_splits
 91        self.lico_expansion_factor = lico_expansion_factor
 92        self.oversample_ratio = oversample_ratio
 93        self.undersample_ratio = undersample_ratio
 94        self.random_seed = random_seed
 95        self.covariance_estimator = covariance_estimator
 96
 97        # Define the classifier
 98        self.clf = make_pipeline(
 99            XdawnCovariances(estimator=self.covariance_estimator),
100            TangentSpace(metric="riemann"),
101            LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"),
102        )
103
104        if remove_flats:
105            rf = FlatChannelRemover()
106            self.clf.steps.insert(0, ["Remove Flat Channels", rf])
107
108    def fit(
109        self,
110        plot_cm=False,
111        plot_roc=False,
112        lico_expansion_factor=1,
113    ):
114        """Fit the model.
115
116        Parameters
117        ----------
118        plot_cm : bool, *optional*
119            Whether to plot the confusion matrix during training.
120            - Default is `False`.
121        plot_roc : bool, *optional*
122            Whether to plot the ROC curve during training.
123            - Default is `False`.
124        lico_expansion_factor : int, *optional*
125            Linear combination oversampling expansion factor.
126            Determines the number of ERPs in the training set that will be expanded.
127            Higher value increases the oversampling, generating more synthetic
128            samples for the minority class.
129            - Default is `1`.
130
131        Returns
132        -------
133        `None`
134            Models created used in `predict()`.
135
136        """
137        logger.info("Fitting the model using RG")
138        logger.info("X shape: %s", self.X.shape)
139        logger.info("y shape: %s", self.y.shape)
140
141        # Define the strategy for cross validation
142        cv = StratifiedKFold(
143            n_splits=self.n_splits, shuffle=True, random_state=self.random_seed
144        )
145
146        # Init predictions to all false
147        cv_preds = np.zeros(len(self.y))
148
149        def __erp_rg_kernel(X, y):
150            """ERP RG kernel.
151
152            Parameters
153            ----------
154            X : numpy.ndarray
155                Input features (ERP data) for training.
156                3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`).
157                E.g. (100, 32, 1000) for 100 trials, 32 channels and 1000 samples per channel.
158
159            y : numpy.ndarray
160                Target labels corresponding to the input features in `X`.
161                1D numpy array with shape (n_trails, ).
162                Each label indicates the class of the corresponding trial in `X`.
163                E.g. (100, ) for 100 trials.
164
165
166            Returns
167            -------
168            kernelResults : KernelResults
169                KernelResults object containing the following attributes:
170                    model : classifier
171                        The trained classification model.
172                    cv_preds : numpy.ndarray
173                        The predictions from the model using cross validation.
174                        1D array with the same shape as `y`.
175                    accuracy : float
176                        The accuracy of the trained classification model.
177                    precision : float
178                        The precision of the trained classification model.
179                    recall : float
180                        The recall of the trained classification model.
181
182            """
183            for train_idx, test_idx in cv.split(X, y):
184                y_train, y_test = y[train_idx], y[test_idx]
185
186                X_train, X_test = X[train_idx], X[test_idx]
187
188                # LICO
189                logger.debug(
190                    "Before LICO:\n\tShape X: %s\n\tShape y: %s",
191                    X_train.shape,
192                    y_train.shape,
193                )
194
195                if sum(y_train) > 2:
196                    if lico_expansion_factor > 1:
197                        X_train, y_train = lico(
198                            X_train,
199                            y_train,
200                            expansion_factor=lico_expansion_factor,
201                            sum_num=2,
202                            shuffle=False,
203                        )
204                        logger.debug("y_train = %s", y_train)
205
206                logger.debug(
207                    "After LICO:\n\tShape X: %s\n\tShape y: %s",
208                    X_train.shape,
209                    y_train.shape,
210                )
211
212                # Oversampling
213                if self.oversample_ratio > 0:
214                    p_count = sum(y_train)
215                    n_count = len(y_train) - sum(y_train)
216
217                    num_to_add = int(
218                        np.floor((self.oversample_ratio * n_count) - p_count)
219                    )
220
221                    # Add num_to_add random selections from the positive
222                    true_X_train = X_train[y_train == 1]
223
224                    len_X_train = len(true_X_train)
225
226                    for s in range(num_to_add):
227                        to_add_X = true_X_train[random.randrange(0, len_X_train), :, :]
228
229                        X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0)
230                        y_train = np.append(y_train, [1], axis=0)
231
232                # Undersampling
233                if self.undersample_ratio > 0:
234                    p_count = sum(y_train)
235                    n_count = len(y_train) - sum(y_train)
236
237                    num_to_remove = int(
238                        np.floor(n_count - (p_count / self.undersample_ratio))
239                    )
240
241                    ind_range = np.arange(len(y_train))
242                    ind_list = list(ind_range)
243                    to_remove = []
244
245                    # Remove num_to_remove random selections from the negative
246                    false_ind = list(ind_range[y_train == 0])
247
248                    for s in range(num_to_remove):
249                        # select a random value from the list of false indices
250                        remove_at = false_ind[random.randrange(0, len(false_ind))]
251
252                        # remove that value from the false index list
253                        false_ind.remove(remove_at)
254
255                        # add the index to be removed to a list
256                        to_remove.append(remove_at)
257
258                    remaining_ind = ind_list
259                    for i in range(len(to_remove)):
260                        remaining_ind.remove(to_remove[i])
261
262                    X_train = X_train[remaining_ind, :, :]
263                    y_train = y_train[remaining_ind]
264
265                self.clf.fit(X_train, y_train)
266                cv_preds[test_idx] = self.clf.predict(X_test)
267                predproba = self.clf.predict_proba(X_test)
268
269                # Use pred proba to show what would be predicted
270                predprobs = predproba[:, 1]
271                real = np.where(y_test == 1)
272
273                # TODO handle exception where two probabilities are the same
274                prediction = int(np.where(predprobs == np.amax(predprobs))[0][0])
275
276                logger.debug("y_test = %s", y_test)
277                logger.debug("predproba = %s", predproba)
278                logger.debug("real = %s", real[0])
279                logger.debug("prediction = %s", prediction)
280
281            # Train final model with all available data
282            self.clf.fit(X, y)
283            model = self.clf
284
285            accuracy = sum(cv_preds == self.y) / len(cv_preds)
286            precision = precision_score(self.y, cv_preds)
287            recall = recall_score(self.y, cv_preds)
288
289            return KernelResults(model, cv_preds, accuracy, precision, recall)
290
291        # Check if channel selection is true
292        if self.channel_selection_setup:
293            logger.info("Doing channel selection")
294            logger.debug("Initial subset: %s", self.chs_initial_subset)
295
296            channel_selection_results = channel_selection_by_method(
297                __erp_rg_kernel,
298                self.X,
299                self.y,
300                self.channel_labels,  # kernel setup
301                self.chs_method,
302                self.chs_metric,
303                self.chs_initial_subset,  # wrapper setup
304                self.chs_max_time,
305                self.chs_min_channels,
306                self.chs_max_channels,
307                self.chs_performance_delta,  # stopping criterion
308                self.chs_n_jobs,
309            )  # njobs, output messages
310
311            preds = channel_selection_results.best_preds
312            accuracy = channel_selection_results.best_accuracy
313            precision = channel_selection_results.best_precision
314            recall = channel_selection_results.best_recall
315
316            logger.info(
317                "The optimal subset is %s",
318                channel_selection_results.best_channel_subset,
319            )
320
321            self.results_df = channel_selection_results.results_df
322            self.subset = channel_selection_results.best_channel_subset
323            self.subset_defined = True
324            self.clf = channel_selection_results.best_model
325        else:
326            logger.warning("Not doing channel selection")
327            X = self.get_subset(self.X, self.subset, self.channel_labels)
328
329            current_results = __erp_rg_kernel(X, self.y)
330            self.clf = current_results.model
331            preds = current_results.cv_preds
332            accuracy = current_results.accuracy
333            precision = current_results.precision
334            recall = current_results.recall
335
336        # Log performance stats
337        # accuracy
338        accuracy = sum(preds == self.y) / len(preds)
339        self.offline_accuracy = accuracy
340        logger.info("Accuracy = %s", accuracy)
341
342        # precision
343        precision = precision_score(self.y, preds)
344        self.offline_precision = precision
345        logger.info("Precision = %s", precision)
346
347        # recall
348        recall = recall_score(self.y, preds)
349        self.offline_recall = recall
350        logger.info("Recall = %s", recall)
351
352        # confusion matrix in command line
353        cm = confusion_matrix(self.y, preds)
354        self.offline_cm = cm
355        logger.info("Confusion matrix:\n%s", cm)
356
357        if plot_cm:
358            cm = confusion_matrix(self.y, preds)
359            ConfusionMatrixDisplay(cm).plot()
360            plt.show()
361
362        if plot_roc:
363            logger.error("ROC plot has not been implemented yet")
364
365    def predict(self, X):
366        """Predict the class of the data (Unused in this classifier)
367
368        Parameters
369        ----------
370        X : numpy.ndarray
371            3D array where shape = (n_epochs, n_channels, n_samples)
372
373        Returns
374        -------
375        prediction : Prediction
376            Predict object. Contains the predicted labels and and the probability.
377            Because this classifier chooses the P300 object with the highest posterior probability,
378            the probability is only the posterior probability of the chosen object.
379
380        """
381
382        subset_X = self.get_subset(X, self.subset, self.channel_labels)
383
384        # Get posterior probability for each target
385        posterior_prob = self.clf.predict_proba(subset_X)[:, 1]
386
387        label = [int(np.argmax(posterior_prob))]
388        probability = [np.max(posterior_prob)]
389
390        return Prediction(label, probability)

ERP RG Classifier class (inherits from GenericClassifier).

def set_p300_clf_settings( self, n_splits=3, lico_expansion_factor=1, oversample_ratio=0, undersample_ratio=0, random_seed=42, covariance_estimator='oas', remove_flats=True):
 45    def set_p300_clf_settings(
 46        self,
 47        n_splits=3,
 48        lico_expansion_factor=1,
 49        oversample_ratio=0,
 50        undersample_ratio=0,
 51        random_seed=42,
 52        covariance_estimator="oas",  # Covariance estimator, see pyriemann Covariances
 53        remove_flats=True,
 54    ):
 55        """Set P300 Classifier Settings.
 56
 57        Parameters
 58        ----------
 59        n_splits : int, *optional*
 60            Number of folds for cross-validation.
 61            - Default is `3`.
 62        lico_expansion_factor : int, *optional*
 63            Linear Combination Oversampling expansion factor, which is the
 64            factor by which the number of ERPs in the training set will be
 65            expanded.
 66            - Default is `1`.
 67        oversample_ratio : float, *optional*
 68            Traditional oversampling. Range is from from 0.1-1 resulting
 69            from the ratio of erp to non-erp class. 0 for no oversampling.
 70            - Default is `0`.
 71        undersample_ratio : float, *optional*
 72            Traditional undersampling. Range is from from 0.1-1 resulting
 73            from the ratio of erp to non-erp class. 0 for no undersampling.
 74            - Default is `0`.
 75        random_seed : int, *optional*
 76            Random seed.
 77            - Default is `42`.
 78        covariance_estimator : str, *optional*
 79            Covariance estimator. See pyriemann Covariances.
 80            - Default is `"oas"`.
 81        remove_flats : bool, *optional*
 82            Whether to remove flat channels.
 83            - Default is `True`.
 84
 85        Returns
 86        -------
 87        `None`
 88
 89        """
 90        self.n_splits = n_splits
 91        self.lico_expansion_factor = lico_expansion_factor
 92        self.oversample_ratio = oversample_ratio
 93        self.undersample_ratio = undersample_ratio
 94        self.random_seed = random_seed
 95        self.covariance_estimator = covariance_estimator
 96
 97        # Define the classifier
 98        self.clf = make_pipeline(
 99            XdawnCovariances(estimator=self.covariance_estimator),
100            TangentSpace(metric="riemann"),
101            LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"),
102        )
103
104        if remove_flats:
105            rf = FlatChannelRemover()
106            self.clf.steps.insert(0, ["Remove Flat Channels", rf])

Set P300 Classifier Settings.

Parameters
  • n_splits (int, optional): Number of folds for cross-validation.
    • Default is 3.
  • lico_expansion_factor (int, optional): Linear Combination Oversampling expansion factor, which is the factor by which the number of ERPs in the training set will be expanded.
    • Default is 1.
  • oversample_ratio (float, optional): Traditional oversampling. Range is from from 0.1-1 resulting from the ratio of erp to non-erp class. 0 for no oversampling.
    • Default is 0.
  • undersample_ratio (float, optional): Traditional undersampling. Range is from from 0.1-1 resulting from the ratio of erp to non-erp class. 0 for no undersampling.
    • Default is 0.
  • random_seed (int, optional): Random seed.
    • Default is 42.
  • covariance_estimator (str, optional): Covariance estimator. See pyriemann Covariances.
    • Default is "oas".
  • remove_flats (bool, optional): Whether to remove flat channels.
    • Default is True.
Returns
  • None
def fit(self, plot_cm=False, plot_roc=False, lico_expansion_factor=1):
108    def fit(
109        self,
110        plot_cm=False,
111        plot_roc=False,
112        lico_expansion_factor=1,
113    ):
114        """Fit the model.
115
116        Parameters
117        ----------
118        plot_cm : bool, *optional*
119            Whether to plot the confusion matrix during training.
120            - Default is `False`.
121        plot_roc : bool, *optional*
122            Whether to plot the ROC curve during training.
123            - Default is `False`.
124        lico_expansion_factor : int, *optional*
125            Linear combination oversampling expansion factor.
126            Determines the number of ERPs in the training set that will be expanded.
127            Higher value increases the oversampling, generating more synthetic
128            samples for the minority class.
129            - Default is `1`.
130
131        Returns
132        -------
133        `None`
134            Models created used in `predict()`.
135
136        """
137        logger.info("Fitting the model using RG")
138        logger.info("X shape: %s", self.X.shape)
139        logger.info("y shape: %s", self.y.shape)
140
141        # Define the strategy for cross validation
142        cv = StratifiedKFold(
143            n_splits=self.n_splits, shuffle=True, random_state=self.random_seed
144        )
145
146        # Init predictions to all false
147        cv_preds = np.zeros(len(self.y))
148
149        def __erp_rg_kernel(X, y):
150            """ERP RG kernel.
151
152            Parameters
153            ----------
154            X : numpy.ndarray
155                Input features (ERP data) for training.
156                3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`).
157                E.g. (100, 32, 1000) for 100 trials, 32 channels and 1000 samples per channel.
158
159            y : numpy.ndarray
160                Target labels corresponding to the input features in `X`.
161                1D numpy array with shape (n_trails, ).
162                Each label indicates the class of the corresponding trial in `X`.
163                E.g. (100, ) for 100 trials.
164
165
166            Returns
167            -------
168            kernelResults : KernelResults
169                KernelResults object containing the following attributes:
170                    model : classifier
171                        The trained classification model.
172                    cv_preds : numpy.ndarray
173                        The predictions from the model using cross validation.
174                        1D array with the same shape as `y`.
175                    accuracy : float
176                        The accuracy of the trained classification model.
177                    precision : float
178                        The precision of the trained classification model.
179                    recall : float
180                        The recall of the trained classification model.
181
182            """
183            for train_idx, test_idx in cv.split(X, y):
184                y_train, y_test = y[train_idx], y[test_idx]
185
186                X_train, X_test = X[train_idx], X[test_idx]
187
188                # LICO
189                logger.debug(
190                    "Before LICO:\n\tShape X: %s\n\tShape y: %s",
191                    X_train.shape,
192                    y_train.shape,
193                )
194
195                if sum(y_train) > 2:
196                    if lico_expansion_factor > 1:
197                        X_train, y_train = lico(
198                            X_train,
199                            y_train,
200                            expansion_factor=lico_expansion_factor,
201                            sum_num=2,
202                            shuffle=False,
203                        )
204                        logger.debug("y_train = %s", y_train)
205
206                logger.debug(
207                    "After LICO:\n\tShape X: %s\n\tShape y: %s",
208                    X_train.shape,
209                    y_train.shape,
210                )
211
212                # Oversampling
213                if self.oversample_ratio > 0:
214                    p_count = sum(y_train)
215                    n_count = len(y_train) - sum(y_train)
216
217                    num_to_add = int(
218                        np.floor((self.oversample_ratio * n_count) - p_count)
219                    )
220
221                    # Add num_to_add random selections from the positive
222                    true_X_train = X_train[y_train == 1]
223
224                    len_X_train = len(true_X_train)
225
226                    for s in range(num_to_add):
227                        to_add_X = true_X_train[random.randrange(0, len_X_train), :, :]
228
229                        X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0)
230                        y_train = np.append(y_train, [1], axis=0)
231
232                # Undersampling
233                if self.undersample_ratio > 0:
234                    p_count = sum(y_train)
235                    n_count = len(y_train) - sum(y_train)
236
237                    num_to_remove = int(
238                        np.floor(n_count - (p_count / self.undersample_ratio))
239                    )
240
241                    ind_range = np.arange(len(y_train))
242                    ind_list = list(ind_range)
243                    to_remove = []
244
245                    # Remove num_to_remove random selections from the negative
246                    false_ind = list(ind_range[y_train == 0])
247
248                    for s in range(num_to_remove):
249                        # select a random value from the list of false indices
250                        remove_at = false_ind[random.randrange(0, len(false_ind))]
251
252                        # remove that value from the false index list
253                        false_ind.remove(remove_at)
254
255                        # add the index to be removed to a list
256                        to_remove.append(remove_at)
257
258                    remaining_ind = ind_list
259                    for i in range(len(to_remove)):
260                        remaining_ind.remove(to_remove[i])
261
262                    X_train = X_train[remaining_ind, :, :]
263                    y_train = y_train[remaining_ind]
264
265                self.clf.fit(X_train, y_train)
266                cv_preds[test_idx] = self.clf.predict(X_test)
267                predproba = self.clf.predict_proba(X_test)
268
269                # Use pred proba to show what would be predicted
270                predprobs = predproba[:, 1]
271                real = np.where(y_test == 1)
272
273                # TODO handle exception where two probabilities are the same
274                prediction = int(np.where(predprobs == np.amax(predprobs))[0][0])
275
276                logger.debug("y_test = %s", y_test)
277                logger.debug("predproba = %s", predproba)
278                logger.debug("real = %s", real[0])
279                logger.debug("prediction = %s", prediction)
280
281            # Train final model with all available data
282            self.clf.fit(X, y)
283            model = self.clf
284
285            accuracy = sum(cv_preds == self.y) / len(cv_preds)
286            precision = precision_score(self.y, cv_preds)
287            recall = recall_score(self.y, cv_preds)
288
289            return KernelResults(model, cv_preds, accuracy, precision, recall)
290
291        # Check if channel selection is true
292        if self.channel_selection_setup:
293            logger.info("Doing channel selection")
294            logger.debug("Initial subset: %s", self.chs_initial_subset)
295
296            channel_selection_results = channel_selection_by_method(
297                __erp_rg_kernel,
298                self.X,
299                self.y,
300                self.channel_labels,  # kernel setup
301                self.chs_method,
302                self.chs_metric,
303                self.chs_initial_subset,  # wrapper setup
304                self.chs_max_time,
305                self.chs_min_channels,
306                self.chs_max_channels,
307                self.chs_performance_delta,  # stopping criterion
308                self.chs_n_jobs,
309            )  # njobs, output messages
310
311            preds = channel_selection_results.best_preds
312            accuracy = channel_selection_results.best_accuracy
313            precision = channel_selection_results.best_precision
314            recall = channel_selection_results.best_recall
315
316            logger.info(
317                "The optimal subset is %s",
318                channel_selection_results.best_channel_subset,
319            )
320
321            self.results_df = channel_selection_results.results_df
322            self.subset = channel_selection_results.best_channel_subset
323            self.subset_defined = True
324            self.clf = channel_selection_results.best_model
325        else:
326            logger.warning("Not doing channel selection")
327            X = self.get_subset(self.X, self.subset, self.channel_labels)
328
329            current_results = __erp_rg_kernel(X, self.y)
330            self.clf = current_results.model
331            preds = current_results.cv_preds
332            accuracy = current_results.accuracy
333            precision = current_results.precision
334            recall = current_results.recall
335
336        # Log performance stats
337        # accuracy
338        accuracy = sum(preds == self.y) / len(preds)
339        self.offline_accuracy = accuracy
340        logger.info("Accuracy = %s", accuracy)
341
342        # precision
343        precision = precision_score(self.y, preds)
344        self.offline_precision = precision
345        logger.info("Precision = %s", precision)
346
347        # recall
348        recall = recall_score(self.y, preds)
349        self.offline_recall = recall
350        logger.info("Recall = %s", recall)
351
352        # confusion matrix in command line
353        cm = confusion_matrix(self.y, preds)
354        self.offline_cm = cm
355        logger.info("Confusion matrix:\n%s", cm)
356
357        if plot_cm:
358            cm = confusion_matrix(self.y, preds)
359            ConfusionMatrixDisplay(cm).plot()
360            plt.show()
361
362        if plot_roc:
363            logger.error("ROC plot has not been implemented yet")

Fit the model.

Parameters
  • plot_cm (bool, optional): Whether to plot the confusion matrix during training.
    • Default is False.
  • plot_roc (bool, optional): Whether to plot the ROC curve during training.
    • Default is False.
  • lico_expansion_factor (int, optional): Linear combination oversampling expansion factor. Determines the number of ERPs in the training set that will be expanded. Higher value increases the oversampling, generating more synthetic samples for the minority class.
    • Default is 1.
Returns
def predict(self, X):
365    def predict(self, X):
366        """Predict the class of the data (Unused in this classifier)
367
368        Parameters
369        ----------
370        X : numpy.ndarray
371            3D array where shape = (n_epochs, n_channels, n_samples)
372
373        Returns
374        -------
375        prediction : Prediction
376            Predict object. Contains the predicted labels and and the probability.
377            Because this classifier chooses the P300 object with the highest posterior probability,
378            the probability is only the posterior probability of the chosen object.
379
380        """
381
382        subset_X = self.get_subset(X, self.subset, self.channel_labels)
383
384        # Get posterior probability for each target
385        posterior_prob = self.clf.predict_proba(subset_X)[:, 1]
386
387        label = [int(np.argmax(posterior_prob))]
388        probability = [np.max(posterior_prob)]
389
390        return Prediction(label, probability)

Predict the class of the data (Unused in this classifier)

Parameters
  • X (numpy.ndarray): 3D array where shape = (n_epochs, n_channels, n_samples)
Returns
  • prediction (Prediction): Predict object. Contains the predicted labels and and the probability. Because this classifier chooses the P300 object with the highest posterior probability, the probability is only the posterior probability of the chosen object.