bci_essentials.classification.mi_classifier
MI Classifier
This classifier is used to classify MI data.
1"""**MI Classifier** 2 3This classifier is used to classify MI data. 4 5""" 6 7# Stock libraries 8import numpy as np 9from sklearn.model_selection import StratifiedKFold 10from sklearn.pipeline import Pipeline 11from sklearn.ensemble import RandomForestClassifier 12from sklearn.metrics import confusion_matrix, precision_score, recall_score 13from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 14from pyriemann.preprocessing import Whitening 15from pyriemann.estimation import Covariances 16from pyriemann.classification import MDM, TSclassifier 17from pyriemann.channelselection import FlatChannelRemover 18 19# Import bci_essentials modules and methods 20from ..classification.generic_classifier import ( 21 GenericClassifier, 22 Prediction, 23 KernelResults, 24) 25from ..channel_selection import channel_selection_by_method 26from ..utils.logger import Logger # Logger wrapper 27 28# Instantiate a logger for the module at the default level of logging.INFO 29# Logs to bci_essentials.__module__) where __module__ is the name of the module 30logger = Logger(name=__name__) 31 32 33class MiClassifier(GenericClassifier): 34 """MI Classifier class (*inherits from `GenericClassifier`*).""" 35 36 def set_mi_classifier_settings( 37 self, 38 n_splits=5, 39 type="TS", 40 remove_flats=True, 41 whitening=False, 42 covariance_estimator="oas", 43 artifact_rejection="none", 44 pred_threshold=0.5, 45 random_seed=42, 46 n_jobs=1, 47 ): 48 """Set MI classifier settings. 49 50 Parameters 51 ---------- 52 n_splits : int, *optional* 53 Number of folds for cross-validation. 54 - Default is `5`. 55 type : str, *optional* 56 Type of classifier to be used. 57 Options = sLDA, RandomForest, TS, or MDM. 58 - Default is `"TS"`. 59 remove_flats : bool, *optional* 60 Whether to remove flat channels from the EEG data. 61 - Default is `True`. 62 whitening : bool, *optional* 63 Whether to apply whitening to the EEG data. 64 - Default is `False`. 65 covariance_estimator : str, *optional* 66 Covariance estimator. See pyriemann Covariances. 67 - Default is `"oas"`. 68 artifact_rejection : str, *optional* 69 Method for artefact rejection. 70 - Default is `"none"`. 71 pred_threshold : float, *optional* 72 Prediction threshold used for classification. 73 - Default is `0.5`. 74 random_seed : int, *optional* 75 Random seed. 76 - Default is `42`. 77 n_jobs : int, *optional* 78 The number of threads to dedicate to this calculation. 79 - Default is `1`. 80 81 Returns 82 ------- 83 `None` 84 Models created are used in `fit()`. 85 86 """ 87 # Build the cross-validation split 88 self.n_splits = n_splits 89 self.cv = StratifiedKFold( 90 n_splits=n_splits, shuffle=True, random_state=random_seed 91 ) 92 93 self.covariance_estimator = covariance_estimator 94 95 # Shrinkage LDA 96 if type == "sLDA": 97 slda = LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto") 98 self.clf_model = Pipeline([("Shrinkage LDA", slda)]) 99 self.clf = Pipeline([("Shrinkage LDA", slda)]) 100 101 # Random Forest 102 elif type == "RandomForest": 103 rf = RandomForestClassifier() 104 self.clf_model = Pipeline([("Random Forest", rf)]) 105 self.clf = Pipeline([("Random Forest", rf)]) 106 107 # Tangent Space Logistic Regression 108 elif type == "TS": 109 ts = TSclassifier() 110 self.clf_model = Pipeline([("Tangent Space", ts)]) 111 self.clf = Pipeline([("Tangent Space", ts)]) 112 113 # Minimum Distance to Mean 114 elif type == "MDM": 115 mdm = MDM(metric=dict(mean="riemann", distance="riemann"), n_jobs=n_jobs) 116 self.clf_model = Pipeline([("MDM", mdm)]) 117 self.clf = Pipeline([("MDM", mdm)]) 118 119 else: 120 logger.error("Classifier type not defined") 121 122 # All algorithms have covariance estimation as the first step 123 self.clf_model.steps.insert( 124 0, ["Covariances", Covariances(estimator=self.covariance_estimator)] 125 ) 126 self.clf.steps.insert( 127 0, ["Covariances", Covariances(estimator=self.covariance_estimator)] 128 ) 129 130 if artifact_rejection == "potato": 131 logger.error("Potato not implemented") 132 133 if whitening: 134 self.clf_model.steps.insert(0, ["Whitening", Whitening()]) 135 self.clf.steps.insert(0, ["Whitening", Whitening()]) 136 137 if remove_flats: 138 rf = FlatChannelRemover() 139 self.clf_model.steps.insert(0, ["Remove Flat Channels", rf]) 140 self.clf.steps.insert(0, ["Remove Flat Channels", rf]) 141 142 # Threshold 143 self.pred_threshold = pred_threshold 144 145 # Rebuild from scratch with each training 146 self.rebuild = True 147 148 def fit(self): 149 """Fit the model. 150 151 Returns 152 ------- 153 `None` 154 Models created used in `predict()`. 155 156 """ 157 # get dimensions 158 n_trials, n_channels, n_samples = self.X.shape 159 160 # do the rest of the training if train_free is false 161 self.X = np.array(self.X) 162 163 # Try rebuilding the classifier each time 164 if self.rebuild: 165 self.next_fit_trial = 0 166 self.clf = self.clf_model 167 168 # get temporal subset 169 subX = self.X[self.next_fit_trial :, :, :] 170 suby = self.y[self.next_fit_trial :] 171 self.next_fit_trial = n_trials 172 173 # Init predictions to all false 174 cv_preds = np.zeros(n_trials) 175 176 def __mi_kernel(subX, suby): 177 """MI kernel. 178 179 Parameters 180 ---------- 181 subX : numpy.ndarray 182 EEG data for training. 183 3D array with shape = (`n_epochs`, `n_channels`, `n_samples`). 184 suby : numpy.ndarray 185 Labels for training data. 186 1D array with shape = (`n_epochs`, ). 187 188 Returns 189 ------- 190 kernelResults : KernelResults 191 KernelResults object containing the following attributes: 192 model : classifier 193 The trained classification model. 194 cv_preds : numpy.ndarray 195 The predictions from the model using cross validation. 196 1D array with the same shape as `suby`. 197 accuracy : float 198 The accuracy of the trained classification model. 199 precision : float 200 The precision of the trained classification model. 201 recall : float 202 The recall of the trained classification model. 203 204 205 """ 206 for train_idx, test_idx in self.cv.split(subX, suby): 207 self.clf = self.clf_model 208 209 X_train, X_test = subX[train_idx], subX[test_idx] 210 # y_test not implemented 211 y_train = suby[train_idx] 212 213 # fit the classsifier 214 self.clf.fit(X_train, y_train) 215 cv_preds[test_idx] = self.clf.predict(X_test) 216 217 # Train final model with all available data 218 self.clf.fit(subX, suby) 219 model = self.clf 220 221 accuracy = sum(cv_preds == self.y) / len(cv_preds) 222 precision = precision_score(self.y, cv_preds, average="micro") 223 recall = recall_score(self.y, cv_preds, average="micro") 224 225 return KernelResults(model, cv_preds, accuracy, precision, recall) 226 227 # Check if channel selection is true 228 if self.channel_selection_setup: 229 if self.chs_iterative_selection is True and self.subset is not None: 230 initial_subset = self.subset 231 logger.info( 232 "Using subset from previous channel selection " 233 + "because iterative selection is TRUE" 234 ) 235 else: 236 initial_subset = self.chs_initial_subset 237 238 logger.info("Doing channel selection") 239 channel_selection_results = channel_selection_by_method( 240 __mi_kernel, 241 self.X, 242 self.y, 243 self.channel_labels, # kernel setup 244 self.chs_method, 245 self.chs_metric, 246 initial_subset, # wrapper setup 247 self.chs_max_time, 248 self.chs_min_channels, 249 self.chs_max_channels, 250 self.chs_performance_delta, # stopping criterion 251 self.chs_n_jobs, 252 ) 253 254 preds = channel_selection_results.best_preds 255 accuracy = channel_selection_results.best_accuracy 256 precision = channel_selection_results.best_precision 257 recall = channel_selection_results.best_recall 258 259 self.results_df = channel_selection_results.results_df 260 self.subset = channel_selection_results.best_channel_subset 261 self.subset_defined = True 262 self.clf = channel_selection_results.best_model 263 else: 264 logger.warning("Not doing channel selection") 265 266 subX = self.get_subset(subX, self.subset, self.channel_labels) 267 268 current_results = __mi_kernel(subX, suby) 269 self.clf = current_results.model 270 preds = current_results.cv_preds 271 accuracy = current_results.accuracy 272 precision = current_results.precision 273 recall = current_results.recall 274 275 # Log performance stats 276 277 self.offline_trial_count = n_trials 278 self.offline_trial_counts.append(self.offline_trial_count) 279 280 # accuracy 281 accuracy = sum(preds == self.y) / len(preds) 282 self.offline_accuracy.append(accuracy) 283 logger.info("Accuracy = %s", accuracy) 284 285 # precision 286 precision = precision_score(self.y, preds, average="micro") 287 self.offline_precision.append(precision) 288 logger.info("Precision = %s", precision) 289 290 # recall 291 recall = recall_score(self.y, preds, average="micro") 292 self.offline_recall.append(recall) 293 logger.info("Recall = %s", recall) 294 295 # confusion matrix in command line 296 cm = confusion_matrix(self.y, preds) 297 self.offline_cm = cm 298 logger.info("Confusion matrix:\n%s", cm) 299 300 def predict(self, X): 301 """Predict the class labels for the provided data. 302 303 Parameters 304 ---------- 305 X : numpy.ndarray 306 3D array where shape = (trials, channels, samples) 307 308 Returns 309 ------- 310 prediction : Prediction 311 Results of predict call containing the predicted class labels, and 312 the probabilities of the labels. 313 314 """ 315 # if X is 2D, make it 3D with one as first dimension 316 if len(X.shape) < 3: 317 X = X[np.newaxis, ...] 318 319 subset_X = self.get_subset(X, self.subset, self.channel_labels) 320 321 logger.info("The shape of X is %s", subset_X.shape) 322 323 pred = [int(x) for x in self.clf.predict(subset_X)] 324 pred_proba = self.clf.predict_proba(subset_X) 325 326 logger.info("Prediction: %s", pred) 327 logger.info("Prediction probabilities: %s", pred_proba) 328 329 for i in range(len(pred)): 330 self.predictions.append(pred[i]) 331 self.pred_probas.append(pred_proba[i]) 332 333 return Prediction(labels=pred, probabilities=pred_proba)
logger =
<bci_essentials.utils.logger.Logger object>
34class MiClassifier(GenericClassifier): 35 """MI Classifier class (*inherits from `GenericClassifier`*).""" 36 37 def set_mi_classifier_settings( 38 self, 39 n_splits=5, 40 type="TS", 41 remove_flats=True, 42 whitening=False, 43 covariance_estimator="oas", 44 artifact_rejection="none", 45 pred_threshold=0.5, 46 random_seed=42, 47 n_jobs=1, 48 ): 49 """Set MI classifier settings. 50 51 Parameters 52 ---------- 53 n_splits : int, *optional* 54 Number of folds for cross-validation. 55 - Default is `5`. 56 type : str, *optional* 57 Type of classifier to be used. 58 Options = sLDA, RandomForest, TS, or MDM. 59 - Default is `"TS"`. 60 remove_flats : bool, *optional* 61 Whether to remove flat channels from the EEG data. 62 - Default is `True`. 63 whitening : bool, *optional* 64 Whether to apply whitening to the EEG data. 65 - Default is `False`. 66 covariance_estimator : str, *optional* 67 Covariance estimator. See pyriemann Covariances. 68 - Default is `"oas"`. 69 artifact_rejection : str, *optional* 70 Method for artefact rejection. 71 - Default is `"none"`. 72 pred_threshold : float, *optional* 73 Prediction threshold used for classification. 74 - Default is `0.5`. 75 random_seed : int, *optional* 76 Random seed. 77 - Default is `42`. 78 n_jobs : int, *optional* 79 The number of threads to dedicate to this calculation. 80 - Default is `1`. 81 82 Returns 83 ------- 84 `None` 85 Models created are used in `fit()`. 86 87 """ 88 # Build the cross-validation split 89 self.n_splits = n_splits 90 self.cv = StratifiedKFold( 91 n_splits=n_splits, shuffle=True, random_state=random_seed 92 ) 93 94 self.covariance_estimator = covariance_estimator 95 96 # Shrinkage LDA 97 if type == "sLDA": 98 slda = LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto") 99 self.clf_model = Pipeline([("Shrinkage LDA", slda)]) 100 self.clf = Pipeline([("Shrinkage LDA", slda)]) 101 102 # Random Forest 103 elif type == "RandomForest": 104 rf = RandomForestClassifier() 105 self.clf_model = Pipeline([("Random Forest", rf)]) 106 self.clf = Pipeline([("Random Forest", rf)]) 107 108 # Tangent Space Logistic Regression 109 elif type == "TS": 110 ts = TSclassifier() 111 self.clf_model = Pipeline([("Tangent Space", ts)]) 112 self.clf = Pipeline([("Tangent Space", ts)]) 113 114 # Minimum Distance to Mean 115 elif type == "MDM": 116 mdm = MDM(metric=dict(mean="riemann", distance="riemann"), n_jobs=n_jobs) 117 self.clf_model = Pipeline([("MDM", mdm)]) 118 self.clf = Pipeline([("MDM", mdm)]) 119 120 else: 121 logger.error("Classifier type not defined") 122 123 # All algorithms have covariance estimation as the first step 124 self.clf_model.steps.insert( 125 0, ["Covariances", Covariances(estimator=self.covariance_estimator)] 126 ) 127 self.clf.steps.insert( 128 0, ["Covariances", Covariances(estimator=self.covariance_estimator)] 129 ) 130 131 if artifact_rejection == "potato": 132 logger.error("Potato not implemented") 133 134 if whitening: 135 self.clf_model.steps.insert(0, ["Whitening", Whitening()]) 136 self.clf.steps.insert(0, ["Whitening", Whitening()]) 137 138 if remove_flats: 139 rf = FlatChannelRemover() 140 self.clf_model.steps.insert(0, ["Remove Flat Channels", rf]) 141 self.clf.steps.insert(0, ["Remove Flat Channels", rf]) 142 143 # Threshold 144 self.pred_threshold = pred_threshold 145 146 # Rebuild from scratch with each training 147 self.rebuild = True 148 149 def fit(self): 150 """Fit the model. 151 152 Returns 153 ------- 154 `None` 155 Models created used in `predict()`. 156 157 """ 158 # get dimensions 159 n_trials, n_channels, n_samples = self.X.shape 160 161 # do the rest of the training if train_free is false 162 self.X = np.array(self.X) 163 164 # Try rebuilding the classifier each time 165 if self.rebuild: 166 self.next_fit_trial = 0 167 self.clf = self.clf_model 168 169 # get temporal subset 170 subX = self.X[self.next_fit_trial :, :, :] 171 suby = self.y[self.next_fit_trial :] 172 self.next_fit_trial = n_trials 173 174 # Init predictions to all false 175 cv_preds = np.zeros(n_trials) 176 177 def __mi_kernel(subX, suby): 178 """MI kernel. 179 180 Parameters 181 ---------- 182 subX : numpy.ndarray 183 EEG data for training. 184 3D array with shape = (`n_epochs`, `n_channels`, `n_samples`). 185 suby : numpy.ndarray 186 Labels for training data. 187 1D array with shape = (`n_epochs`, ). 188 189 Returns 190 ------- 191 kernelResults : KernelResults 192 KernelResults object containing the following attributes: 193 model : classifier 194 The trained classification model. 195 cv_preds : numpy.ndarray 196 The predictions from the model using cross validation. 197 1D array with the same shape as `suby`. 198 accuracy : float 199 The accuracy of the trained classification model. 200 precision : float 201 The precision of the trained classification model. 202 recall : float 203 The recall of the trained classification model. 204 205 206 """ 207 for train_idx, test_idx in self.cv.split(subX, suby): 208 self.clf = self.clf_model 209 210 X_train, X_test = subX[train_idx], subX[test_idx] 211 # y_test not implemented 212 y_train = suby[train_idx] 213 214 # fit the classsifier 215 self.clf.fit(X_train, y_train) 216 cv_preds[test_idx] = self.clf.predict(X_test) 217 218 # Train final model with all available data 219 self.clf.fit(subX, suby) 220 model = self.clf 221 222 accuracy = sum(cv_preds == self.y) / len(cv_preds) 223 precision = precision_score(self.y, cv_preds, average="micro") 224 recall = recall_score(self.y, cv_preds, average="micro") 225 226 return KernelResults(model, cv_preds, accuracy, precision, recall) 227 228 # Check if channel selection is true 229 if self.channel_selection_setup: 230 if self.chs_iterative_selection is True and self.subset is not None: 231 initial_subset = self.subset 232 logger.info( 233 "Using subset from previous channel selection " 234 + "because iterative selection is TRUE" 235 ) 236 else: 237 initial_subset = self.chs_initial_subset 238 239 logger.info("Doing channel selection") 240 channel_selection_results = channel_selection_by_method( 241 __mi_kernel, 242 self.X, 243 self.y, 244 self.channel_labels, # kernel setup 245 self.chs_method, 246 self.chs_metric, 247 initial_subset, # wrapper setup 248 self.chs_max_time, 249 self.chs_min_channels, 250 self.chs_max_channels, 251 self.chs_performance_delta, # stopping criterion 252 self.chs_n_jobs, 253 ) 254 255 preds = channel_selection_results.best_preds 256 accuracy = channel_selection_results.best_accuracy 257 precision = channel_selection_results.best_precision 258 recall = channel_selection_results.best_recall 259 260 self.results_df = channel_selection_results.results_df 261 self.subset = channel_selection_results.best_channel_subset 262 self.subset_defined = True 263 self.clf = channel_selection_results.best_model 264 else: 265 logger.warning("Not doing channel selection") 266 267 subX = self.get_subset(subX, self.subset, self.channel_labels) 268 269 current_results = __mi_kernel(subX, suby) 270 self.clf = current_results.model 271 preds = current_results.cv_preds 272 accuracy = current_results.accuracy 273 precision = current_results.precision 274 recall = current_results.recall 275 276 # Log performance stats 277 278 self.offline_trial_count = n_trials 279 self.offline_trial_counts.append(self.offline_trial_count) 280 281 # accuracy 282 accuracy = sum(preds == self.y) / len(preds) 283 self.offline_accuracy.append(accuracy) 284 logger.info("Accuracy = %s", accuracy) 285 286 # precision 287 precision = precision_score(self.y, preds, average="micro") 288 self.offline_precision.append(precision) 289 logger.info("Precision = %s", precision) 290 291 # recall 292 recall = recall_score(self.y, preds, average="micro") 293 self.offline_recall.append(recall) 294 logger.info("Recall = %s", recall) 295 296 # confusion matrix in command line 297 cm = confusion_matrix(self.y, preds) 298 self.offline_cm = cm 299 logger.info("Confusion matrix:\n%s", cm) 300 301 def predict(self, X): 302 """Predict the class labels for the provided data. 303 304 Parameters 305 ---------- 306 X : numpy.ndarray 307 3D array where shape = (trials, channels, samples) 308 309 Returns 310 ------- 311 prediction : Prediction 312 Results of predict call containing the predicted class labels, and 313 the probabilities of the labels. 314 315 """ 316 # if X is 2D, make it 3D with one as first dimension 317 if len(X.shape) < 3: 318 X = X[np.newaxis, ...] 319 320 subset_X = self.get_subset(X, self.subset, self.channel_labels) 321 322 logger.info("The shape of X is %s", subset_X.shape) 323 324 pred = [int(x) for x in self.clf.predict(subset_X)] 325 pred_proba = self.clf.predict_proba(subset_X) 326 327 logger.info("Prediction: %s", pred) 328 logger.info("Prediction probabilities: %s", pred_proba) 329 330 for i in range(len(pred)): 331 self.predictions.append(pred[i]) 332 self.pred_probas.append(pred_proba[i]) 333 334 return Prediction(labels=pred, probabilities=pred_proba)
MI Classifier class (inherits from GenericClassifier).
def
set_mi_classifier_settings( self, n_splits=5, type='TS', remove_flats=True, whitening=False, covariance_estimator='oas', artifact_rejection='none', pred_threshold=0.5, random_seed=42, n_jobs=1):
37 def set_mi_classifier_settings( 38 self, 39 n_splits=5, 40 type="TS", 41 remove_flats=True, 42 whitening=False, 43 covariance_estimator="oas", 44 artifact_rejection="none", 45 pred_threshold=0.5, 46 random_seed=42, 47 n_jobs=1, 48 ): 49 """Set MI classifier settings. 50 51 Parameters 52 ---------- 53 n_splits : int, *optional* 54 Number of folds for cross-validation. 55 - Default is `5`. 56 type : str, *optional* 57 Type of classifier to be used. 58 Options = sLDA, RandomForest, TS, or MDM. 59 - Default is `"TS"`. 60 remove_flats : bool, *optional* 61 Whether to remove flat channels from the EEG data. 62 - Default is `True`. 63 whitening : bool, *optional* 64 Whether to apply whitening to the EEG data. 65 - Default is `False`. 66 covariance_estimator : str, *optional* 67 Covariance estimator. See pyriemann Covariances. 68 - Default is `"oas"`. 69 artifact_rejection : str, *optional* 70 Method for artefact rejection. 71 - Default is `"none"`. 72 pred_threshold : float, *optional* 73 Prediction threshold used for classification. 74 - Default is `0.5`. 75 random_seed : int, *optional* 76 Random seed. 77 - Default is `42`. 78 n_jobs : int, *optional* 79 The number of threads to dedicate to this calculation. 80 - Default is `1`. 81 82 Returns 83 ------- 84 `None` 85 Models created are used in `fit()`. 86 87 """ 88 # Build the cross-validation split 89 self.n_splits = n_splits 90 self.cv = StratifiedKFold( 91 n_splits=n_splits, shuffle=True, random_state=random_seed 92 ) 93 94 self.covariance_estimator = covariance_estimator 95 96 # Shrinkage LDA 97 if type == "sLDA": 98 slda = LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto") 99 self.clf_model = Pipeline([("Shrinkage LDA", slda)]) 100 self.clf = Pipeline([("Shrinkage LDA", slda)]) 101 102 # Random Forest 103 elif type == "RandomForest": 104 rf = RandomForestClassifier() 105 self.clf_model = Pipeline([("Random Forest", rf)]) 106 self.clf = Pipeline([("Random Forest", rf)]) 107 108 # Tangent Space Logistic Regression 109 elif type == "TS": 110 ts = TSclassifier() 111 self.clf_model = Pipeline([("Tangent Space", ts)]) 112 self.clf = Pipeline([("Tangent Space", ts)]) 113 114 # Minimum Distance to Mean 115 elif type == "MDM": 116 mdm = MDM(metric=dict(mean="riemann", distance="riemann"), n_jobs=n_jobs) 117 self.clf_model = Pipeline([("MDM", mdm)]) 118 self.clf = Pipeline([("MDM", mdm)]) 119 120 else: 121 logger.error("Classifier type not defined") 122 123 # All algorithms have covariance estimation as the first step 124 self.clf_model.steps.insert( 125 0, ["Covariances", Covariances(estimator=self.covariance_estimator)] 126 ) 127 self.clf.steps.insert( 128 0, ["Covariances", Covariances(estimator=self.covariance_estimator)] 129 ) 130 131 if artifact_rejection == "potato": 132 logger.error("Potato not implemented") 133 134 if whitening: 135 self.clf_model.steps.insert(0, ["Whitening", Whitening()]) 136 self.clf.steps.insert(0, ["Whitening", Whitening()]) 137 138 if remove_flats: 139 rf = FlatChannelRemover() 140 self.clf_model.steps.insert(0, ["Remove Flat Channels", rf]) 141 self.clf.steps.insert(0, ["Remove Flat Channels", rf]) 142 143 # Threshold 144 self.pred_threshold = pred_threshold 145 146 # Rebuild from scratch with each training 147 self.rebuild = True
Set MI classifier settings.
Parameters
- n_splits (int, optional):
Number of folds for cross-validation.
- Default is
5.
- Default is
- type (str, optional):
Type of classifier to be used.
Options = sLDA, RandomForest, TS, or MDM.
- Default is
"TS".
- Default is
- remove_flats (bool, optional):
Whether to remove flat channels from the EEG data.
- Default is
True.
- Default is
- whitening (bool, optional):
Whether to apply whitening to the EEG data.
- Default is
False.
- Default is
- covariance_estimator (str, optional):
Covariance estimator. See pyriemann Covariances.
- Default is
"oas".
- Default is
- artifact_rejection (str, optional):
Method for artefact rejection.
- Default is
"none".
- Default is
- pred_threshold (float, optional):
Prediction threshold used for classification.
- Default is
0.5.
- Default is
- random_seed (int, optional):
Random seed.
- Default is
42.
- Default is
- n_jobs (int, optional):
The number of threads to dedicate to this calculation.
- Default is
1.
- Default is
Returns
None: Models created are used infit().
def
fit(self):
149 def fit(self): 150 """Fit the model. 151 152 Returns 153 ------- 154 `None` 155 Models created used in `predict()`. 156 157 """ 158 # get dimensions 159 n_trials, n_channels, n_samples = self.X.shape 160 161 # do the rest of the training if train_free is false 162 self.X = np.array(self.X) 163 164 # Try rebuilding the classifier each time 165 if self.rebuild: 166 self.next_fit_trial = 0 167 self.clf = self.clf_model 168 169 # get temporal subset 170 subX = self.X[self.next_fit_trial :, :, :] 171 suby = self.y[self.next_fit_trial :] 172 self.next_fit_trial = n_trials 173 174 # Init predictions to all false 175 cv_preds = np.zeros(n_trials) 176 177 def __mi_kernel(subX, suby): 178 """MI kernel. 179 180 Parameters 181 ---------- 182 subX : numpy.ndarray 183 EEG data for training. 184 3D array with shape = (`n_epochs`, `n_channels`, `n_samples`). 185 suby : numpy.ndarray 186 Labels for training data. 187 1D array with shape = (`n_epochs`, ). 188 189 Returns 190 ------- 191 kernelResults : KernelResults 192 KernelResults object containing the following attributes: 193 model : classifier 194 The trained classification model. 195 cv_preds : numpy.ndarray 196 The predictions from the model using cross validation. 197 1D array with the same shape as `suby`. 198 accuracy : float 199 The accuracy of the trained classification model. 200 precision : float 201 The precision of the trained classification model. 202 recall : float 203 The recall of the trained classification model. 204 205 206 """ 207 for train_idx, test_idx in self.cv.split(subX, suby): 208 self.clf = self.clf_model 209 210 X_train, X_test = subX[train_idx], subX[test_idx] 211 # y_test not implemented 212 y_train = suby[train_idx] 213 214 # fit the classsifier 215 self.clf.fit(X_train, y_train) 216 cv_preds[test_idx] = self.clf.predict(X_test) 217 218 # Train final model with all available data 219 self.clf.fit(subX, suby) 220 model = self.clf 221 222 accuracy = sum(cv_preds == self.y) / len(cv_preds) 223 precision = precision_score(self.y, cv_preds, average="micro") 224 recall = recall_score(self.y, cv_preds, average="micro") 225 226 return KernelResults(model, cv_preds, accuracy, precision, recall) 227 228 # Check if channel selection is true 229 if self.channel_selection_setup: 230 if self.chs_iterative_selection is True and self.subset is not None: 231 initial_subset = self.subset 232 logger.info( 233 "Using subset from previous channel selection " 234 + "because iterative selection is TRUE" 235 ) 236 else: 237 initial_subset = self.chs_initial_subset 238 239 logger.info("Doing channel selection") 240 channel_selection_results = channel_selection_by_method( 241 __mi_kernel, 242 self.X, 243 self.y, 244 self.channel_labels, # kernel setup 245 self.chs_method, 246 self.chs_metric, 247 initial_subset, # wrapper setup 248 self.chs_max_time, 249 self.chs_min_channels, 250 self.chs_max_channels, 251 self.chs_performance_delta, # stopping criterion 252 self.chs_n_jobs, 253 ) 254 255 preds = channel_selection_results.best_preds 256 accuracy = channel_selection_results.best_accuracy 257 precision = channel_selection_results.best_precision 258 recall = channel_selection_results.best_recall 259 260 self.results_df = channel_selection_results.results_df 261 self.subset = channel_selection_results.best_channel_subset 262 self.subset_defined = True 263 self.clf = channel_selection_results.best_model 264 else: 265 logger.warning("Not doing channel selection") 266 267 subX = self.get_subset(subX, self.subset, self.channel_labels) 268 269 current_results = __mi_kernel(subX, suby) 270 self.clf = current_results.model 271 preds = current_results.cv_preds 272 accuracy = current_results.accuracy 273 precision = current_results.precision 274 recall = current_results.recall 275 276 # Log performance stats 277 278 self.offline_trial_count = n_trials 279 self.offline_trial_counts.append(self.offline_trial_count) 280 281 # accuracy 282 accuracy = sum(preds == self.y) / len(preds) 283 self.offline_accuracy.append(accuracy) 284 logger.info("Accuracy = %s", accuracy) 285 286 # precision 287 precision = precision_score(self.y, preds, average="micro") 288 self.offline_precision.append(precision) 289 logger.info("Precision = %s", precision) 290 291 # recall 292 recall = recall_score(self.y, preds, average="micro") 293 self.offline_recall.append(recall) 294 logger.info("Recall = %s", recall) 295 296 # confusion matrix in command line 297 cm = confusion_matrix(self.y, preds) 298 self.offline_cm = cm 299 logger.info("Confusion matrix:\n%s", cm)
def
predict(self, X):
301 def predict(self, X): 302 """Predict the class labels for the provided data. 303 304 Parameters 305 ---------- 306 X : numpy.ndarray 307 3D array where shape = (trials, channels, samples) 308 309 Returns 310 ------- 311 prediction : Prediction 312 Results of predict call containing the predicted class labels, and 313 the probabilities of the labels. 314 315 """ 316 # if X is 2D, make it 3D with one as first dimension 317 if len(X.shape) < 3: 318 X = X[np.newaxis, ...] 319 320 subset_X = self.get_subset(X, self.subset, self.channel_labels) 321 322 logger.info("The shape of X is %s", subset_X.shape) 323 324 pred = [int(x) for x in self.clf.predict(subset_X)] 325 pred_proba = self.clf.predict_proba(subset_X) 326 327 logger.info("Prediction: %s", pred) 328 logger.info("Prediction probabilities: %s", pred_proba) 329 330 for i in range(len(pred)): 331 self.predictions.append(pred[i]) 332 self.pred_probas.append(pred_proba[i]) 333 334 return Prediction(labels=pred, probabilities=pred_proba)
Predict the class labels for the provided data.
Parameters
- X (numpy.ndarray): 3D array where shape = (trials, channels, samples)
Returns
- prediction (Prediction): Results of predict call containing the predicted class labels, and the probabilities of the labels.