bci_essentials.classification.erp_rg_classifier

ERP RG Classifier

This classifier is used to classify ERPs using the Riemannian Geometry approach.

  1"""**ERP RG Classifier**
  2
  3This classifier is used to classify ERPs using the Riemannian Geometry
  4approach.
  5
  6"""
  7
  8# Stock libraries
  9import random
 10import numpy as np
 11import matplotlib.pyplot as plt
 12from sklearn.pipeline import make_pipeline
 13from sklearn.model_selection import StratifiedKFold
 14from sklearn.metrics import (
 15    confusion_matrix,
 16    ConfusionMatrixDisplay,
 17    precision_score,
 18    recall_score,
 19)
 20from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
 21from pyriemann.estimation import XdawnCovariances
 22from pyriemann.tangentspace import TangentSpace
 23from pyriemann.channelselection import FlatChannelRemover
 24
 25# Import bci_essentials modules and methods
 26from ..classification.generic_classifier import (
 27    GenericClassifier,
 28    Prediction,
 29    KernelResults,
 30)
 31from ..signal_processing import lico
 32from ..channel_selection import channel_selection_by_method
 33from ..utils.logger import Logger  # Logger wrapper
 34
 35# Instantiate a logger for the module at the default level of logging.INFO
 36# Logs to bci_essentials.__module__) where __module__ is the name of the module
 37logger = Logger(name=__name__)
 38
 39
 40class ErpRgClassifier(GenericClassifier):
 41    """ERP RG Classifier class (*inherits from `GenericClassifier`*)."""
 42
 43    def set_p300_clf_settings(
 44        self,
 45        n_splits=3,
 46        lico_expansion_factor=1,
 47        oversample_ratio=0,
 48        undersample_ratio=0,
 49        random_seed=42,
 50        covariance_estimator="oas",  # Covariance estimator, see pyriemann Covariances
 51        remove_flats=True,
 52    ):
 53        """Set P300 Classifier Settings.
 54
 55        Parameters
 56        ----------
 57        n_splits : int, *optional*
 58            Number of folds for cross-validation.
 59            - Default is `3`.
 60        lico_expansion_factor : int, *optional*
 61            Linear Combination Oversampling expansion factor, which is the
 62            factor by which the number of ERPs in the training set will be
 63            expanded.
 64            - Default is `1`.
 65        oversample_ratio : float, *optional*
 66            Traditional oversampling. Range is from from 0.1-1 resulting
 67            from the ratio of erp to non-erp class. 0 for no oversampling.
 68            - Default is `0`.
 69        undersample_ratio : float, *optional*
 70            Traditional undersampling. Range is from from 0.1-1 resulting
 71            from the ratio of erp to non-erp class. 0 for no undersampling.
 72            - Default is `0`.
 73        random_seed : int, *optional*
 74            Random seed.
 75            - Default is `42`.
 76        covariance_estimator : str, *optional*
 77            Covariance estimator. See pyriemann Covariances.
 78            - Default is `"oas"`.
 79        remove_flats : bool, *optional*
 80            Whether to remove flat channels.
 81            - Default is `True`.
 82
 83        Returns
 84        -------
 85        `None`
 86
 87        """
 88        self.n_splits = n_splits
 89        self.lico_expansion_factor = lico_expansion_factor
 90        self.oversample_ratio = oversample_ratio
 91        self.undersample_ratio = undersample_ratio
 92        self.random_seed = random_seed
 93        self.covariance_estimator = covariance_estimator
 94
 95        # Define the classifier
 96        self.clf = make_pipeline(
 97            XdawnCovariances(estimator=self.covariance_estimator),
 98            TangentSpace(metric="riemann"),
 99            LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"),
100        )
101
102        if remove_flats:
103            rf = FlatChannelRemover()
104            self.clf.steps.insert(0, ["Remove Flat Channels", rf])
105
106    def fit(
107        self,
108        plot_cm=False,
109        plot_roc=False,
110        lico_expansion_factor=1,
111    ):
112        """Fit the model.
113
114        Parameters
115        ----------
116        plot_cm : bool, *optional*
117            Whether to plot the confusion matrix during training.
118            - Default is `False`.
119        plot_roc : bool, *optional*
120            Whether to plot the ROC curve during training.
121            - Default is `False`.
122        lico_expansion_factor : int, *optional*
123            Linear combination oversampling expansion factor.
124            Determines the number of ERPs in the training set that will be expanded.
125            Higher value increases the oversampling, generating more synthetic
126            samples for the minority class.
127            - Default is `1`.
128
129        Returns
130        -------
131        `None`
132            Models created used in `predict()`.
133
134        """
135        logger.info("Fitting the model using RG")
136        logger.info("X shape: %s", self.X.shape)
137        logger.info("y shape: %s", self.y.shape)
138
139        # Define the strategy for cross validation
140        cv = StratifiedKFold(
141            n_splits=self.n_splits, shuffle=True, random_state=self.random_seed
142        )
143
144        # Init predictions to all false
145        cv_preds = np.zeros(len(self.y))
146
147        def __erp_rg_kernel(X, y):
148            """ERP RG kernel.
149
150            Parameters
151            ----------
152            X : numpy.ndarray
153                Input features (ERP data) for training.
154                3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`).
155                E.g. (100, 32, 1000) for 100 trials, 32 channels and 1000 samples per channel.
156
157            y : numpy.ndarray
158                Target labels corresponding to the input features in `X`.
159                1D numpy array with shape (n_trails, ).
160                Each label indicates the class of the corresponding trial in `X`.
161                E.g. (100, ) for 100 trials.
162
163
164            Returns
165            -------
166            kernelResults : KernelResults
167                KernelResults object containing the following attributes:
168                    model : classifier
169                        The trained classification model.
170                    cv_preds : numpy.ndarray
171                        The predictions from the model using cross validation.
172                        1D array with the same shape as `y`.
173                    accuracy : float
174                        The accuracy of the trained classification model.
175                    precision : float
176                        The precision of the trained classification model.
177                    recall : float
178                        The recall of the trained classification model.
179
180            """
181            for train_idx, test_idx in cv.split(X, y):
182                y_train, y_test = y[train_idx], y[test_idx]
183
184                X_train, X_test = X[train_idx], X[test_idx]
185
186                # LICO
187                logger.debug(
188                    "Before LICO:\n\tShape X: %s\n\tShape y: %s",
189                    X_train.shape,
190                    y_train.shape,
191                )
192
193                if sum(y_train) > 2:
194                    if lico_expansion_factor > 1:
195                        X_train, y_train = lico(
196                            X_train,
197                            y_train,
198                            expansion_factor=lico_expansion_factor,
199                            sum_num=2,
200                            shuffle=False,
201                        )
202                        logger.debug("y_train = %s", y_train)
203
204                logger.debug(
205                    "After LICO:\n\tShape X: %s\n\tShape y: %s",
206                    X_train.shape,
207                    y_train.shape,
208                )
209
210                # Oversampling
211                if self.oversample_ratio > 0:
212                    p_count = sum(y_train)
213                    n_count = len(y_train) - sum(y_train)
214
215                    num_to_add = int(
216                        np.floor((self.oversample_ratio * n_count) - p_count)
217                    )
218
219                    # Add num_to_add random selections from the positive
220                    true_X_train = X_train[y_train == 1]
221
222                    len_X_train = len(true_X_train)
223
224                    for s in range(num_to_add):
225                        to_add_X = true_X_train[random.randrange(0, len_X_train), :, :]
226
227                        X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0)
228                        y_train = np.append(y_train, [1], axis=0)
229
230                # Undersampling
231                if self.undersample_ratio > 0:
232                    p_count = sum(y_train)
233                    n_count = len(y_train) - sum(y_train)
234
235                    num_to_remove = int(
236                        np.floor(n_count - (p_count / self.undersample_ratio))
237                    )
238
239                    ind_range = np.arange(len(y_train))
240                    ind_list = list(ind_range)
241                    to_remove = []
242
243                    # Remove num_to_remove random selections from the negative
244                    false_ind = list(ind_range[y_train == 0])
245
246                    for s in range(num_to_remove):
247                        # select a random value from the list of false indices
248                        remove_at = false_ind[random.randrange(0, len(false_ind))]
249
250                        # remove that value from the false index list
251                        false_ind.remove(remove_at)
252
253                        # add the index to be removed to a list
254                        to_remove.append(remove_at)
255
256                    remaining_ind = ind_list
257                    for i in range(len(to_remove)):
258                        remaining_ind.remove(to_remove[i])
259
260                    X_train = X_train[remaining_ind, :, :]
261                    y_train = y_train[remaining_ind]
262
263                self.clf.fit(X_train, y_train)
264                cv_preds[test_idx] = self.clf.predict(X_test)
265                predproba = self.clf.predict_proba(X_test)
266
267                # Use pred proba to show what would be predicted
268                predprobs = predproba[:, 1]
269                real = np.where(y_test == 1)
270
271                # TODO handle exception where two probabilities are the same
272                prediction = int(np.where(predprobs == np.amax(predprobs))[0][0])
273
274                logger.debug("y_test = %s", y_test)
275                logger.debug("predproba = %s", predproba)
276                logger.debug("real = %s", real[0])
277                logger.debug("prediction = %s", prediction)
278
279            # Train final model with all available data
280            self.clf.fit(X, y)
281            model = self.clf
282
283            accuracy = sum(cv_preds == self.y) / len(cv_preds)
284            precision = precision_score(self.y, cv_preds)
285            recall = recall_score(self.y, cv_preds)
286
287            return KernelResults(model, cv_preds, accuracy, precision, recall)
288
289        # Check if channel selection is true
290        if self.channel_selection_setup:
291            logger.info("Doing channel selection")
292            logger.debug("Initial subset: %s", self.chs_initial_subset)
293
294            channel_selection_results = channel_selection_by_method(
295                __erp_rg_kernel,
296                self.X,
297                self.y,
298                self.channel_labels,  # kernel setup
299                self.chs_method,
300                self.chs_metric,
301                self.chs_initial_subset,  # wrapper setup
302                self.chs_max_time,
303                self.chs_min_channels,
304                self.chs_max_channels,
305                self.chs_performance_delta,  # stopping criterion
306                self.chs_n_jobs,
307            )  # njobs, output messages
308
309            preds = channel_selection_results.best_preds
310            accuracy = channel_selection_results.best_accuracy
311            precision = channel_selection_results.best_precision
312            recall = channel_selection_results.best_recall
313
314            logger.info(
315                "The optimal subset is %s",
316                channel_selection_results.best_channel_subset,
317            )
318
319            self.results_df = channel_selection_results.results_df
320            self.subset = channel_selection_results.best_channel_subset
321            self.subset_defined = True
322            self.clf = channel_selection_results.best_model
323        else:
324            logger.warning("Not doing channel selection")
325            X = self.get_subset(self.X, self.subset, self.channel_labels)
326
327            current_results = __erp_rg_kernel(X, self.y)
328            self.clf = current_results.model
329            preds = current_results.cv_preds
330            accuracy = current_results.accuracy
331            precision = current_results.precision
332            recall = current_results.recall
333
334        # Log performance stats
335        # accuracy
336        accuracy = sum(preds == self.y) / len(preds)
337        self.offline_accuracy = accuracy
338        logger.info("Accuracy = %s", accuracy)
339
340        # precision
341        precision = precision_score(self.y, preds)
342        self.offline_precision = precision
343        logger.info("Precision = %s", precision)
344
345        # recall
346        recall = recall_score(self.y, preds)
347        self.offline_recall = recall
348        logger.info("Recall = %s", recall)
349
350        # confusion matrix in command line
351        cm = confusion_matrix(self.y, preds)
352        self.offline_cm = cm
353        logger.info("Confusion matrix:\n%s", cm)
354
355        if plot_cm:
356            cm = confusion_matrix(self.y, preds)
357            ConfusionMatrixDisplay(cm).plot()
358            plt.show()
359
360        if plot_roc:
361            logger.error("ROC plot has not been implemented yet")
362
363    def predict(self, X):
364        """Predict the class of the data (Unused in this classifier)
365
366        Parameters
367        ----------
368        X : numpy.ndarray
369            3D array where shape = (n_epochs, n_channels, n_samples)
370
371        Returns
372        -------
373        prediction : Prediction
374            Predict object. Contains the predicted labels and and the probability.
375            Because this classifier chooses the P300 object with the highest posterior probability,
376            the probability is only the posterior probability of the chosen object.
377
378        """
379
380        subset_X = self.get_subset(X, self.subset, self.channel_labels)
381
382        # Get posterior probability for each target
383        posterior_probabilities = self.clf.predict_proba(subset_X)[:, 1]
384        label = [int(np.argmax(posterior_probabilities) + 1)]
385
386        return Prediction(label, posterior_probabilities)
 41class ErpRgClassifier(GenericClassifier):
 42    """ERP RG Classifier class (*inherits from `GenericClassifier`*)."""
 43
 44    def set_p300_clf_settings(
 45        self,
 46        n_splits=3,
 47        lico_expansion_factor=1,
 48        oversample_ratio=0,
 49        undersample_ratio=0,
 50        random_seed=42,
 51        covariance_estimator="oas",  # Covariance estimator, see pyriemann Covariances
 52        remove_flats=True,
 53    ):
 54        """Set P300 Classifier Settings.
 55
 56        Parameters
 57        ----------
 58        n_splits : int, *optional*
 59            Number of folds for cross-validation.
 60            - Default is `3`.
 61        lico_expansion_factor : int, *optional*
 62            Linear Combination Oversampling expansion factor, which is the
 63            factor by which the number of ERPs in the training set will be
 64            expanded.
 65            - Default is `1`.
 66        oversample_ratio : float, *optional*
 67            Traditional oversampling. Range is from from 0.1-1 resulting
 68            from the ratio of erp to non-erp class. 0 for no oversampling.
 69            - Default is `0`.
 70        undersample_ratio : float, *optional*
 71            Traditional undersampling. Range is from from 0.1-1 resulting
 72            from the ratio of erp to non-erp class. 0 for no undersampling.
 73            - Default is `0`.
 74        random_seed : int, *optional*
 75            Random seed.
 76            - Default is `42`.
 77        covariance_estimator : str, *optional*
 78            Covariance estimator. See pyriemann Covariances.
 79            - Default is `"oas"`.
 80        remove_flats : bool, *optional*
 81            Whether to remove flat channels.
 82            - Default is `True`.
 83
 84        Returns
 85        -------
 86        `None`
 87
 88        """
 89        self.n_splits = n_splits
 90        self.lico_expansion_factor = lico_expansion_factor
 91        self.oversample_ratio = oversample_ratio
 92        self.undersample_ratio = undersample_ratio
 93        self.random_seed = random_seed
 94        self.covariance_estimator = covariance_estimator
 95
 96        # Define the classifier
 97        self.clf = make_pipeline(
 98            XdawnCovariances(estimator=self.covariance_estimator),
 99            TangentSpace(metric="riemann"),
100            LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"),
101        )
102
103        if remove_flats:
104            rf = FlatChannelRemover()
105            self.clf.steps.insert(0, ["Remove Flat Channels", rf])
106
107    def fit(
108        self,
109        plot_cm=False,
110        plot_roc=False,
111        lico_expansion_factor=1,
112    ):
113        """Fit the model.
114
115        Parameters
116        ----------
117        plot_cm : bool, *optional*
118            Whether to plot the confusion matrix during training.
119            - Default is `False`.
120        plot_roc : bool, *optional*
121            Whether to plot the ROC curve during training.
122            - Default is `False`.
123        lico_expansion_factor : int, *optional*
124            Linear combination oversampling expansion factor.
125            Determines the number of ERPs in the training set that will be expanded.
126            Higher value increases the oversampling, generating more synthetic
127            samples for the minority class.
128            - Default is `1`.
129
130        Returns
131        -------
132        `None`
133            Models created used in `predict()`.
134
135        """
136        logger.info("Fitting the model using RG")
137        logger.info("X shape: %s", self.X.shape)
138        logger.info("y shape: %s", self.y.shape)
139
140        # Define the strategy for cross validation
141        cv = StratifiedKFold(
142            n_splits=self.n_splits, shuffle=True, random_state=self.random_seed
143        )
144
145        # Init predictions to all false
146        cv_preds = np.zeros(len(self.y))
147
148        def __erp_rg_kernel(X, y):
149            """ERP RG kernel.
150
151            Parameters
152            ----------
153            X : numpy.ndarray
154                Input features (ERP data) for training.
155                3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`).
156                E.g. (100, 32, 1000) for 100 trials, 32 channels and 1000 samples per channel.
157
158            y : numpy.ndarray
159                Target labels corresponding to the input features in `X`.
160                1D numpy array with shape (n_trails, ).
161                Each label indicates the class of the corresponding trial in `X`.
162                E.g. (100, ) for 100 trials.
163
164
165            Returns
166            -------
167            kernelResults : KernelResults
168                KernelResults object containing the following attributes:
169                    model : classifier
170                        The trained classification model.
171                    cv_preds : numpy.ndarray
172                        The predictions from the model using cross validation.
173                        1D array with the same shape as `y`.
174                    accuracy : float
175                        The accuracy of the trained classification model.
176                    precision : float
177                        The precision of the trained classification model.
178                    recall : float
179                        The recall of the trained classification model.
180
181            """
182            for train_idx, test_idx in cv.split(X, y):
183                y_train, y_test = y[train_idx], y[test_idx]
184
185                X_train, X_test = X[train_idx], X[test_idx]
186
187                # LICO
188                logger.debug(
189                    "Before LICO:\n\tShape X: %s\n\tShape y: %s",
190                    X_train.shape,
191                    y_train.shape,
192                )
193
194                if sum(y_train) > 2:
195                    if lico_expansion_factor > 1:
196                        X_train, y_train = lico(
197                            X_train,
198                            y_train,
199                            expansion_factor=lico_expansion_factor,
200                            sum_num=2,
201                            shuffle=False,
202                        )
203                        logger.debug("y_train = %s", y_train)
204
205                logger.debug(
206                    "After LICO:\n\tShape X: %s\n\tShape y: %s",
207                    X_train.shape,
208                    y_train.shape,
209                )
210
211                # Oversampling
212                if self.oversample_ratio > 0:
213                    p_count = sum(y_train)
214                    n_count = len(y_train) - sum(y_train)
215
216                    num_to_add = int(
217                        np.floor((self.oversample_ratio * n_count) - p_count)
218                    )
219
220                    # Add num_to_add random selections from the positive
221                    true_X_train = X_train[y_train == 1]
222
223                    len_X_train = len(true_X_train)
224
225                    for s in range(num_to_add):
226                        to_add_X = true_X_train[random.randrange(0, len_X_train), :, :]
227
228                        X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0)
229                        y_train = np.append(y_train, [1], axis=0)
230
231                # Undersampling
232                if self.undersample_ratio > 0:
233                    p_count = sum(y_train)
234                    n_count = len(y_train) - sum(y_train)
235
236                    num_to_remove = int(
237                        np.floor(n_count - (p_count / self.undersample_ratio))
238                    )
239
240                    ind_range = np.arange(len(y_train))
241                    ind_list = list(ind_range)
242                    to_remove = []
243
244                    # Remove num_to_remove random selections from the negative
245                    false_ind = list(ind_range[y_train == 0])
246
247                    for s in range(num_to_remove):
248                        # select a random value from the list of false indices
249                        remove_at = false_ind[random.randrange(0, len(false_ind))]
250
251                        # remove that value from the false index list
252                        false_ind.remove(remove_at)
253
254                        # add the index to be removed to a list
255                        to_remove.append(remove_at)
256
257                    remaining_ind = ind_list
258                    for i in range(len(to_remove)):
259                        remaining_ind.remove(to_remove[i])
260
261                    X_train = X_train[remaining_ind, :, :]
262                    y_train = y_train[remaining_ind]
263
264                self.clf.fit(X_train, y_train)
265                cv_preds[test_idx] = self.clf.predict(X_test)
266                predproba = self.clf.predict_proba(X_test)
267
268                # Use pred proba to show what would be predicted
269                predprobs = predproba[:, 1]
270                real = np.where(y_test == 1)
271
272                # TODO handle exception where two probabilities are the same
273                prediction = int(np.where(predprobs == np.amax(predprobs))[0][0])
274
275                logger.debug("y_test = %s", y_test)
276                logger.debug("predproba = %s", predproba)
277                logger.debug("real = %s", real[0])
278                logger.debug("prediction = %s", prediction)
279
280            # Train final model with all available data
281            self.clf.fit(X, y)
282            model = self.clf
283
284            accuracy = sum(cv_preds == self.y) / len(cv_preds)
285            precision = precision_score(self.y, cv_preds)
286            recall = recall_score(self.y, cv_preds)
287
288            return KernelResults(model, cv_preds, accuracy, precision, recall)
289
290        # Check if channel selection is true
291        if self.channel_selection_setup:
292            logger.info("Doing channel selection")
293            logger.debug("Initial subset: %s", self.chs_initial_subset)
294
295            channel_selection_results = channel_selection_by_method(
296                __erp_rg_kernel,
297                self.X,
298                self.y,
299                self.channel_labels,  # kernel setup
300                self.chs_method,
301                self.chs_metric,
302                self.chs_initial_subset,  # wrapper setup
303                self.chs_max_time,
304                self.chs_min_channels,
305                self.chs_max_channels,
306                self.chs_performance_delta,  # stopping criterion
307                self.chs_n_jobs,
308            )  # njobs, output messages
309
310            preds = channel_selection_results.best_preds
311            accuracy = channel_selection_results.best_accuracy
312            precision = channel_selection_results.best_precision
313            recall = channel_selection_results.best_recall
314
315            logger.info(
316                "The optimal subset is %s",
317                channel_selection_results.best_channel_subset,
318            )
319
320            self.results_df = channel_selection_results.results_df
321            self.subset = channel_selection_results.best_channel_subset
322            self.subset_defined = True
323            self.clf = channel_selection_results.best_model
324        else:
325            logger.warning("Not doing channel selection")
326            X = self.get_subset(self.X, self.subset, self.channel_labels)
327
328            current_results = __erp_rg_kernel(X, self.y)
329            self.clf = current_results.model
330            preds = current_results.cv_preds
331            accuracy = current_results.accuracy
332            precision = current_results.precision
333            recall = current_results.recall
334
335        # Log performance stats
336        # accuracy
337        accuracy = sum(preds == self.y) / len(preds)
338        self.offline_accuracy = accuracy
339        logger.info("Accuracy = %s", accuracy)
340
341        # precision
342        precision = precision_score(self.y, preds)
343        self.offline_precision = precision
344        logger.info("Precision = %s", precision)
345
346        # recall
347        recall = recall_score(self.y, preds)
348        self.offline_recall = recall
349        logger.info("Recall = %s", recall)
350
351        # confusion matrix in command line
352        cm = confusion_matrix(self.y, preds)
353        self.offline_cm = cm
354        logger.info("Confusion matrix:\n%s", cm)
355
356        if plot_cm:
357            cm = confusion_matrix(self.y, preds)
358            ConfusionMatrixDisplay(cm).plot()
359            plt.show()
360
361        if plot_roc:
362            logger.error("ROC plot has not been implemented yet")
363
364    def predict(self, X):
365        """Predict the class of the data (Unused in this classifier)
366
367        Parameters
368        ----------
369        X : numpy.ndarray
370            3D array where shape = (n_epochs, n_channels, n_samples)
371
372        Returns
373        -------
374        prediction : Prediction
375            Predict object. Contains the predicted labels and and the probability.
376            Because this classifier chooses the P300 object with the highest posterior probability,
377            the probability is only the posterior probability of the chosen object.
378
379        """
380
381        subset_X = self.get_subset(X, self.subset, self.channel_labels)
382
383        # Get posterior probability for each target
384        posterior_probabilities = self.clf.predict_proba(subset_X)[:, 1]
385        label = [int(np.argmax(posterior_probabilities) + 1)]
386
387        return Prediction(label, posterior_probabilities)

ERP RG Classifier class (inherits from GenericClassifier).

def set_p300_clf_settings( self, n_splits=3, lico_expansion_factor=1, oversample_ratio=0, undersample_ratio=0, random_seed=42, covariance_estimator='oas', remove_flats=True):
 44    def set_p300_clf_settings(
 45        self,
 46        n_splits=3,
 47        lico_expansion_factor=1,
 48        oversample_ratio=0,
 49        undersample_ratio=0,
 50        random_seed=42,
 51        covariance_estimator="oas",  # Covariance estimator, see pyriemann Covariances
 52        remove_flats=True,
 53    ):
 54        """Set P300 Classifier Settings.
 55
 56        Parameters
 57        ----------
 58        n_splits : int, *optional*
 59            Number of folds for cross-validation.
 60            - Default is `3`.
 61        lico_expansion_factor : int, *optional*
 62            Linear Combination Oversampling expansion factor, which is the
 63            factor by which the number of ERPs in the training set will be
 64            expanded.
 65            - Default is `1`.
 66        oversample_ratio : float, *optional*
 67            Traditional oversampling. Range is from from 0.1-1 resulting
 68            from the ratio of erp to non-erp class. 0 for no oversampling.
 69            - Default is `0`.
 70        undersample_ratio : float, *optional*
 71            Traditional undersampling. Range is from from 0.1-1 resulting
 72            from the ratio of erp to non-erp class. 0 for no undersampling.
 73            - Default is `0`.
 74        random_seed : int, *optional*
 75            Random seed.
 76            - Default is `42`.
 77        covariance_estimator : str, *optional*
 78            Covariance estimator. See pyriemann Covariances.
 79            - Default is `"oas"`.
 80        remove_flats : bool, *optional*
 81            Whether to remove flat channels.
 82            - Default is `True`.
 83
 84        Returns
 85        -------
 86        `None`
 87
 88        """
 89        self.n_splits = n_splits
 90        self.lico_expansion_factor = lico_expansion_factor
 91        self.oversample_ratio = oversample_ratio
 92        self.undersample_ratio = undersample_ratio
 93        self.random_seed = random_seed
 94        self.covariance_estimator = covariance_estimator
 95
 96        # Define the classifier
 97        self.clf = make_pipeline(
 98            XdawnCovariances(estimator=self.covariance_estimator),
 99            TangentSpace(metric="riemann"),
100            LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"),
101        )
102
103        if remove_flats:
104            rf = FlatChannelRemover()
105            self.clf.steps.insert(0, ["Remove Flat Channels", rf])

Set P300 Classifier Settings.

Parameters
  • n_splits (int, optional): Number of folds for cross-validation.
    • Default is 3.
  • lico_expansion_factor (int, optional): Linear Combination Oversampling expansion factor, which is the factor by which the number of ERPs in the training set will be expanded.
    • Default is 1.
  • oversample_ratio (float, optional): Traditional oversampling. Range is from from 0.1-1 resulting from the ratio of erp to non-erp class. 0 for no oversampling.
    • Default is 0.
  • undersample_ratio (float, optional): Traditional undersampling. Range is from from 0.1-1 resulting from the ratio of erp to non-erp class. 0 for no undersampling.
    • Default is 0.
  • random_seed (int, optional): Random seed.
    • Default is 42.
  • covariance_estimator (str, optional): Covariance estimator. See pyriemann Covariances.
    • Default is "oas".
  • remove_flats (bool, optional): Whether to remove flat channels.
    • Default is True.
Returns
  • None
def fit(self, plot_cm=False, plot_roc=False, lico_expansion_factor=1):
107    def fit(
108        self,
109        plot_cm=False,
110        plot_roc=False,
111        lico_expansion_factor=1,
112    ):
113        """Fit the model.
114
115        Parameters
116        ----------
117        plot_cm : bool, *optional*
118            Whether to plot the confusion matrix during training.
119            - Default is `False`.
120        plot_roc : bool, *optional*
121            Whether to plot the ROC curve during training.
122            - Default is `False`.
123        lico_expansion_factor : int, *optional*
124            Linear combination oversampling expansion factor.
125            Determines the number of ERPs in the training set that will be expanded.
126            Higher value increases the oversampling, generating more synthetic
127            samples for the minority class.
128            - Default is `1`.
129
130        Returns
131        -------
132        `None`
133            Models created used in `predict()`.
134
135        """
136        logger.info("Fitting the model using RG")
137        logger.info("X shape: %s", self.X.shape)
138        logger.info("y shape: %s", self.y.shape)
139
140        # Define the strategy for cross validation
141        cv = StratifiedKFold(
142            n_splits=self.n_splits, shuffle=True, random_state=self.random_seed
143        )
144
145        # Init predictions to all false
146        cv_preds = np.zeros(len(self.y))
147
148        def __erp_rg_kernel(X, y):
149            """ERP RG kernel.
150
151            Parameters
152            ----------
153            X : numpy.ndarray
154                Input features (ERP data) for training.
155                3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`).
156                E.g. (100, 32, 1000) for 100 trials, 32 channels and 1000 samples per channel.
157
158            y : numpy.ndarray
159                Target labels corresponding to the input features in `X`.
160                1D numpy array with shape (n_trails, ).
161                Each label indicates the class of the corresponding trial in `X`.
162                E.g. (100, ) for 100 trials.
163
164
165            Returns
166            -------
167            kernelResults : KernelResults
168                KernelResults object containing the following attributes:
169                    model : classifier
170                        The trained classification model.
171                    cv_preds : numpy.ndarray
172                        The predictions from the model using cross validation.
173                        1D array with the same shape as `y`.
174                    accuracy : float
175                        The accuracy of the trained classification model.
176                    precision : float
177                        The precision of the trained classification model.
178                    recall : float
179                        The recall of the trained classification model.
180
181            """
182            for train_idx, test_idx in cv.split(X, y):
183                y_train, y_test = y[train_idx], y[test_idx]
184
185                X_train, X_test = X[train_idx], X[test_idx]
186
187                # LICO
188                logger.debug(
189                    "Before LICO:\n\tShape X: %s\n\tShape y: %s",
190                    X_train.shape,
191                    y_train.shape,
192                )
193
194                if sum(y_train) > 2:
195                    if lico_expansion_factor > 1:
196                        X_train, y_train = lico(
197                            X_train,
198                            y_train,
199                            expansion_factor=lico_expansion_factor,
200                            sum_num=2,
201                            shuffle=False,
202                        )
203                        logger.debug("y_train = %s", y_train)
204
205                logger.debug(
206                    "After LICO:\n\tShape X: %s\n\tShape y: %s",
207                    X_train.shape,
208                    y_train.shape,
209                )
210
211                # Oversampling
212                if self.oversample_ratio > 0:
213                    p_count = sum(y_train)
214                    n_count = len(y_train) - sum(y_train)
215
216                    num_to_add = int(
217                        np.floor((self.oversample_ratio * n_count) - p_count)
218                    )
219
220                    # Add num_to_add random selections from the positive
221                    true_X_train = X_train[y_train == 1]
222
223                    len_X_train = len(true_X_train)
224
225                    for s in range(num_to_add):
226                        to_add_X = true_X_train[random.randrange(0, len_X_train), :, :]
227
228                        X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0)
229                        y_train = np.append(y_train, [1], axis=0)
230
231                # Undersampling
232                if self.undersample_ratio > 0:
233                    p_count = sum(y_train)
234                    n_count = len(y_train) - sum(y_train)
235
236                    num_to_remove = int(
237                        np.floor(n_count - (p_count / self.undersample_ratio))
238                    )
239
240                    ind_range = np.arange(len(y_train))
241                    ind_list = list(ind_range)
242                    to_remove = []
243
244                    # Remove num_to_remove random selections from the negative
245                    false_ind = list(ind_range[y_train == 0])
246
247                    for s in range(num_to_remove):
248                        # select a random value from the list of false indices
249                        remove_at = false_ind[random.randrange(0, len(false_ind))]
250
251                        # remove that value from the false index list
252                        false_ind.remove(remove_at)
253
254                        # add the index to be removed to a list
255                        to_remove.append(remove_at)
256
257                    remaining_ind = ind_list
258                    for i in range(len(to_remove)):
259                        remaining_ind.remove(to_remove[i])
260
261                    X_train = X_train[remaining_ind, :, :]
262                    y_train = y_train[remaining_ind]
263
264                self.clf.fit(X_train, y_train)
265                cv_preds[test_idx] = self.clf.predict(X_test)
266                predproba = self.clf.predict_proba(X_test)
267
268                # Use pred proba to show what would be predicted
269                predprobs = predproba[:, 1]
270                real = np.where(y_test == 1)
271
272                # TODO handle exception where two probabilities are the same
273                prediction = int(np.where(predprobs == np.amax(predprobs))[0][0])
274
275                logger.debug("y_test = %s", y_test)
276                logger.debug("predproba = %s", predproba)
277                logger.debug("real = %s", real[0])
278                logger.debug("prediction = %s", prediction)
279
280            # Train final model with all available data
281            self.clf.fit(X, y)
282            model = self.clf
283
284            accuracy = sum(cv_preds == self.y) / len(cv_preds)
285            precision = precision_score(self.y, cv_preds)
286            recall = recall_score(self.y, cv_preds)
287
288            return KernelResults(model, cv_preds, accuracy, precision, recall)
289
290        # Check if channel selection is true
291        if self.channel_selection_setup:
292            logger.info("Doing channel selection")
293            logger.debug("Initial subset: %s", self.chs_initial_subset)
294
295            channel_selection_results = channel_selection_by_method(
296                __erp_rg_kernel,
297                self.X,
298                self.y,
299                self.channel_labels,  # kernel setup
300                self.chs_method,
301                self.chs_metric,
302                self.chs_initial_subset,  # wrapper setup
303                self.chs_max_time,
304                self.chs_min_channels,
305                self.chs_max_channels,
306                self.chs_performance_delta,  # stopping criterion
307                self.chs_n_jobs,
308            )  # njobs, output messages
309
310            preds = channel_selection_results.best_preds
311            accuracy = channel_selection_results.best_accuracy
312            precision = channel_selection_results.best_precision
313            recall = channel_selection_results.best_recall
314
315            logger.info(
316                "The optimal subset is %s",
317                channel_selection_results.best_channel_subset,
318            )
319
320            self.results_df = channel_selection_results.results_df
321            self.subset = channel_selection_results.best_channel_subset
322            self.subset_defined = True
323            self.clf = channel_selection_results.best_model
324        else:
325            logger.warning("Not doing channel selection")
326            X = self.get_subset(self.X, self.subset, self.channel_labels)
327
328            current_results = __erp_rg_kernel(X, self.y)
329            self.clf = current_results.model
330            preds = current_results.cv_preds
331            accuracy = current_results.accuracy
332            precision = current_results.precision
333            recall = current_results.recall
334
335        # Log performance stats
336        # accuracy
337        accuracy = sum(preds == self.y) / len(preds)
338        self.offline_accuracy = accuracy
339        logger.info("Accuracy = %s", accuracy)
340
341        # precision
342        precision = precision_score(self.y, preds)
343        self.offline_precision = precision
344        logger.info("Precision = %s", precision)
345
346        # recall
347        recall = recall_score(self.y, preds)
348        self.offline_recall = recall
349        logger.info("Recall = %s", recall)
350
351        # confusion matrix in command line
352        cm = confusion_matrix(self.y, preds)
353        self.offline_cm = cm
354        logger.info("Confusion matrix:\n%s", cm)
355
356        if plot_cm:
357            cm = confusion_matrix(self.y, preds)
358            ConfusionMatrixDisplay(cm).plot()
359            plt.show()
360
361        if plot_roc:
362            logger.error("ROC plot has not been implemented yet")

Fit the model.

Parameters
  • plot_cm (bool, optional): Whether to plot the confusion matrix during training.
    • Default is False.
  • plot_roc (bool, optional): Whether to plot the ROC curve during training.
    • Default is False.
  • lico_expansion_factor (int, optional): Linear combination oversampling expansion factor. Determines the number of ERPs in the training set that will be expanded. Higher value increases the oversampling, generating more synthetic samples for the minority class.
    • Default is 1.
Returns
def predict(self, X):
364    def predict(self, X):
365        """Predict the class of the data (Unused in this classifier)
366
367        Parameters
368        ----------
369        X : numpy.ndarray
370            3D array where shape = (n_epochs, n_channels, n_samples)
371
372        Returns
373        -------
374        prediction : Prediction
375            Predict object. Contains the predicted labels and and the probability.
376            Because this classifier chooses the P300 object with the highest posterior probability,
377            the probability is only the posterior probability of the chosen object.
378
379        """
380
381        subset_X = self.get_subset(X, self.subset, self.channel_labels)
382
383        # Get posterior probability for each target
384        posterior_probabilities = self.clf.predict_proba(subset_X)[:, 1]
385        label = [int(np.argmax(posterior_probabilities) + 1)]
386
387        return Prediction(label, posterior_probabilities)

Predict the class of the data (Unused in this classifier)

Parameters
  • X (numpy.ndarray): 3D array where shape = (n_epochs, n_channels, n_samples)
Returns
  • prediction (Prediction): Predict object. Contains the predicted labels and and the probability. Because this classifier chooses the P300 object with the highest posterior probability, the probability is only the posterior probability of the chosen object.