bci_essentials.classification.mi_classifier

MI Classifier

This classifier is used to classify MI data.

  1"""**MI Classifier**
  2
  3This classifier is used to classify MI data.
  4
  5"""
  6
  7# Stock libraries
  8import numpy as np
  9from sklearn.model_selection import StratifiedKFold
 10from sklearn.pipeline import Pipeline
 11from sklearn.ensemble import RandomForestClassifier
 12from sklearn.metrics import confusion_matrix, precision_score, recall_score
 13from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
 14from pyriemann.preprocessing import Whitening
 15from pyriemann.estimation import Covariances
 16from pyriemann.classification import MDM, TSclassifier
 17from pyriemann.channelselection import FlatChannelRemover
 18
 19# Import bci_essentials modules and methods
 20from ..classification.generic_classifier import (
 21    GenericClassifier,
 22    Prediction,
 23    KernelResults,
 24)
 25from ..channel_selection import channel_selection_by_method
 26from ..utils.logger import Logger  # Logger wrapper
 27
 28# Instantiate a logger for the module at the default level of logging.INFO
 29# Logs to bci_essentials.__module__) where __module__ is the name of the module
 30logger = Logger(name=__name__)
 31
 32
 33class MiClassifier(GenericClassifier):
 34    """MI Classifier class (*inherits from `GenericClassifier`*)."""
 35
 36    def set_mi_classifier_settings(
 37        self,
 38        n_splits=5,
 39        type="TS",
 40        remove_flats=True,
 41        whitening=False,
 42        covariance_estimator="oas",
 43        artifact_rejection="none",
 44        pred_threshold=0.5,
 45        random_seed=42,
 46        n_jobs=1,
 47    ):
 48        """Set MI classifier settings.
 49
 50        Parameters
 51        ----------
 52        n_splits : int, *optional*
 53            Number of folds for cross-validation.
 54            - Default is `5`.
 55        type : str, *optional*
 56            Type of classifier to be used.
 57            Options = sLDA, RandomForest, TS, or MDM.
 58            - Default is `"TS"`.
 59        remove_flats : bool, *optional*
 60            Whether to remove flat channels from the EEG data.
 61            - Default is `True`.
 62        whitening : bool, *optional*
 63            Whether to apply whitening to the EEG data.
 64            - Default is `False`.
 65        covariance_estimator : str, *optional*
 66            Covariance estimator. See pyriemann Covariances.
 67            - Default is `"oas"`.
 68        artifact_rejection : str, *optional*
 69            Method for artefact rejection.
 70            - Default is `"none"`.
 71        pred_threshold : float, *optional*
 72            Prediction threshold used for classification.
 73            - Default is `0.5`.
 74        random_seed : int, *optional*
 75            Random seed.
 76            - Default is `42`.
 77        n_jobs : int, *optional*
 78            The number of threads to dedicate to this calculation.
 79            - Default is `1`.
 80
 81        Returns
 82        -------
 83        `None`
 84            Models created are used in `fit()`.
 85
 86        """
 87        # Build the cross-validation split
 88        self.n_splits = n_splits
 89        self.cv = StratifiedKFold(
 90            n_splits=n_splits, shuffle=True, random_state=random_seed
 91        )
 92
 93        self.covariance_estimator = covariance_estimator
 94
 95        # Shrinkage LDA
 96        if type == "sLDA":
 97            slda = LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto")
 98            self.clf_model = Pipeline([("Shrinkage LDA", slda)])
 99            self.clf = Pipeline([("Shrinkage LDA", slda)])
100
101        # Random Forest
102        elif type == "RandomForest":
103            rf = RandomForestClassifier()
104            self.clf_model = Pipeline([("Random Forest", rf)])
105            self.clf = Pipeline([("Random Forest", rf)])
106
107        # Tangent Space Logistic Regression
108        elif type == "TS":
109            ts = TSclassifier()
110            self.clf_model = Pipeline([("Tangent Space", ts)])
111            self.clf = Pipeline([("Tangent Space", ts)])
112
113        # Minimum Distance to Mean
114        elif type == "MDM":
115            mdm = MDM(metric=dict(mean="riemann", distance="riemann"), n_jobs=n_jobs)
116            self.clf_model = Pipeline([("MDM", mdm)])
117            self.clf = Pipeline([("MDM", mdm)])
118
119        else:
120            logger.error("Classifier type not defined")
121
122        # All algorithms have covariance estimation as the first step
123        self.clf_model.steps.insert(
124            0, ["Covariances", Covariances(estimator=self.covariance_estimator)]
125        )
126        self.clf.steps.insert(
127            0, ["Covariances", Covariances(estimator=self.covariance_estimator)]
128        )
129
130        if artifact_rejection == "potato":
131            logger.error("Potato not implemented")
132
133        if whitening:
134            self.clf_model.steps.insert(0, ["Whitening", Whitening()])
135            self.clf.steps.insert(0, ["Whitening", Whitening()])
136
137        if remove_flats:
138            rf = FlatChannelRemover()
139            self.clf_model.steps.insert(0, ["Remove Flat Channels", rf])
140            self.clf.steps.insert(0, ["Remove Flat Channels", rf])
141
142        # Threshold
143        self.pred_threshold = pred_threshold
144
145        # Rebuild from scratch with each training
146        self.rebuild = True
147
148    def fit(self):
149        """Fit the model.
150
151        Returns
152        -------
153        `None`
154            Models created used in `predict()`.
155
156        """
157        # get dimensions
158        n_trials, n_channels, n_samples = self.X.shape
159
160        # do the rest of the training if train_free is false
161        self.X = np.array(self.X)
162
163        # Try rebuilding the classifier each time
164        if self.rebuild:
165            self.next_fit_trial = 0
166            self.clf = self.clf_model
167
168        # get temporal subset
169        subX = self.X[self.next_fit_trial :, :, :]
170        suby = self.y[self.next_fit_trial :]
171        self.next_fit_trial = n_trials
172
173        # Init predictions to all false
174        cv_preds = np.zeros(n_trials)
175
176        def __mi_kernel(subX, suby):
177            """MI kernel.
178
179            Parameters
180            ----------
181            subX : numpy.ndarray
182                EEG data for training.
183                3D array with shape = (`n_epochs`, `n_channels`, `n_samples`).
184            suby : numpy.ndarray
185                Labels for training data.
186                1D array with shape = (`n_epochs`, ).
187
188            Returns
189            -------
190            kernelResults : KernelResults
191                KernelResults object containing the following attributes:
192                    model : classifier
193                        The trained classification model.
194                    cv_preds : numpy.ndarray
195                        The predictions from the model using cross validation.
196                        1D array with the same shape as `suby`.
197                    accuracy : float
198                        The accuracy of the trained classification model.
199                    precision : float
200                        The precision of the trained classification model.
201                    recall : float
202                        The recall of the trained classification model.
203
204
205            """
206            for train_idx, test_idx in self.cv.split(subX, suby):
207                self.clf = self.clf_model
208
209                X_train, X_test = subX[train_idx], subX[test_idx]
210                # y_test not implemented
211                y_train = suby[train_idx]
212
213                # fit the classsifier
214                self.clf.fit(X_train, y_train)
215                cv_preds[test_idx] = self.clf.predict(X_test)
216
217            # Train final model with all available data
218            self.clf.fit(subX, suby)
219            model = self.clf
220
221            accuracy = sum(cv_preds == self.y) / len(cv_preds)
222            precision = precision_score(self.y, cv_preds, average="micro")
223            recall = recall_score(self.y, cv_preds, average="micro")
224
225            return KernelResults(model, cv_preds, accuracy, precision, recall)
226
227        # Check if channel selection is true
228        if self.channel_selection_setup:
229            if self.chs_iterative_selection is True and self.subset is not None:
230                initial_subset = self.subset
231                logger.info(
232                    "Using subset from previous channel selection "
233                    + "because iterative selection is TRUE"
234                )
235            else:
236                initial_subset = self.chs_initial_subset
237
238            logger.info("Doing channel selection")
239            channel_selection_results = channel_selection_by_method(
240                __mi_kernel,
241                self.X,
242                self.y,
243                self.channel_labels,  # kernel setup
244                self.chs_method,
245                self.chs_metric,
246                initial_subset,  # wrapper setup
247                self.chs_max_time,
248                self.chs_min_channels,
249                self.chs_max_channels,
250                self.chs_performance_delta,  # stopping criterion
251                self.chs_n_jobs,
252            )
253
254            preds = channel_selection_results.best_preds
255            accuracy = channel_selection_results.best_accuracy
256            precision = channel_selection_results.best_precision
257            recall = channel_selection_results.best_recall
258
259            self.results_df = channel_selection_results.results_df
260            self.subset = channel_selection_results.best_channel_subset
261            self.subset_defined = True
262            self.clf = channel_selection_results.best_model
263        else:
264            logger.warning("Not doing channel selection")
265
266            subX = self.get_subset(subX, self.subset, self.channel_labels)
267
268            current_results = __mi_kernel(subX, suby)
269            self.clf = current_results.model
270            preds = current_results.cv_preds
271            accuracy = current_results.accuracy
272            precision = current_results.precision
273            recall = current_results.recall
274
275        # Log performance stats
276
277        self.offline_trial_count = n_trials
278        self.offline_trial_counts.append(self.offline_trial_count)
279
280        # accuracy
281        accuracy = sum(preds == self.y) / len(preds)
282        self.offline_accuracy.append(accuracy)
283        logger.info("Accuracy = %s", accuracy)
284
285        # precision
286        precision = precision_score(self.y, preds, average="micro")
287        self.offline_precision.append(precision)
288        logger.info("Precision = %s", precision)
289
290        # recall
291        recall = recall_score(self.y, preds, average="micro")
292        self.offline_recall.append(recall)
293        logger.info("Recall = %s", recall)
294
295        # confusion matrix in command line
296        cm = confusion_matrix(self.y, preds)
297        self.offline_cm = cm
298        logger.info("Confusion matrix:\n%s", cm)
299
300    def predict(self, X):
301        """Predict the class labels for the provided data.
302
303        Parameters
304        ----------
305        X : numpy.ndarray
306            3D array where shape = (trials, channels, samples)
307
308        Returns
309        -------
310        prediction : Prediction
311            Results of predict call containing the predicted class labels, and
312            the probabilities of the labels.
313
314        """
315        # if X is 2D, make it 3D with one as first dimension
316        if len(X.shape) < 3:
317            X = X[np.newaxis, ...]
318
319        subset_X = self.get_subset(X, self.subset, self.channel_labels)
320
321        logger.info("The shape of X is %s", subset_X.shape)
322
323        pred = [int(x) for x in self.clf.predict(subset_X)]
324        pred_proba = self.clf.predict_proba(subset_X)
325
326        logger.info("Prediction: %s", pred)
327        logger.info("Prediction probabilities: %s", pred_proba)
328
329        for i in range(len(pred)):
330            self.predictions.append(pred[i])
331            self.pred_probas.append(pred_proba[i])
332
333        return Prediction(labels=pred, probabilities=pred_proba)
 34class MiClassifier(GenericClassifier):
 35    """MI Classifier class (*inherits from `GenericClassifier`*)."""
 36
 37    def set_mi_classifier_settings(
 38        self,
 39        n_splits=5,
 40        type="TS",
 41        remove_flats=True,
 42        whitening=False,
 43        covariance_estimator="oas",
 44        artifact_rejection="none",
 45        pred_threshold=0.5,
 46        random_seed=42,
 47        n_jobs=1,
 48    ):
 49        """Set MI classifier settings.
 50
 51        Parameters
 52        ----------
 53        n_splits : int, *optional*
 54            Number of folds for cross-validation.
 55            - Default is `5`.
 56        type : str, *optional*
 57            Type of classifier to be used.
 58            Options = sLDA, RandomForest, TS, or MDM.
 59            - Default is `"TS"`.
 60        remove_flats : bool, *optional*
 61            Whether to remove flat channels from the EEG data.
 62            - Default is `True`.
 63        whitening : bool, *optional*
 64            Whether to apply whitening to the EEG data.
 65            - Default is `False`.
 66        covariance_estimator : str, *optional*
 67            Covariance estimator. See pyriemann Covariances.
 68            - Default is `"oas"`.
 69        artifact_rejection : str, *optional*
 70            Method for artefact rejection.
 71            - Default is `"none"`.
 72        pred_threshold : float, *optional*
 73            Prediction threshold used for classification.
 74            - Default is `0.5`.
 75        random_seed : int, *optional*
 76            Random seed.
 77            - Default is `42`.
 78        n_jobs : int, *optional*
 79            The number of threads to dedicate to this calculation.
 80            - Default is `1`.
 81
 82        Returns
 83        -------
 84        `None`
 85            Models created are used in `fit()`.
 86
 87        """
 88        # Build the cross-validation split
 89        self.n_splits = n_splits
 90        self.cv = StratifiedKFold(
 91            n_splits=n_splits, shuffle=True, random_state=random_seed
 92        )
 93
 94        self.covariance_estimator = covariance_estimator
 95
 96        # Shrinkage LDA
 97        if type == "sLDA":
 98            slda = LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto")
 99            self.clf_model = Pipeline([("Shrinkage LDA", slda)])
100            self.clf = Pipeline([("Shrinkage LDA", slda)])
101
102        # Random Forest
103        elif type == "RandomForest":
104            rf = RandomForestClassifier()
105            self.clf_model = Pipeline([("Random Forest", rf)])
106            self.clf = Pipeline([("Random Forest", rf)])
107
108        # Tangent Space Logistic Regression
109        elif type == "TS":
110            ts = TSclassifier()
111            self.clf_model = Pipeline([("Tangent Space", ts)])
112            self.clf = Pipeline([("Tangent Space", ts)])
113
114        # Minimum Distance to Mean
115        elif type == "MDM":
116            mdm = MDM(metric=dict(mean="riemann", distance="riemann"), n_jobs=n_jobs)
117            self.clf_model = Pipeline([("MDM", mdm)])
118            self.clf = Pipeline([("MDM", mdm)])
119
120        else:
121            logger.error("Classifier type not defined")
122
123        # All algorithms have covariance estimation as the first step
124        self.clf_model.steps.insert(
125            0, ["Covariances", Covariances(estimator=self.covariance_estimator)]
126        )
127        self.clf.steps.insert(
128            0, ["Covariances", Covariances(estimator=self.covariance_estimator)]
129        )
130
131        if artifact_rejection == "potato":
132            logger.error("Potato not implemented")
133
134        if whitening:
135            self.clf_model.steps.insert(0, ["Whitening", Whitening()])
136            self.clf.steps.insert(0, ["Whitening", Whitening()])
137
138        if remove_flats:
139            rf = FlatChannelRemover()
140            self.clf_model.steps.insert(0, ["Remove Flat Channels", rf])
141            self.clf.steps.insert(0, ["Remove Flat Channels", rf])
142
143        # Threshold
144        self.pred_threshold = pred_threshold
145
146        # Rebuild from scratch with each training
147        self.rebuild = True
148
149    def fit(self):
150        """Fit the model.
151
152        Returns
153        -------
154        `None`
155            Models created used in `predict()`.
156
157        """
158        # get dimensions
159        n_trials, n_channels, n_samples = self.X.shape
160
161        # do the rest of the training if train_free is false
162        self.X = np.array(self.X)
163
164        # Try rebuilding the classifier each time
165        if self.rebuild:
166            self.next_fit_trial = 0
167            self.clf = self.clf_model
168
169        # get temporal subset
170        subX = self.X[self.next_fit_trial :, :, :]
171        suby = self.y[self.next_fit_trial :]
172        self.next_fit_trial = n_trials
173
174        # Init predictions to all false
175        cv_preds = np.zeros(n_trials)
176
177        def __mi_kernel(subX, suby):
178            """MI kernel.
179
180            Parameters
181            ----------
182            subX : numpy.ndarray
183                EEG data for training.
184                3D array with shape = (`n_epochs`, `n_channels`, `n_samples`).
185            suby : numpy.ndarray
186                Labels for training data.
187                1D array with shape = (`n_epochs`, ).
188
189            Returns
190            -------
191            kernelResults : KernelResults
192                KernelResults object containing the following attributes:
193                    model : classifier
194                        The trained classification model.
195                    cv_preds : numpy.ndarray
196                        The predictions from the model using cross validation.
197                        1D array with the same shape as `suby`.
198                    accuracy : float
199                        The accuracy of the trained classification model.
200                    precision : float
201                        The precision of the trained classification model.
202                    recall : float
203                        The recall of the trained classification model.
204
205
206            """
207            for train_idx, test_idx in self.cv.split(subX, suby):
208                self.clf = self.clf_model
209
210                X_train, X_test = subX[train_idx], subX[test_idx]
211                # y_test not implemented
212                y_train = suby[train_idx]
213
214                # fit the classsifier
215                self.clf.fit(X_train, y_train)
216                cv_preds[test_idx] = self.clf.predict(X_test)
217
218            # Train final model with all available data
219            self.clf.fit(subX, suby)
220            model = self.clf
221
222            accuracy = sum(cv_preds == self.y) / len(cv_preds)
223            precision = precision_score(self.y, cv_preds, average="micro")
224            recall = recall_score(self.y, cv_preds, average="micro")
225
226            return KernelResults(model, cv_preds, accuracy, precision, recall)
227
228        # Check if channel selection is true
229        if self.channel_selection_setup:
230            if self.chs_iterative_selection is True and self.subset is not None:
231                initial_subset = self.subset
232                logger.info(
233                    "Using subset from previous channel selection "
234                    + "because iterative selection is TRUE"
235                )
236            else:
237                initial_subset = self.chs_initial_subset
238
239            logger.info("Doing channel selection")
240            channel_selection_results = channel_selection_by_method(
241                __mi_kernel,
242                self.X,
243                self.y,
244                self.channel_labels,  # kernel setup
245                self.chs_method,
246                self.chs_metric,
247                initial_subset,  # wrapper setup
248                self.chs_max_time,
249                self.chs_min_channels,
250                self.chs_max_channels,
251                self.chs_performance_delta,  # stopping criterion
252                self.chs_n_jobs,
253            )
254
255            preds = channel_selection_results.best_preds
256            accuracy = channel_selection_results.best_accuracy
257            precision = channel_selection_results.best_precision
258            recall = channel_selection_results.best_recall
259
260            self.results_df = channel_selection_results.results_df
261            self.subset = channel_selection_results.best_channel_subset
262            self.subset_defined = True
263            self.clf = channel_selection_results.best_model
264        else:
265            logger.warning("Not doing channel selection")
266
267            subX = self.get_subset(subX, self.subset, self.channel_labels)
268
269            current_results = __mi_kernel(subX, suby)
270            self.clf = current_results.model
271            preds = current_results.cv_preds
272            accuracy = current_results.accuracy
273            precision = current_results.precision
274            recall = current_results.recall
275
276        # Log performance stats
277
278        self.offline_trial_count = n_trials
279        self.offline_trial_counts.append(self.offline_trial_count)
280
281        # accuracy
282        accuracy = sum(preds == self.y) / len(preds)
283        self.offline_accuracy.append(accuracy)
284        logger.info("Accuracy = %s", accuracy)
285
286        # precision
287        precision = precision_score(self.y, preds, average="micro")
288        self.offline_precision.append(precision)
289        logger.info("Precision = %s", precision)
290
291        # recall
292        recall = recall_score(self.y, preds, average="micro")
293        self.offline_recall.append(recall)
294        logger.info("Recall = %s", recall)
295
296        # confusion matrix in command line
297        cm = confusion_matrix(self.y, preds)
298        self.offline_cm = cm
299        logger.info("Confusion matrix:\n%s", cm)
300
301    def predict(self, X):
302        """Predict the class labels for the provided data.
303
304        Parameters
305        ----------
306        X : numpy.ndarray
307            3D array where shape = (trials, channels, samples)
308
309        Returns
310        -------
311        prediction : Prediction
312            Results of predict call containing the predicted class labels, and
313            the probabilities of the labels.
314
315        """
316        # if X is 2D, make it 3D with one as first dimension
317        if len(X.shape) < 3:
318            X = X[np.newaxis, ...]
319
320        subset_X = self.get_subset(X, self.subset, self.channel_labels)
321
322        logger.info("The shape of X is %s", subset_X.shape)
323
324        pred = [int(x) for x in self.clf.predict(subset_X)]
325        pred_proba = self.clf.predict_proba(subset_X)
326
327        logger.info("Prediction: %s", pred)
328        logger.info("Prediction probabilities: %s", pred_proba)
329
330        for i in range(len(pred)):
331            self.predictions.append(pred[i])
332            self.pred_probas.append(pred_proba[i])
333
334        return Prediction(labels=pred, probabilities=pred_proba)

MI Classifier class (inherits from GenericClassifier).

def set_mi_classifier_settings( self, n_splits=5, type='TS', remove_flats=True, whitening=False, covariance_estimator='oas', artifact_rejection='none', pred_threshold=0.5, random_seed=42, n_jobs=1):
 37    def set_mi_classifier_settings(
 38        self,
 39        n_splits=5,
 40        type="TS",
 41        remove_flats=True,
 42        whitening=False,
 43        covariance_estimator="oas",
 44        artifact_rejection="none",
 45        pred_threshold=0.5,
 46        random_seed=42,
 47        n_jobs=1,
 48    ):
 49        """Set MI classifier settings.
 50
 51        Parameters
 52        ----------
 53        n_splits : int, *optional*
 54            Number of folds for cross-validation.
 55            - Default is `5`.
 56        type : str, *optional*
 57            Type of classifier to be used.
 58            Options = sLDA, RandomForest, TS, or MDM.
 59            - Default is `"TS"`.
 60        remove_flats : bool, *optional*
 61            Whether to remove flat channels from the EEG data.
 62            - Default is `True`.
 63        whitening : bool, *optional*
 64            Whether to apply whitening to the EEG data.
 65            - Default is `False`.
 66        covariance_estimator : str, *optional*
 67            Covariance estimator. See pyriemann Covariances.
 68            - Default is `"oas"`.
 69        artifact_rejection : str, *optional*
 70            Method for artefact rejection.
 71            - Default is `"none"`.
 72        pred_threshold : float, *optional*
 73            Prediction threshold used for classification.
 74            - Default is `0.5`.
 75        random_seed : int, *optional*
 76            Random seed.
 77            - Default is `42`.
 78        n_jobs : int, *optional*
 79            The number of threads to dedicate to this calculation.
 80            - Default is `1`.
 81
 82        Returns
 83        -------
 84        `None`
 85            Models created are used in `fit()`.
 86
 87        """
 88        # Build the cross-validation split
 89        self.n_splits = n_splits
 90        self.cv = StratifiedKFold(
 91            n_splits=n_splits, shuffle=True, random_state=random_seed
 92        )
 93
 94        self.covariance_estimator = covariance_estimator
 95
 96        # Shrinkage LDA
 97        if type == "sLDA":
 98            slda = LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto")
 99            self.clf_model = Pipeline([("Shrinkage LDA", slda)])
100            self.clf = Pipeline([("Shrinkage LDA", slda)])
101
102        # Random Forest
103        elif type == "RandomForest":
104            rf = RandomForestClassifier()
105            self.clf_model = Pipeline([("Random Forest", rf)])
106            self.clf = Pipeline([("Random Forest", rf)])
107
108        # Tangent Space Logistic Regression
109        elif type == "TS":
110            ts = TSclassifier()
111            self.clf_model = Pipeline([("Tangent Space", ts)])
112            self.clf = Pipeline([("Tangent Space", ts)])
113
114        # Minimum Distance to Mean
115        elif type == "MDM":
116            mdm = MDM(metric=dict(mean="riemann", distance="riemann"), n_jobs=n_jobs)
117            self.clf_model = Pipeline([("MDM", mdm)])
118            self.clf = Pipeline([("MDM", mdm)])
119
120        else:
121            logger.error("Classifier type not defined")
122
123        # All algorithms have covariance estimation as the first step
124        self.clf_model.steps.insert(
125            0, ["Covariances", Covariances(estimator=self.covariance_estimator)]
126        )
127        self.clf.steps.insert(
128            0, ["Covariances", Covariances(estimator=self.covariance_estimator)]
129        )
130
131        if artifact_rejection == "potato":
132            logger.error("Potato not implemented")
133
134        if whitening:
135            self.clf_model.steps.insert(0, ["Whitening", Whitening()])
136            self.clf.steps.insert(0, ["Whitening", Whitening()])
137
138        if remove_flats:
139            rf = FlatChannelRemover()
140            self.clf_model.steps.insert(0, ["Remove Flat Channels", rf])
141            self.clf.steps.insert(0, ["Remove Flat Channels", rf])
142
143        # Threshold
144        self.pred_threshold = pred_threshold
145
146        # Rebuild from scratch with each training
147        self.rebuild = True

Set MI classifier settings.

Parameters
  • n_splits (int, optional): Number of folds for cross-validation.
    • Default is 5.
  • type (str, optional): Type of classifier to be used. Options = sLDA, RandomForest, TS, or MDM.
    • Default is "TS".
  • remove_flats (bool, optional): Whether to remove flat channels from the EEG data.
    • Default is True.
  • whitening (bool, optional): Whether to apply whitening to the EEG data.
    • Default is False.
  • covariance_estimator (str, optional): Covariance estimator. See pyriemann Covariances.
    • Default is "oas".
  • artifact_rejection (str, optional): Method for artefact rejection.
    • Default is "none".
  • pred_threshold (float, optional): Prediction threshold used for classification.
    • Default is 0.5.
  • random_seed (int, optional): Random seed.
    • Default is 42.
  • n_jobs (int, optional): The number of threads to dedicate to this calculation.
    • Default is 1.
Returns
  • None: Models created are used in fit().
def fit(self):
149    def fit(self):
150        """Fit the model.
151
152        Returns
153        -------
154        `None`
155            Models created used in `predict()`.
156
157        """
158        # get dimensions
159        n_trials, n_channels, n_samples = self.X.shape
160
161        # do the rest of the training if train_free is false
162        self.X = np.array(self.X)
163
164        # Try rebuilding the classifier each time
165        if self.rebuild:
166            self.next_fit_trial = 0
167            self.clf = self.clf_model
168
169        # get temporal subset
170        subX = self.X[self.next_fit_trial :, :, :]
171        suby = self.y[self.next_fit_trial :]
172        self.next_fit_trial = n_trials
173
174        # Init predictions to all false
175        cv_preds = np.zeros(n_trials)
176
177        def __mi_kernel(subX, suby):
178            """MI kernel.
179
180            Parameters
181            ----------
182            subX : numpy.ndarray
183                EEG data for training.
184                3D array with shape = (`n_epochs`, `n_channels`, `n_samples`).
185            suby : numpy.ndarray
186                Labels for training data.
187                1D array with shape = (`n_epochs`, ).
188
189            Returns
190            -------
191            kernelResults : KernelResults
192                KernelResults object containing the following attributes:
193                    model : classifier
194                        The trained classification model.
195                    cv_preds : numpy.ndarray
196                        The predictions from the model using cross validation.
197                        1D array with the same shape as `suby`.
198                    accuracy : float
199                        The accuracy of the trained classification model.
200                    precision : float
201                        The precision of the trained classification model.
202                    recall : float
203                        The recall of the trained classification model.
204
205
206            """
207            for train_idx, test_idx in self.cv.split(subX, suby):
208                self.clf = self.clf_model
209
210                X_train, X_test = subX[train_idx], subX[test_idx]
211                # y_test not implemented
212                y_train = suby[train_idx]
213
214                # fit the classsifier
215                self.clf.fit(X_train, y_train)
216                cv_preds[test_idx] = self.clf.predict(X_test)
217
218            # Train final model with all available data
219            self.clf.fit(subX, suby)
220            model = self.clf
221
222            accuracy = sum(cv_preds == self.y) / len(cv_preds)
223            precision = precision_score(self.y, cv_preds, average="micro")
224            recall = recall_score(self.y, cv_preds, average="micro")
225
226            return KernelResults(model, cv_preds, accuracy, precision, recall)
227
228        # Check if channel selection is true
229        if self.channel_selection_setup:
230            if self.chs_iterative_selection is True and self.subset is not None:
231                initial_subset = self.subset
232                logger.info(
233                    "Using subset from previous channel selection "
234                    + "because iterative selection is TRUE"
235                )
236            else:
237                initial_subset = self.chs_initial_subset
238
239            logger.info("Doing channel selection")
240            channel_selection_results = channel_selection_by_method(
241                __mi_kernel,
242                self.X,
243                self.y,
244                self.channel_labels,  # kernel setup
245                self.chs_method,
246                self.chs_metric,
247                initial_subset,  # wrapper setup
248                self.chs_max_time,
249                self.chs_min_channels,
250                self.chs_max_channels,
251                self.chs_performance_delta,  # stopping criterion
252                self.chs_n_jobs,
253            )
254
255            preds = channel_selection_results.best_preds
256            accuracy = channel_selection_results.best_accuracy
257            precision = channel_selection_results.best_precision
258            recall = channel_selection_results.best_recall
259
260            self.results_df = channel_selection_results.results_df
261            self.subset = channel_selection_results.best_channel_subset
262            self.subset_defined = True
263            self.clf = channel_selection_results.best_model
264        else:
265            logger.warning("Not doing channel selection")
266
267            subX = self.get_subset(subX, self.subset, self.channel_labels)
268
269            current_results = __mi_kernel(subX, suby)
270            self.clf = current_results.model
271            preds = current_results.cv_preds
272            accuracy = current_results.accuracy
273            precision = current_results.precision
274            recall = current_results.recall
275
276        # Log performance stats
277
278        self.offline_trial_count = n_trials
279        self.offline_trial_counts.append(self.offline_trial_count)
280
281        # accuracy
282        accuracy = sum(preds == self.y) / len(preds)
283        self.offline_accuracy.append(accuracy)
284        logger.info("Accuracy = %s", accuracy)
285
286        # precision
287        precision = precision_score(self.y, preds, average="micro")
288        self.offline_precision.append(precision)
289        logger.info("Precision = %s", precision)
290
291        # recall
292        recall = recall_score(self.y, preds, average="micro")
293        self.offline_recall.append(recall)
294        logger.info("Recall = %s", recall)
295
296        # confusion matrix in command line
297        cm = confusion_matrix(self.y, preds)
298        self.offline_cm = cm
299        logger.info("Confusion matrix:\n%s", cm)

Fit the model.

Returns
def predict(self, X):
301    def predict(self, X):
302        """Predict the class labels for the provided data.
303
304        Parameters
305        ----------
306        X : numpy.ndarray
307            3D array where shape = (trials, channels, samples)
308
309        Returns
310        -------
311        prediction : Prediction
312            Results of predict call containing the predicted class labels, and
313            the probabilities of the labels.
314
315        """
316        # if X is 2D, make it 3D with one as first dimension
317        if len(X.shape) < 3:
318            X = X[np.newaxis, ...]
319
320        subset_X = self.get_subset(X, self.subset, self.channel_labels)
321
322        logger.info("The shape of X is %s", subset_X.shape)
323
324        pred = [int(x) for x in self.clf.predict(subset_X)]
325        pred_proba = self.clf.predict_proba(subset_X)
326
327        logger.info("Prediction: %s", pred)
328        logger.info("Prediction probabilities: %s", pred_proba)
329
330        for i in range(len(pred)):
331            self.predictions.append(pred[i])
332            self.pred_probas.append(pred_proba[i])
333
334        return Prediction(labels=pred, probabilities=pred_proba)

Predict the class labels for the provided data.

Parameters
  • X (numpy.ndarray): 3D array where shape = (trials, channels, samples)
Returns
  • prediction (Prediction): Results of predict call containing the predicted class labels, and the probabilities of the labels.