bci_essentials.classification.erp_single_channel_classifier

ERP Single Channel Classifier

This classifier is used to classify ERPs when only a single channel (ex. ear EEG) is available.

  1"""**ERP Single Channel Classifier**
  2
  3This classifier is used to classify ERPs when only a single channel (ex. ear EEG) is available.
  4
  5"""
  6
  7# Stock libraries
  8import random
  9import numpy as np
 10import matplotlib.pyplot as plt
 11from sklearn.pipeline import make_pipeline
 12from sklearn.model_selection import StratifiedKFold
 13from sklearn.metrics import (
 14    confusion_matrix,
 15    ConfusionMatrixDisplay,
 16    precision_score,
 17    recall_score,
 18)
 19from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
 20
 21# Import bci_essentials modules and methods
 22from .generic_classifier import GenericClassifier, Prediction, KernelResults
 23from ..signal_processing import lico
 24from ..channel_selection import channel_selection_by_method
 25from ..utils.logger import Logger  # Logger wrapper
 26from ..utils.reduce_to_single_channel import ReduceToSingleChannel
 27
 28# Instantiate a logger for the module at the default level of logging.INFO
 29# Logs to bci_essentials.__module__) where __module__ is the name of the module
 30logger = Logger(name=__name__)
 31
 32
 33class ErpSingleChannelClassifier(GenericClassifier):
 34    """ERP Single Channel Classifier class (*inherits from `GenericClassifier`*)."""
 35
 36    def set_p300_clf_settings(
 37        self,
 38        n_splits=3,
 39        lico_expansion_factor=1,
 40        oversample_ratio=0,
 41        undersample_ratio=0,
 42        random_seed=42,
 43    ):
 44        """Set P300 Classifier Settings.
 45
 46        Parameters
 47        ----------
 48        n_splits : int, *optional*
 49            Number of folds for cross-validation.
 50            - Default is `3`.
 51        lico_expansion_factor : int, *optional*
 52            Linear Combination Oversampling expansion factor, which is the
 53            factor by which the number of ERPs in the training set will be
 54            expanded.
 55            - Default is `1`.
 56        oversample_ratio : float, *optional*
 57            Traditional oversampling. Range is from from 0.1-1 resulting
 58            from the ratio of erp to non-erp class. 0 for no oversampling.
 59            - Default is `0`.
 60        undersample_ratio : float, *optional*
 61            Traditional undersampling. Range is from from 0.1-1 resulting
 62            from the ratio of erp to non-erp class. 0 for no undersampling.
 63            - Default is `0`.
 64        random_seed : int, *optional*
 65            Random seed.
 66            - Default is `42`.
 67
 68        Returns
 69        -------
 70        `None`
 71
 72        """
 73        self.n_splits = n_splits
 74        self.lico_expansion_factor = lico_expansion_factor
 75        self.oversample_ratio = oversample_ratio
 76        self.undersample_ratio = undersample_ratio
 77        self.random_seed = random_seed
 78
 79    def fit(
 80        self,
 81        plot_cm=False,
 82        plot_roc=False,
 83        lico_expansion_factor=1,
 84    ):
 85        """Fit the model.
 86
 87        Parameters
 88        ----------
 89        plot_cm : bool, *optional*
 90            Whether to plot the confusion matrix during training.
 91            - Default is `False`.
 92        plot_roc : bool, *optional*
 93            Whether to plot the ROC curve during training.
 94            - Default is `False`.
 95        lico_expansion_factor : int, *optional*
 96            Linear combination oversampling expansion factor.
 97            Determines the number of ERPs in the training set that will be expanded.
 98            Higher value increases the oversampling, generating more synthetic
 99            samples for the minority class.
100            - Default is `1`.
101
102        Returns
103        -------
104        `None`
105            Models created used in `predict()`.
106
107        """
108
109        logger.info("Fitting the model using sLDA")
110        logger.info("X shape: %s", self.X.shape)
111        logger.info("y shape: %s", self.y.shape)
112
113        # Define the strategy for cross validation
114        cv = StratifiedKFold(
115            n_splits=self.n_splits, shuffle=True, random_state=self.random_seed
116        )
117
118        # Define the classifier
119        self.clf = make_pipeline(
120            ReduceToSingleChannel(),
121            LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"),
122        )
123
124        # Init predictions to all false
125        cv_preds = np.zeros(len(self.y))
126
127        #
128        def __erp_single_channel_kernel(X, y):
129            """ERP Single Channel kernel.
130
131            Parameters
132            ----------
133            X : numpy.ndarray
134                Input features (ERP data) for training.
135                3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`).
136                E.g. (100, 1, 1000) for 100 trials, 1 channel and 1000 samples.
137
138            y : numpy.ndarray
139                Target labels corresponding to the input features in `X`.
140                1D numpy array with shape (n_trails, ).
141                Each label indicates the class of the corresponding trial in `X`.
142                E.g. (100, ) for 100 trials.
143
144
145            Returns
146            -------
147            kernelResults : KernelResults
148                KernelResults object containing the following attributes:
149                    model : classifier
150                        The trained classification model.
151                    cv_preds : numpy.ndarray
152                        The predictions from the model using cross validation.
153                        1D array with the same shape as `y`.
154                    accuracy : float
155                        The accuracy of the trained classification model.
156                    precision : float
157                        The precision of the trained classification model.
158                    recall : float
159                        The recall of the trained classification model.
160
161            """
162            logger.info("X shape: %s", X.shape)
163
164            for train_idx, test_idx in cv.split(X, y):
165                y_train, y_test = y[train_idx], y[test_idx]
166
167                X_train, X_test = X[train_idx], X[test_idx]
168
169                # LICO
170                logger.debug(
171                    "Before LICO:\n\tShape X: %s\n\tShape y: %s",
172                    X_train.shape,
173                    y_train.shape,
174                )
175
176                if sum(y_train) > 2:
177                    if lico_expansion_factor > 1:
178                        X_train, y_train = lico(
179                            X_train,
180                            y_train,
181                            expansion_factor=lico_expansion_factor,
182                            sum_num=2,
183                            shuffle=False,
184                        )
185                        logger.debug("y_train = %s", y_train)
186
187                logger.debug(
188                    "After LICO:\n\tShape X: %s\n\tShape y: %s",
189                    X_train.shape,
190                    y_train.shape,
191                )
192
193                # Oversampling
194                if self.oversample_ratio > 0:
195                    p_count = sum(y_train)
196                    n_count = len(y_train) - sum(y_train)
197
198                    num_to_add = int(
199                        np.floor((self.oversample_ratio * n_count) - p_count)
200                    )
201
202                    # Add num_to_add random selections from the positive
203                    true_X_train = X_train[y_train == 1]
204
205                    len_X_train = len(true_X_train)
206
207                    for s in range(num_to_add):
208                        to_add_X = true_X_train[random.randrange(0, len_X_train), :, :]
209
210                        X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0)
211                        y_train = np.append(y_train, [1], axis=0)
212
213                # Undersampling
214                if self.undersample_ratio > 0:
215                    p_count = sum(y_train)
216                    n_count = len(y_train) - sum(y_train)
217
218                    num_to_remove = int(
219                        np.floor(n_count - (p_count / self.undersample_ratio))
220                    )
221
222                    ind_range = np.arange(len(y_train))
223                    ind_list = list(ind_range)
224                    to_remove = []
225
226                    # Remove num_to_remove random selections from the negative
227                    false_ind = list(ind_range[y_train == 0])
228
229                    for s in range(num_to_remove):
230                        # select a random value from the list of false indices
231                        remove_at = false_ind[random.randrange(0, len(false_ind))]
232
233                        # remove that value from the false ind list
234                        false_ind.remove(remove_at)
235
236                        # add the index to be removed to a list
237                        to_remove.append(remove_at)
238
239                    remaining_ind = ind_list
240                    for i in range(len(to_remove)):
241                        remaining_ind.remove(to_remove[i])
242
243                    X_train = X_train[remaining_ind, :, :]
244                    y_train = y_train[remaining_ind]
245
246                self.clf.fit(X_train, y_train)
247                cv_preds[test_idx] = self.clf.predict(X_test)
248                predproba = self.clf.predict_proba(X_test)
249
250                # Use pred proba to show what would be predicted
251                predprobs = predproba[:, 1]
252                real = np.where(y_test == 1)
253
254                # TODO handle exception where two probabilities are the same
255                prediction = int(np.where(predprobs == np.amax(predprobs))[0][0])
256
257                logger.debug("y_test = %s", y_test)
258                logger.debug("predproba = %s", predproba)
259                logger.debug("real = %s", real[0])
260                logger.debug("prediction = %s", prediction)
261
262            # Train final model with all available data
263            self.clf.fit(X, y)
264            model = self.clf
265
266            accuracy = sum(cv_preds == self.y) / len(cv_preds)
267            precision = precision_score(self.y, cv_preds)
268            recall = recall_score(self.y, cv_preds)
269
270            return KernelResults(model, cv_preds, accuracy, precision, recall)
271
272        # Check if channel selection is true
273        if self.channel_selection_setup:
274            logger.info("Doing channel selection")
275            logger.debug("Initial subset: %s", self.chs_initial_subset)
276
277            channel_selection_results = channel_selection_by_method(
278                __erp_single_channel_kernel,
279                self.X,
280                self.y,
281                self.channel_labels,  # kernel setup
282                self.chs_method,
283                self.chs_metric,
284                self.chs_initial_subset,  # wrapper setup
285                self.chs_max_time,
286                self.chs_min_channels,
287                self.chs_max_channels,
288                self.chs_performance_delta,  # stopping criterion
289                self.chs_n_jobs,
290            )  # njobs, output messages
291
292            preds = channel_selection_results.best_preds
293            accuracy = channel_selection_results.best_accuracy
294            precision = channel_selection_results.best_precision
295            recall = channel_selection_results.best_recall
296
297            logger.info(
298                "The optimal subset is %s",
299                channel_selection_results.best_channel_subset,
300            )
301
302            self.results_df = channel_selection_results.results_df
303            self.subset = channel_selection_results.best_channel_subset
304            self.clf = channel_selection_results.best_model
305        else:
306            logger.warning("Not doing channel selection")
307            current_results = __erp_single_channel_kernel(self.X, self.y)
308            self.clf = current_results.model
309            preds = current_results.cv_preds
310            accuracy = current_results.accuracy
311            precision = current_results.precision
312            recall = current_results.recall
313
314        # Log performance stats
315        # accuracy
316        accuracy = sum(preds == self.y) / len(preds)
317        self.offline_accuracy = accuracy
318        logger.info("Accuracy = %s", accuracy)
319
320        # precision
321        precision = precision_score(self.y, preds)
322        self.offline_precision = precision
323        logger.info("Precision = %s", precision)
324
325        # recall
326        recall = recall_score(self.y, preds)
327        self.offline_recall = recall
328        logger.info("Recall = %s", recall)
329
330        # confusion matrix in command line
331        cm = confusion_matrix(self.y, preds)
332        self.offline_cm = cm
333        logger.info("Confusion matrix:\n%s", cm)
334
335        if plot_cm:
336            cm = confusion_matrix(self.y, preds)
337            ConfusionMatrixDisplay(cm).plot()
338            plt.show()
339
340        if plot_roc:
341            logger.error("ROC plot has not been implemented yet")
342
343    def predict(self, X):
344        """Predict the class of the data (Unused in this classifier)
345
346        Parameters
347        ----------
348        X : numpy.ndarray
349            3D array where shape = (n_trials, n_channels, n_samples)
350
351        Returns
352        -------
353        prediction : Prediction
354            Empty Predict object
355
356        """
357
358        return Prediction()
class ErpSingleChannelClassifier(bci_essentials.classification.generic_classifier.GenericClassifier):
 34class ErpSingleChannelClassifier(GenericClassifier):
 35    """ERP Single Channel Classifier class (*inherits from `GenericClassifier`*)."""
 36
 37    def set_p300_clf_settings(
 38        self,
 39        n_splits=3,
 40        lico_expansion_factor=1,
 41        oversample_ratio=0,
 42        undersample_ratio=0,
 43        random_seed=42,
 44    ):
 45        """Set P300 Classifier Settings.
 46
 47        Parameters
 48        ----------
 49        n_splits : int, *optional*
 50            Number of folds for cross-validation.
 51            - Default is `3`.
 52        lico_expansion_factor : int, *optional*
 53            Linear Combination Oversampling expansion factor, which is the
 54            factor by which the number of ERPs in the training set will be
 55            expanded.
 56            - Default is `1`.
 57        oversample_ratio : float, *optional*
 58            Traditional oversampling. Range is from from 0.1-1 resulting
 59            from the ratio of erp to non-erp class. 0 for no oversampling.
 60            - Default is `0`.
 61        undersample_ratio : float, *optional*
 62            Traditional undersampling. Range is from from 0.1-1 resulting
 63            from the ratio of erp to non-erp class. 0 for no undersampling.
 64            - Default is `0`.
 65        random_seed : int, *optional*
 66            Random seed.
 67            - Default is `42`.
 68
 69        Returns
 70        -------
 71        `None`
 72
 73        """
 74        self.n_splits = n_splits
 75        self.lico_expansion_factor = lico_expansion_factor
 76        self.oversample_ratio = oversample_ratio
 77        self.undersample_ratio = undersample_ratio
 78        self.random_seed = random_seed
 79
 80    def fit(
 81        self,
 82        plot_cm=False,
 83        plot_roc=False,
 84        lico_expansion_factor=1,
 85    ):
 86        """Fit the model.
 87
 88        Parameters
 89        ----------
 90        plot_cm : bool, *optional*
 91            Whether to plot the confusion matrix during training.
 92            - Default is `False`.
 93        plot_roc : bool, *optional*
 94            Whether to plot the ROC curve during training.
 95            - Default is `False`.
 96        lico_expansion_factor : int, *optional*
 97            Linear combination oversampling expansion factor.
 98            Determines the number of ERPs in the training set that will be expanded.
 99            Higher value increases the oversampling, generating more synthetic
100            samples for the minority class.
101            - Default is `1`.
102
103        Returns
104        -------
105        `None`
106            Models created used in `predict()`.
107
108        """
109
110        logger.info("Fitting the model using sLDA")
111        logger.info("X shape: %s", self.X.shape)
112        logger.info("y shape: %s", self.y.shape)
113
114        # Define the strategy for cross validation
115        cv = StratifiedKFold(
116            n_splits=self.n_splits, shuffle=True, random_state=self.random_seed
117        )
118
119        # Define the classifier
120        self.clf = make_pipeline(
121            ReduceToSingleChannel(),
122            LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"),
123        )
124
125        # Init predictions to all false
126        cv_preds = np.zeros(len(self.y))
127
128        #
129        def __erp_single_channel_kernel(X, y):
130            """ERP Single Channel kernel.
131
132            Parameters
133            ----------
134            X : numpy.ndarray
135                Input features (ERP data) for training.
136                3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`).
137                E.g. (100, 1, 1000) for 100 trials, 1 channel and 1000 samples.
138
139            y : numpy.ndarray
140                Target labels corresponding to the input features in `X`.
141                1D numpy array with shape (n_trails, ).
142                Each label indicates the class of the corresponding trial in `X`.
143                E.g. (100, ) for 100 trials.
144
145
146            Returns
147            -------
148            kernelResults : KernelResults
149                KernelResults object containing the following attributes:
150                    model : classifier
151                        The trained classification model.
152                    cv_preds : numpy.ndarray
153                        The predictions from the model using cross validation.
154                        1D array with the same shape as `y`.
155                    accuracy : float
156                        The accuracy of the trained classification model.
157                    precision : float
158                        The precision of the trained classification model.
159                    recall : float
160                        The recall of the trained classification model.
161
162            """
163            logger.info("X shape: %s", X.shape)
164
165            for train_idx, test_idx in cv.split(X, y):
166                y_train, y_test = y[train_idx], y[test_idx]
167
168                X_train, X_test = X[train_idx], X[test_idx]
169
170                # LICO
171                logger.debug(
172                    "Before LICO:\n\tShape X: %s\n\tShape y: %s",
173                    X_train.shape,
174                    y_train.shape,
175                )
176
177                if sum(y_train) > 2:
178                    if lico_expansion_factor > 1:
179                        X_train, y_train = lico(
180                            X_train,
181                            y_train,
182                            expansion_factor=lico_expansion_factor,
183                            sum_num=2,
184                            shuffle=False,
185                        )
186                        logger.debug("y_train = %s", y_train)
187
188                logger.debug(
189                    "After LICO:\n\tShape X: %s\n\tShape y: %s",
190                    X_train.shape,
191                    y_train.shape,
192                )
193
194                # Oversampling
195                if self.oversample_ratio > 0:
196                    p_count = sum(y_train)
197                    n_count = len(y_train) - sum(y_train)
198
199                    num_to_add = int(
200                        np.floor((self.oversample_ratio * n_count) - p_count)
201                    )
202
203                    # Add num_to_add random selections from the positive
204                    true_X_train = X_train[y_train == 1]
205
206                    len_X_train = len(true_X_train)
207
208                    for s in range(num_to_add):
209                        to_add_X = true_X_train[random.randrange(0, len_X_train), :, :]
210
211                        X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0)
212                        y_train = np.append(y_train, [1], axis=0)
213
214                # Undersampling
215                if self.undersample_ratio > 0:
216                    p_count = sum(y_train)
217                    n_count = len(y_train) - sum(y_train)
218
219                    num_to_remove = int(
220                        np.floor(n_count - (p_count / self.undersample_ratio))
221                    )
222
223                    ind_range = np.arange(len(y_train))
224                    ind_list = list(ind_range)
225                    to_remove = []
226
227                    # Remove num_to_remove random selections from the negative
228                    false_ind = list(ind_range[y_train == 0])
229
230                    for s in range(num_to_remove):
231                        # select a random value from the list of false indices
232                        remove_at = false_ind[random.randrange(0, len(false_ind))]
233
234                        # remove that value from the false ind list
235                        false_ind.remove(remove_at)
236
237                        # add the index to be removed to a list
238                        to_remove.append(remove_at)
239
240                    remaining_ind = ind_list
241                    for i in range(len(to_remove)):
242                        remaining_ind.remove(to_remove[i])
243
244                    X_train = X_train[remaining_ind, :, :]
245                    y_train = y_train[remaining_ind]
246
247                self.clf.fit(X_train, y_train)
248                cv_preds[test_idx] = self.clf.predict(X_test)
249                predproba = self.clf.predict_proba(X_test)
250
251                # Use pred proba to show what would be predicted
252                predprobs = predproba[:, 1]
253                real = np.where(y_test == 1)
254
255                # TODO handle exception where two probabilities are the same
256                prediction = int(np.where(predprobs == np.amax(predprobs))[0][0])
257
258                logger.debug("y_test = %s", y_test)
259                logger.debug("predproba = %s", predproba)
260                logger.debug("real = %s", real[0])
261                logger.debug("prediction = %s", prediction)
262
263            # Train final model with all available data
264            self.clf.fit(X, y)
265            model = self.clf
266
267            accuracy = sum(cv_preds == self.y) / len(cv_preds)
268            precision = precision_score(self.y, cv_preds)
269            recall = recall_score(self.y, cv_preds)
270
271            return KernelResults(model, cv_preds, accuracy, precision, recall)
272
273        # Check if channel selection is true
274        if self.channel_selection_setup:
275            logger.info("Doing channel selection")
276            logger.debug("Initial subset: %s", self.chs_initial_subset)
277
278            channel_selection_results = channel_selection_by_method(
279                __erp_single_channel_kernel,
280                self.X,
281                self.y,
282                self.channel_labels,  # kernel setup
283                self.chs_method,
284                self.chs_metric,
285                self.chs_initial_subset,  # wrapper setup
286                self.chs_max_time,
287                self.chs_min_channels,
288                self.chs_max_channels,
289                self.chs_performance_delta,  # stopping criterion
290                self.chs_n_jobs,
291            )  # njobs, output messages
292
293            preds = channel_selection_results.best_preds
294            accuracy = channel_selection_results.best_accuracy
295            precision = channel_selection_results.best_precision
296            recall = channel_selection_results.best_recall
297
298            logger.info(
299                "The optimal subset is %s",
300                channel_selection_results.best_channel_subset,
301            )
302
303            self.results_df = channel_selection_results.results_df
304            self.subset = channel_selection_results.best_channel_subset
305            self.clf = channel_selection_results.best_model
306        else:
307            logger.warning("Not doing channel selection")
308            current_results = __erp_single_channel_kernel(self.X, self.y)
309            self.clf = current_results.model
310            preds = current_results.cv_preds
311            accuracy = current_results.accuracy
312            precision = current_results.precision
313            recall = current_results.recall
314
315        # Log performance stats
316        # accuracy
317        accuracy = sum(preds == self.y) / len(preds)
318        self.offline_accuracy = accuracy
319        logger.info("Accuracy = %s", accuracy)
320
321        # precision
322        precision = precision_score(self.y, preds)
323        self.offline_precision = precision
324        logger.info("Precision = %s", precision)
325
326        # recall
327        recall = recall_score(self.y, preds)
328        self.offline_recall = recall
329        logger.info("Recall = %s", recall)
330
331        # confusion matrix in command line
332        cm = confusion_matrix(self.y, preds)
333        self.offline_cm = cm
334        logger.info("Confusion matrix:\n%s", cm)
335
336        if plot_cm:
337            cm = confusion_matrix(self.y, preds)
338            ConfusionMatrixDisplay(cm).plot()
339            plt.show()
340
341        if plot_roc:
342            logger.error("ROC plot has not been implemented yet")
343
344    def predict(self, X):
345        """Predict the class of the data (Unused in this classifier)
346
347        Parameters
348        ----------
349        X : numpy.ndarray
350            3D array where shape = (n_trials, n_channels, n_samples)
351
352        Returns
353        -------
354        prediction : Prediction
355            Empty Predict object
356
357        """
358
359        return Prediction()

ERP Single Channel Classifier class (inherits from GenericClassifier).

def set_p300_clf_settings( self, n_splits=3, lico_expansion_factor=1, oversample_ratio=0, undersample_ratio=0, random_seed=42):
37    def set_p300_clf_settings(
38        self,
39        n_splits=3,
40        lico_expansion_factor=1,
41        oversample_ratio=0,
42        undersample_ratio=0,
43        random_seed=42,
44    ):
45        """Set P300 Classifier Settings.
46
47        Parameters
48        ----------
49        n_splits : int, *optional*
50            Number of folds for cross-validation.
51            - Default is `3`.
52        lico_expansion_factor : int, *optional*
53            Linear Combination Oversampling expansion factor, which is the
54            factor by which the number of ERPs in the training set will be
55            expanded.
56            - Default is `1`.
57        oversample_ratio : float, *optional*
58            Traditional oversampling. Range is from from 0.1-1 resulting
59            from the ratio of erp to non-erp class. 0 for no oversampling.
60            - Default is `0`.
61        undersample_ratio : float, *optional*
62            Traditional undersampling. Range is from from 0.1-1 resulting
63            from the ratio of erp to non-erp class. 0 for no undersampling.
64            - Default is `0`.
65        random_seed : int, *optional*
66            Random seed.
67            - Default is `42`.
68
69        Returns
70        -------
71        `None`
72
73        """
74        self.n_splits = n_splits
75        self.lico_expansion_factor = lico_expansion_factor
76        self.oversample_ratio = oversample_ratio
77        self.undersample_ratio = undersample_ratio
78        self.random_seed = random_seed

Set P300 Classifier Settings.

Parameters
  • n_splits (int, optional): Number of folds for cross-validation.
    • Default is 3.
  • lico_expansion_factor (int, optional): Linear Combination Oversampling expansion factor, which is the factor by which the number of ERPs in the training set will be expanded.
    • Default is 1.
  • oversample_ratio (float, optional): Traditional oversampling. Range is from from 0.1-1 resulting from the ratio of erp to non-erp class. 0 for no oversampling.
    • Default is 0.
  • undersample_ratio (float, optional): Traditional undersampling. Range is from from 0.1-1 resulting from the ratio of erp to non-erp class. 0 for no undersampling.
    • Default is 0.
  • random_seed (int, optional): Random seed.
    • Default is 42.
Returns
  • None
def fit(self, plot_cm=False, plot_roc=False, lico_expansion_factor=1):
 80    def fit(
 81        self,
 82        plot_cm=False,
 83        plot_roc=False,
 84        lico_expansion_factor=1,
 85    ):
 86        """Fit the model.
 87
 88        Parameters
 89        ----------
 90        plot_cm : bool, *optional*
 91            Whether to plot the confusion matrix during training.
 92            - Default is `False`.
 93        plot_roc : bool, *optional*
 94            Whether to plot the ROC curve during training.
 95            - Default is `False`.
 96        lico_expansion_factor : int, *optional*
 97            Linear combination oversampling expansion factor.
 98            Determines the number of ERPs in the training set that will be expanded.
 99            Higher value increases the oversampling, generating more synthetic
100            samples for the minority class.
101            - Default is `1`.
102
103        Returns
104        -------
105        `None`
106            Models created used in `predict()`.
107
108        """
109
110        logger.info("Fitting the model using sLDA")
111        logger.info("X shape: %s", self.X.shape)
112        logger.info("y shape: %s", self.y.shape)
113
114        # Define the strategy for cross validation
115        cv = StratifiedKFold(
116            n_splits=self.n_splits, shuffle=True, random_state=self.random_seed
117        )
118
119        # Define the classifier
120        self.clf = make_pipeline(
121            ReduceToSingleChannel(),
122            LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"),
123        )
124
125        # Init predictions to all false
126        cv_preds = np.zeros(len(self.y))
127
128        #
129        def __erp_single_channel_kernel(X, y):
130            """ERP Single Channel kernel.
131
132            Parameters
133            ----------
134            X : numpy.ndarray
135                Input features (ERP data) for training.
136                3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`).
137                E.g. (100, 1, 1000) for 100 trials, 1 channel and 1000 samples.
138
139            y : numpy.ndarray
140                Target labels corresponding to the input features in `X`.
141                1D numpy array with shape (n_trails, ).
142                Each label indicates the class of the corresponding trial in `X`.
143                E.g. (100, ) for 100 trials.
144
145
146            Returns
147            -------
148            kernelResults : KernelResults
149                KernelResults object containing the following attributes:
150                    model : classifier
151                        The trained classification model.
152                    cv_preds : numpy.ndarray
153                        The predictions from the model using cross validation.
154                        1D array with the same shape as `y`.
155                    accuracy : float
156                        The accuracy of the trained classification model.
157                    precision : float
158                        The precision of the trained classification model.
159                    recall : float
160                        The recall of the trained classification model.
161
162            """
163            logger.info("X shape: %s", X.shape)
164
165            for train_idx, test_idx in cv.split(X, y):
166                y_train, y_test = y[train_idx], y[test_idx]
167
168                X_train, X_test = X[train_idx], X[test_idx]
169
170                # LICO
171                logger.debug(
172                    "Before LICO:\n\tShape X: %s\n\tShape y: %s",
173                    X_train.shape,
174                    y_train.shape,
175                )
176
177                if sum(y_train) > 2:
178                    if lico_expansion_factor > 1:
179                        X_train, y_train = lico(
180                            X_train,
181                            y_train,
182                            expansion_factor=lico_expansion_factor,
183                            sum_num=2,
184                            shuffle=False,
185                        )
186                        logger.debug("y_train = %s", y_train)
187
188                logger.debug(
189                    "After LICO:\n\tShape X: %s\n\tShape y: %s",
190                    X_train.shape,
191                    y_train.shape,
192                )
193
194                # Oversampling
195                if self.oversample_ratio > 0:
196                    p_count = sum(y_train)
197                    n_count = len(y_train) - sum(y_train)
198
199                    num_to_add = int(
200                        np.floor((self.oversample_ratio * n_count) - p_count)
201                    )
202
203                    # Add num_to_add random selections from the positive
204                    true_X_train = X_train[y_train == 1]
205
206                    len_X_train = len(true_X_train)
207
208                    for s in range(num_to_add):
209                        to_add_X = true_X_train[random.randrange(0, len_X_train), :, :]
210
211                        X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0)
212                        y_train = np.append(y_train, [1], axis=0)
213
214                # Undersampling
215                if self.undersample_ratio > 0:
216                    p_count = sum(y_train)
217                    n_count = len(y_train) - sum(y_train)
218
219                    num_to_remove = int(
220                        np.floor(n_count - (p_count / self.undersample_ratio))
221                    )
222
223                    ind_range = np.arange(len(y_train))
224                    ind_list = list(ind_range)
225                    to_remove = []
226
227                    # Remove num_to_remove random selections from the negative
228                    false_ind = list(ind_range[y_train == 0])
229
230                    for s in range(num_to_remove):
231                        # select a random value from the list of false indices
232                        remove_at = false_ind[random.randrange(0, len(false_ind))]
233
234                        # remove that value from the false ind list
235                        false_ind.remove(remove_at)
236
237                        # add the index to be removed to a list
238                        to_remove.append(remove_at)
239
240                    remaining_ind = ind_list
241                    for i in range(len(to_remove)):
242                        remaining_ind.remove(to_remove[i])
243
244                    X_train = X_train[remaining_ind, :, :]
245                    y_train = y_train[remaining_ind]
246
247                self.clf.fit(X_train, y_train)
248                cv_preds[test_idx] = self.clf.predict(X_test)
249                predproba = self.clf.predict_proba(X_test)
250
251                # Use pred proba to show what would be predicted
252                predprobs = predproba[:, 1]
253                real = np.where(y_test == 1)
254
255                # TODO handle exception where two probabilities are the same
256                prediction = int(np.where(predprobs == np.amax(predprobs))[0][0])
257
258                logger.debug("y_test = %s", y_test)
259                logger.debug("predproba = %s", predproba)
260                logger.debug("real = %s", real[0])
261                logger.debug("prediction = %s", prediction)
262
263            # Train final model with all available data
264            self.clf.fit(X, y)
265            model = self.clf
266
267            accuracy = sum(cv_preds == self.y) / len(cv_preds)
268            precision = precision_score(self.y, cv_preds)
269            recall = recall_score(self.y, cv_preds)
270
271            return KernelResults(model, cv_preds, accuracy, precision, recall)
272
273        # Check if channel selection is true
274        if self.channel_selection_setup:
275            logger.info("Doing channel selection")
276            logger.debug("Initial subset: %s", self.chs_initial_subset)
277
278            channel_selection_results = channel_selection_by_method(
279                __erp_single_channel_kernel,
280                self.X,
281                self.y,
282                self.channel_labels,  # kernel setup
283                self.chs_method,
284                self.chs_metric,
285                self.chs_initial_subset,  # wrapper setup
286                self.chs_max_time,
287                self.chs_min_channels,
288                self.chs_max_channels,
289                self.chs_performance_delta,  # stopping criterion
290                self.chs_n_jobs,
291            )  # njobs, output messages
292
293            preds = channel_selection_results.best_preds
294            accuracy = channel_selection_results.best_accuracy
295            precision = channel_selection_results.best_precision
296            recall = channel_selection_results.best_recall
297
298            logger.info(
299                "The optimal subset is %s",
300                channel_selection_results.best_channel_subset,
301            )
302
303            self.results_df = channel_selection_results.results_df
304            self.subset = channel_selection_results.best_channel_subset
305            self.clf = channel_selection_results.best_model
306        else:
307            logger.warning("Not doing channel selection")
308            current_results = __erp_single_channel_kernel(self.X, self.y)
309            self.clf = current_results.model
310            preds = current_results.cv_preds
311            accuracy = current_results.accuracy
312            precision = current_results.precision
313            recall = current_results.recall
314
315        # Log performance stats
316        # accuracy
317        accuracy = sum(preds == self.y) / len(preds)
318        self.offline_accuracy = accuracy
319        logger.info("Accuracy = %s", accuracy)
320
321        # precision
322        precision = precision_score(self.y, preds)
323        self.offline_precision = precision
324        logger.info("Precision = %s", precision)
325
326        # recall
327        recall = recall_score(self.y, preds)
328        self.offline_recall = recall
329        logger.info("Recall = %s", recall)
330
331        # confusion matrix in command line
332        cm = confusion_matrix(self.y, preds)
333        self.offline_cm = cm
334        logger.info("Confusion matrix:\n%s", cm)
335
336        if plot_cm:
337            cm = confusion_matrix(self.y, preds)
338            ConfusionMatrixDisplay(cm).plot()
339            plt.show()
340
341        if plot_roc:
342            logger.error("ROC plot has not been implemented yet")

Fit the model.

Parameters
  • plot_cm (bool, optional): Whether to plot the confusion matrix during training.
    • Default is False.
  • plot_roc (bool, optional): Whether to plot the ROC curve during training.
    • Default is False.
  • lico_expansion_factor (int, optional): Linear combination oversampling expansion factor. Determines the number of ERPs in the training set that will be expanded. Higher value increases the oversampling, generating more synthetic samples for the minority class.
    • Default is 1.
Returns
def predict(self, X):
344    def predict(self, X):
345        """Predict the class of the data (Unused in this classifier)
346
347        Parameters
348        ----------
349        X : numpy.ndarray
350            3D array where shape = (n_trials, n_channels, n_samples)
351
352        Returns
353        -------
354        prediction : Prediction
355            Empty Predict object
356
357        """
358
359        return Prediction()

Predict the class of the data (Unused in this classifier)

Parameters
  • X (numpy.ndarray): 3D array where shape = (n_trials, n_channels, n_samples)
Returns
  • prediction (Prediction): Empty Predict object