bci_essentials.classification.ssvep_riemannian_mdm_classifier
SSVEP Riemannian MDM Classifier
Classifies SSVEP based on relative band power at the expected frequencies.
1""" 2**SSVEP Riemannian MDM Classifier** 3 4Classifies SSVEP based on relative band power at the expected 5frequencies. 6 7""" 8 9# Stock libraries 10import numpy as np 11from sklearn.model_selection import StratifiedKFold 12from sklearn.pipeline import Pipeline 13from sklearn.metrics import confusion_matrix, precision_score, recall_score 14from pyriemann.classification import MDM 15from pyriemann.estimation import Covariances 16from pyriemann.channelselection import FlatChannelRemover 17 18# Import bci_essentials modules and methods 19from ..classification.generic_classifier import ( 20 GenericClassifier, 21 Prediction, 22 KernelResults, 23) 24from ..signal_processing import bandpass 25from ..channel_selection import channel_selection_by_method 26from ..utils.logger import Logger # Logger wrapper 27 28# Instantiate a logger for the module at the default level of logging.INFO 29# Logs to bci_essentials.__module__) where __module__ is the name of the module 30logger = Logger(name=__name__) 31 32 33class SsvepRiemannianMdmClassifier(GenericClassifier): 34 """SSVEP Riemannian MDM Classifier class 35 (*inherits from GenericClassifier*) 36 37 """ 38 39 def set_ssvep_settings( 40 self, 41 n_splits=3, 42 random_seed=42, 43 n_harmonics=2, 44 f_width=0.2, 45 covariance_estimator="oas", 46 remove_flats=True, 47 ): 48 """Set the SSVEP settings. 49 50 Parameters 51 ---------- 52 n_splits : int, *optional* 53 Number of folds for cross-validation. 54 - Default is `3`. 55 random_seed : int, *optional* 56 Random seed. 57 - Default is `42`. 58 n_harmonics : int, *optional* 59 Number of harmonics to be used for each frequency. 60 - Default is `2`. 61 f_width : float, *optional* 62 Width of frequency bins to be used around the target 63 frequencies. 64 - Default is `0.2`. 65 covariance_estimator : str, *optional* 66 Covariance Estimator (see Covariances - pyriemann) 67 - Default is `"oas"`. 68 remove_flats : bool, *optional* 69 Remove flat channels. 70 - Default is `True`. 71 72 Returns 73 ------- 74 `None` 75 Models created used in `fit()`. 76 77 """ 78 # Build the cross-validation split 79 self.n_splits = n_splits 80 self.cv = StratifiedKFold( 81 n_splits=self.n_splits, shuffle=True, random_state=random_seed 82 ) 83 84 self.rebuild = True 85 86 self.n_harmonics = n_harmonics 87 self.f_width = f_width 88 self.covariance_estimator = covariance_estimator 89 90 # Use an MDM classifier, maybe there will be other options later 91 mdm = MDM(metric=dict(mean="riemann", distance="riemann"), n_jobs=1) 92 self.clf_model = Pipeline([("MDM", mdm)]) 93 self.clf = Pipeline([("MDM", mdm)]) 94 95 if remove_flats: 96 rf = FlatChannelRemover() 97 self.clf_model.steps.insert(0, ["Remove Flat Channels", rf]) 98 self.clf.steps.insert(0, ["Remove Flat Channels", rf]) 99 100 def get_ssvep_supertrial( 101 self, 102 X, 103 target_freqs, 104 fsample, 105 f_width=0.4, 106 n_harmonics=2, 107 covariance_estimator="oas", 108 ): 109 """Get SSVEP Supertrial. 110 111 Creates the Riemannian Geometry supertrial for SSVEP. 112 113 Parameters 114 ---------- 115 X : numpy.ndarray 116 Trials of EEG data. 117 3D array containing data with `float` type. 118 shape = (`n_trials`,`n_channels`,`n_samples`) 119 target_freqs : numpy.ndarray 120 Target frequencies for the SSVEP. 121 fsample : float 122 Sampling rate. 123 f_width : float, *optional* 124 Width of frequency bins to be used around the target 125 frequencies. 126 - Default is `0.4`. 127 n_harmonics : int, *optional* 128 Number of harmonics to be used for each frequency. 129 - Default is `2`. 130 covarianc_estimator : str, *optional* 131 Covariance Estimator (see Covariances - pyriemann) 132 - Default is `"oas"`. 133 134 Returns 135 ------- 136 super_X : numpy.ndarray 137 Supertrials of X. 138 3D array containing data with `float` type. 139 140 shape = (`n_trials`,`n_channels*number of target_freqs`, 141 `n_channels*number of target_freqs`) 142 143 """ 144 n_trials, n_channels, n_samples = X.shape 145 n_target_freqs = len(target_freqs) 146 147 super_X = np.zeros( 148 [n_trials, n_channels * n_target_freqs, n_channels * n_target_freqs] 149 ) 150 151 # Create super trial of all trials filtered at all bands 152 for trial in range(n_trials): 153 for tf, target_freq in enumerate(target_freqs): 154 lower_bound = int((n_channels * tf)) 155 upper_bound = int((n_channels * tf) + n_channels) 156 157 signal = X[trial, :, :] 158 for f in range(n_harmonics): 159 if f == 0: 160 filt_signal = bandpass( 161 signal, 162 f_low=target_freq - (f_width / 2), 163 f_high=target_freq + (f_width / 2), 164 order=5, 165 fsample=fsample, 166 ) 167 else: 168 filt_signal += bandpass( 169 signal, 170 f_low=(target_freq * (f + 1)) - (f_width / 2), 171 f_high=(target_freq * (f + 1)) + (f_width / 2), 172 order=5, 173 fsample=fsample, 174 ) 175 176 cov_mat = Covariances(estimator=covariance_estimator).transform( 177 np.expand_dims(filt_signal, axis=0) 178 ) 179 180 cov_mat_diag = np.diag(np.diag(cov_mat[0, :, :])) 181 182 super_X[trial, lower_bound:upper_bound, lower_bound:upper_bound] = ( 183 cov_mat_diag 184 ) 185 186 return super_X 187 188 def fit(self): 189 """Fit the model. 190 191 Returns 192 ------- 193 `None` 194 Models created used in `predict()`. 195 196 """ 197 198 # Convert each trial of X into a SPD of dimensions [n_trials, n_channels*nfreqs, n_channels*nfreqs] 199 n_trials, n_channels, n_samples = self.X.shape 200 201 ################# 202 # Try rebuilding the classifier each time 203 if self.rebuild: 204 self.next_fit_trial = 0 205 self.clf = self.clf_model 206 207 # get temporal subset 208 subX = self.X[self.next_fit_trial :, :, :] 209 suby = self.y[self.next_fit_trial :] 210 self.next_fit_trial = n_trials 211 212 # Init predictions to all false 213 cv_preds = np.zeros(n_trials) 214 215 def __ssvep_kernel(subX, suby): 216 """SSVEP kernel. 217 218 Parameters 219 ---------- 220 subX : numpy.ndarray 221 Input data ffor training/testing the classifier. 222 3D array with shape = (`n_epochs`, `n_channels`, `n_samples`). 223 suby : numpy.ndarray 224 Target labels for the input data. 225 1D array with shape = (`n_epochs`, ). 226 227 Returns 228 ------- 229 kernelResults : KernelResults 230 KernelResults object containing the following attributes: 231 model : classifier 232 The trained classification model. 233 cv_preds : numpy.ndarray 234 The predictions from the model using cross validation. 235 1D array with the same shape as `suby`. 236 accuracy : float 237 The accuracy of the trained classification model. 238 precision : float 239 The precision of the trained classification model. 240 recall : float 241 The recall of the trained classification model. 242 243 244 """ 245 for train_idx, test_idx in self.cv.split(subX, suby): 246 self.clf = self.clf_model 247 248 X_train, X_test = subX[train_idx], subX[test_idx] 249 y_train = suby[train_idx] 250 251 # get the covariance matrices for the training set 252 X_train_super = self.get_ssvep_supertrial( 253 X_train, 254 self.target_freqs, 255 fsample=256, 256 n_harmonics=self.n_harmonics, 257 f_width=self.f_width, 258 covariance_estimator=self.covariance_estimator, 259 ) 260 X_test_super = self.get_ssvep_supertrial( 261 X_test, 262 self.target_freqs, 263 fsample=256, 264 n_harmonics=self.n_harmonics, 265 f_width=self.f_width, 266 covariance_estimator=self.covariance_estimator, 267 ) 268 269 # fit the classsifier 270 self.clf.fit(X_train_super, y_train) 271 cv_preds[test_idx] = self.clf.predict(X_test_super) 272 273 # Create super trial with all available data 274 X_super = self.get_ssvep_supertrial( 275 subX, 276 self.target_freqs, 277 fsample=256, 278 n_harmonics=self.n_harmonics, 279 f_width=self.f_width, 280 covariance_estimator=self.covariance_estimator, 281 ) 282 283 # Train final model with all available data 284 self.clf.fit(X_super, suby) 285 model = self.clf 286 287 accuracy = sum(cv_preds == self.y) / len(cv_preds) 288 precision = precision_score(self.y, cv_preds, average="micro") 289 recall = recall_score(self.y, cv_preds, average="micro") 290 291 return KernelResults(model, cv_preds, accuracy, precision, recall) 292 293 # Check if channel selection is true 294 if self.channel_selection_setup: 295 logger.info("Doing channel selection") 296 297 channel_selection_results = channel_selection_by_method( 298 __ssvep_kernel, 299 self.X, 300 self.y, 301 self.channel_labels, # kernel setup 302 self.chs_method, 303 self.chs_metric, 304 self.chs_initial_subset, # wrapper setup 305 self.chs_max_time, 306 self.chs_min_channels, 307 self.chs_max_channels, 308 self.chs_performance_delta, # stopping criterion 309 self.chs_n_jobs, 310 ) 311 312 preds = channel_selection_results.best_preds 313 accuracy = channel_selection_results.best_accuracy 314 precision = channel_selection_results.best_precision 315 recall = channel_selection_results.best_recall 316 317 logger.info( 318 "The optimal subset is: %s", 319 channel_selection_results.best_channel_subset, 320 ) 321 322 self.subset = channel_selection_results.best_channel_subset 323 self.clf = channel_selection_results.best_model 324 else: 325 logger.warning("Not doing channel selection") 326 current_results = __ssvep_kernel(subX, suby) 327 self.clf = current_results.model 328 preds = current_results.cv_preds 329 accuracy = current_results.accuracy 330 precision = current_results.precision 331 recall = current_results.recall 332 333 # Log performance stats 334 335 self.offline_trial_count = n_trials 336 self.offline_trial_counts.append(self.offline_trial_count) 337 338 # accuracy 339 accuracy = sum(preds == self.y) / len(preds) 340 self.offline_accuracy.append(accuracy) 341 logger.info("Accuracy = %s", accuracy) 342 343 # precision 344 precision = precision_score(self.y, preds, average="micro") 345 self.offline_precision.append(precision) 346 logger.info("Precision = %s", precision) 347 348 # recall 349 recall = recall_score(self.y, preds, average="micro") 350 self.offline_recall.append(recall) 351 logger.info("Recall = %s", recall) 352 353 # confusion matrix in command line 354 cm = confusion_matrix(self.y, preds) 355 self.offline_cm = cm 356 logger.info("Confusion matrix:\n%s", cm) 357 358 def predict(self, X): 359 """Predict the class labels for the provided data. 360 361 Parameters 362 ---------- 363 X : numpy.ndarray 364 3D array where shape = (trials, channels, samples) 365 366 Returns 367 ------- 368 prediction : Prediction 369 Results of predict call containing the predicted class labels, and 370 the probabilities of the labels. 371 372 """ 373 # if X is 2D, make it 3D with one as first dimension 374 if len(X.shape) < 3: 375 X = X[np.newaxis, ...] 376 377 X = self.get_subset(X) 378 379 logger.info("The shape of X is %s", X.shape) 380 381 X_super = self.get_ssvep_supertrial( 382 X, 383 self.target_freqs, 384 fsample=256, 385 n_harmonics=self.n_harmonics, 386 f_width=self.f_width, 387 ) 388 389 pred = self.clf.predict(X_super) 390 pred_proba = self.clf.predict_proba(X_super) 391 392 logger.info("Prediction: %s", pred) 393 logger.info("Prediction probabilities: %s", pred_proba) 394 395 for i in range(len(pred)): 396 self.predictions.append(pred[i]) 397 self.pred_probas.append(pred_proba[i]) 398 399 return Prediction(labels=pred, probabilities=pred_proba)
logger =
<bci_essentials.utils.logger.Logger object>
class
SsvepRiemannianMdmClassifier(bci_essentials.classification.generic_classifier.GenericClassifier):
34class SsvepRiemannianMdmClassifier(GenericClassifier): 35 """SSVEP Riemannian MDM Classifier class 36 (*inherits from GenericClassifier*) 37 38 """ 39 40 def set_ssvep_settings( 41 self, 42 n_splits=3, 43 random_seed=42, 44 n_harmonics=2, 45 f_width=0.2, 46 covariance_estimator="oas", 47 remove_flats=True, 48 ): 49 """Set the SSVEP settings. 50 51 Parameters 52 ---------- 53 n_splits : int, *optional* 54 Number of folds for cross-validation. 55 - Default is `3`. 56 random_seed : int, *optional* 57 Random seed. 58 - Default is `42`. 59 n_harmonics : int, *optional* 60 Number of harmonics to be used for each frequency. 61 - Default is `2`. 62 f_width : float, *optional* 63 Width of frequency bins to be used around the target 64 frequencies. 65 - Default is `0.2`. 66 covariance_estimator : str, *optional* 67 Covariance Estimator (see Covariances - pyriemann) 68 - Default is `"oas"`. 69 remove_flats : bool, *optional* 70 Remove flat channels. 71 - Default is `True`. 72 73 Returns 74 ------- 75 `None` 76 Models created used in `fit()`. 77 78 """ 79 # Build the cross-validation split 80 self.n_splits = n_splits 81 self.cv = StratifiedKFold( 82 n_splits=self.n_splits, shuffle=True, random_state=random_seed 83 ) 84 85 self.rebuild = True 86 87 self.n_harmonics = n_harmonics 88 self.f_width = f_width 89 self.covariance_estimator = covariance_estimator 90 91 # Use an MDM classifier, maybe there will be other options later 92 mdm = MDM(metric=dict(mean="riemann", distance="riemann"), n_jobs=1) 93 self.clf_model = Pipeline([("MDM", mdm)]) 94 self.clf = Pipeline([("MDM", mdm)]) 95 96 if remove_flats: 97 rf = FlatChannelRemover() 98 self.clf_model.steps.insert(0, ["Remove Flat Channels", rf]) 99 self.clf.steps.insert(0, ["Remove Flat Channels", rf]) 100 101 def get_ssvep_supertrial( 102 self, 103 X, 104 target_freqs, 105 fsample, 106 f_width=0.4, 107 n_harmonics=2, 108 covariance_estimator="oas", 109 ): 110 """Get SSVEP Supertrial. 111 112 Creates the Riemannian Geometry supertrial for SSVEP. 113 114 Parameters 115 ---------- 116 X : numpy.ndarray 117 Trials of EEG data. 118 3D array containing data with `float` type. 119 shape = (`n_trials`,`n_channels`,`n_samples`) 120 target_freqs : numpy.ndarray 121 Target frequencies for the SSVEP. 122 fsample : float 123 Sampling rate. 124 f_width : float, *optional* 125 Width of frequency bins to be used around the target 126 frequencies. 127 - Default is `0.4`. 128 n_harmonics : int, *optional* 129 Number of harmonics to be used for each frequency. 130 - Default is `2`. 131 covarianc_estimator : str, *optional* 132 Covariance Estimator (see Covariances - pyriemann) 133 - Default is `"oas"`. 134 135 Returns 136 ------- 137 super_X : numpy.ndarray 138 Supertrials of X. 139 3D array containing data with `float` type. 140 141 shape = (`n_trials`,`n_channels*number of target_freqs`, 142 `n_channels*number of target_freqs`) 143 144 """ 145 n_trials, n_channels, n_samples = X.shape 146 n_target_freqs = len(target_freqs) 147 148 super_X = np.zeros( 149 [n_trials, n_channels * n_target_freqs, n_channels * n_target_freqs] 150 ) 151 152 # Create super trial of all trials filtered at all bands 153 for trial in range(n_trials): 154 for tf, target_freq in enumerate(target_freqs): 155 lower_bound = int((n_channels * tf)) 156 upper_bound = int((n_channels * tf) + n_channels) 157 158 signal = X[trial, :, :] 159 for f in range(n_harmonics): 160 if f == 0: 161 filt_signal = bandpass( 162 signal, 163 f_low=target_freq - (f_width / 2), 164 f_high=target_freq + (f_width / 2), 165 order=5, 166 fsample=fsample, 167 ) 168 else: 169 filt_signal += bandpass( 170 signal, 171 f_low=(target_freq * (f + 1)) - (f_width / 2), 172 f_high=(target_freq * (f + 1)) + (f_width / 2), 173 order=5, 174 fsample=fsample, 175 ) 176 177 cov_mat = Covariances(estimator=covariance_estimator).transform( 178 np.expand_dims(filt_signal, axis=0) 179 ) 180 181 cov_mat_diag = np.diag(np.diag(cov_mat[0, :, :])) 182 183 super_X[trial, lower_bound:upper_bound, lower_bound:upper_bound] = ( 184 cov_mat_diag 185 ) 186 187 return super_X 188 189 def fit(self): 190 """Fit the model. 191 192 Returns 193 ------- 194 `None` 195 Models created used in `predict()`. 196 197 """ 198 199 # Convert each trial of X into a SPD of dimensions [n_trials, n_channels*nfreqs, n_channels*nfreqs] 200 n_trials, n_channels, n_samples = self.X.shape 201 202 ################# 203 # Try rebuilding the classifier each time 204 if self.rebuild: 205 self.next_fit_trial = 0 206 self.clf = self.clf_model 207 208 # get temporal subset 209 subX = self.X[self.next_fit_trial :, :, :] 210 suby = self.y[self.next_fit_trial :] 211 self.next_fit_trial = n_trials 212 213 # Init predictions to all false 214 cv_preds = np.zeros(n_trials) 215 216 def __ssvep_kernel(subX, suby): 217 """SSVEP kernel. 218 219 Parameters 220 ---------- 221 subX : numpy.ndarray 222 Input data ffor training/testing the classifier. 223 3D array with shape = (`n_epochs`, `n_channels`, `n_samples`). 224 suby : numpy.ndarray 225 Target labels for the input data. 226 1D array with shape = (`n_epochs`, ). 227 228 Returns 229 ------- 230 kernelResults : KernelResults 231 KernelResults object containing the following attributes: 232 model : classifier 233 The trained classification model. 234 cv_preds : numpy.ndarray 235 The predictions from the model using cross validation. 236 1D array with the same shape as `suby`. 237 accuracy : float 238 The accuracy of the trained classification model. 239 precision : float 240 The precision of the trained classification model. 241 recall : float 242 The recall of the trained classification model. 243 244 245 """ 246 for train_idx, test_idx in self.cv.split(subX, suby): 247 self.clf = self.clf_model 248 249 X_train, X_test = subX[train_idx], subX[test_idx] 250 y_train = suby[train_idx] 251 252 # get the covariance matrices for the training set 253 X_train_super = self.get_ssvep_supertrial( 254 X_train, 255 self.target_freqs, 256 fsample=256, 257 n_harmonics=self.n_harmonics, 258 f_width=self.f_width, 259 covariance_estimator=self.covariance_estimator, 260 ) 261 X_test_super = self.get_ssvep_supertrial( 262 X_test, 263 self.target_freqs, 264 fsample=256, 265 n_harmonics=self.n_harmonics, 266 f_width=self.f_width, 267 covariance_estimator=self.covariance_estimator, 268 ) 269 270 # fit the classsifier 271 self.clf.fit(X_train_super, y_train) 272 cv_preds[test_idx] = self.clf.predict(X_test_super) 273 274 # Create super trial with all available data 275 X_super = self.get_ssvep_supertrial( 276 subX, 277 self.target_freqs, 278 fsample=256, 279 n_harmonics=self.n_harmonics, 280 f_width=self.f_width, 281 covariance_estimator=self.covariance_estimator, 282 ) 283 284 # Train final model with all available data 285 self.clf.fit(X_super, suby) 286 model = self.clf 287 288 accuracy = sum(cv_preds == self.y) / len(cv_preds) 289 precision = precision_score(self.y, cv_preds, average="micro") 290 recall = recall_score(self.y, cv_preds, average="micro") 291 292 return KernelResults(model, cv_preds, accuracy, precision, recall) 293 294 # Check if channel selection is true 295 if self.channel_selection_setup: 296 logger.info("Doing channel selection") 297 298 channel_selection_results = channel_selection_by_method( 299 __ssvep_kernel, 300 self.X, 301 self.y, 302 self.channel_labels, # kernel setup 303 self.chs_method, 304 self.chs_metric, 305 self.chs_initial_subset, # wrapper setup 306 self.chs_max_time, 307 self.chs_min_channels, 308 self.chs_max_channels, 309 self.chs_performance_delta, # stopping criterion 310 self.chs_n_jobs, 311 ) 312 313 preds = channel_selection_results.best_preds 314 accuracy = channel_selection_results.best_accuracy 315 precision = channel_selection_results.best_precision 316 recall = channel_selection_results.best_recall 317 318 logger.info( 319 "The optimal subset is: %s", 320 channel_selection_results.best_channel_subset, 321 ) 322 323 self.subset = channel_selection_results.best_channel_subset 324 self.clf = channel_selection_results.best_model 325 else: 326 logger.warning("Not doing channel selection") 327 current_results = __ssvep_kernel(subX, suby) 328 self.clf = current_results.model 329 preds = current_results.cv_preds 330 accuracy = current_results.accuracy 331 precision = current_results.precision 332 recall = current_results.recall 333 334 # Log performance stats 335 336 self.offline_trial_count = n_trials 337 self.offline_trial_counts.append(self.offline_trial_count) 338 339 # accuracy 340 accuracy = sum(preds == self.y) / len(preds) 341 self.offline_accuracy.append(accuracy) 342 logger.info("Accuracy = %s", accuracy) 343 344 # precision 345 precision = precision_score(self.y, preds, average="micro") 346 self.offline_precision.append(precision) 347 logger.info("Precision = %s", precision) 348 349 # recall 350 recall = recall_score(self.y, preds, average="micro") 351 self.offline_recall.append(recall) 352 logger.info("Recall = %s", recall) 353 354 # confusion matrix in command line 355 cm = confusion_matrix(self.y, preds) 356 self.offline_cm = cm 357 logger.info("Confusion matrix:\n%s", cm) 358 359 def predict(self, X): 360 """Predict the class labels for the provided data. 361 362 Parameters 363 ---------- 364 X : numpy.ndarray 365 3D array where shape = (trials, channels, samples) 366 367 Returns 368 ------- 369 prediction : Prediction 370 Results of predict call containing the predicted class labels, and 371 the probabilities of the labels. 372 373 """ 374 # if X is 2D, make it 3D with one as first dimension 375 if len(X.shape) < 3: 376 X = X[np.newaxis, ...] 377 378 X = self.get_subset(X) 379 380 logger.info("The shape of X is %s", X.shape) 381 382 X_super = self.get_ssvep_supertrial( 383 X, 384 self.target_freqs, 385 fsample=256, 386 n_harmonics=self.n_harmonics, 387 f_width=self.f_width, 388 ) 389 390 pred = self.clf.predict(X_super) 391 pred_proba = self.clf.predict_proba(X_super) 392 393 logger.info("Prediction: %s", pred) 394 logger.info("Prediction probabilities: %s", pred_proba) 395 396 for i in range(len(pred)): 397 self.predictions.append(pred[i]) 398 self.pred_probas.append(pred_proba[i]) 399 400 return Prediction(labels=pred, probabilities=pred_proba)
SSVEP Riemannian MDM Classifier class (inherits from GenericClassifier)
def
set_ssvep_settings( self, n_splits=3, random_seed=42, n_harmonics=2, f_width=0.2, covariance_estimator='oas', remove_flats=True):
40 def set_ssvep_settings( 41 self, 42 n_splits=3, 43 random_seed=42, 44 n_harmonics=2, 45 f_width=0.2, 46 covariance_estimator="oas", 47 remove_flats=True, 48 ): 49 """Set the SSVEP settings. 50 51 Parameters 52 ---------- 53 n_splits : int, *optional* 54 Number of folds for cross-validation. 55 - Default is `3`. 56 random_seed : int, *optional* 57 Random seed. 58 - Default is `42`. 59 n_harmonics : int, *optional* 60 Number of harmonics to be used for each frequency. 61 - Default is `2`. 62 f_width : float, *optional* 63 Width of frequency bins to be used around the target 64 frequencies. 65 - Default is `0.2`. 66 covariance_estimator : str, *optional* 67 Covariance Estimator (see Covariances - pyriemann) 68 - Default is `"oas"`. 69 remove_flats : bool, *optional* 70 Remove flat channels. 71 - Default is `True`. 72 73 Returns 74 ------- 75 `None` 76 Models created used in `fit()`. 77 78 """ 79 # Build the cross-validation split 80 self.n_splits = n_splits 81 self.cv = StratifiedKFold( 82 n_splits=self.n_splits, shuffle=True, random_state=random_seed 83 ) 84 85 self.rebuild = True 86 87 self.n_harmonics = n_harmonics 88 self.f_width = f_width 89 self.covariance_estimator = covariance_estimator 90 91 # Use an MDM classifier, maybe there will be other options later 92 mdm = MDM(metric=dict(mean="riemann", distance="riemann"), n_jobs=1) 93 self.clf_model = Pipeline([("MDM", mdm)]) 94 self.clf = Pipeline([("MDM", mdm)]) 95 96 if remove_flats: 97 rf = FlatChannelRemover() 98 self.clf_model.steps.insert(0, ["Remove Flat Channels", rf]) 99 self.clf.steps.insert(0, ["Remove Flat Channels", rf])
Set the SSVEP settings.
Parameters
- n_splits (int, optional):
Number of folds for cross-validation.
- Default is
3.
- Default is
- random_seed (int, optional):
Random seed.
- Default is
42.
- Default is
- n_harmonics (int, optional):
Number of harmonics to be used for each frequency.
- Default is
2.
- Default is
- f_width (float, optional):
Width of frequency bins to be used around the target
frequencies.
- Default is
0.2.
- Default is
- covariance_estimator (str, optional):
Covariance Estimator (see Covariances - pyriemann)
- Default is
"oas".
- Default is
- remove_flats (bool, optional):
Remove flat channels.
- Default is
True.
- Default is
Returns
None: Models created used infit().
def
get_ssvep_supertrial( self, X, target_freqs, fsample, f_width=0.4, n_harmonics=2, covariance_estimator='oas'):
101 def get_ssvep_supertrial( 102 self, 103 X, 104 target_freqs, 105 fsample, 106 f_width=0.4, 107 n_harmonics=2, 108 covariance_estimator="oas", 109 ): 110 """Get SSVEP Supertrial. 111 112 Creates the Riemannian Geometry supertrial for SSVEP. 113 114 Parameters 115 ---------- 116 X : numpy.ndarray 117 Trials of EEG data. 118 3D array containing data with `float` type. 119 shape = (`n_trials`,`n_channels`,`n_samples`) 120 target_freqs : numpy.ndarray 121 Target frequencies for the SSVEP. 122 fsample : float 123 Sampling rate. 124 f_width : float, *optional* 125 Width of frequency bins to be used around the target 126 frequencies. 127 - Default is `0.4`. 128 n_harmonics : int, *optional* 129 Number of harmonics to be used for each frequency. 130 - Default is `2`. 131 covarianc_estimator : str, *optional* 132 Covariance Estimator (see Covariances - pyriemann) 133 - Default is `"oas"`. 134 135 Returns 136 ------- 137 super_X : numpy.ndarray 138 Supertrials of X. 139 3D array containing data with `float` type. 140 141 shape = (`n_trials`,`n_channels*number of target_freqs`, 142 `n_channels*number of target_freqs`) 143 144 """ 145 n_trials, n_channels, n_samples = X.shape 146 n_target_freqs = len(target_freqs) 147 148 super_X = np.zeros( 149 [n_trials, n_channels * n_target_freqs, n_channels * n_target_freqs] 150 ) 151 152 # Create super trial of all trials filtered at all bands 153 for trial in range(n_trials): 154 for tf, target_freq in enumerate(target_freqs): 155 lower_bound = int((n_channels * tf)) 156 upper_bound = int((n_channels * tf) + n_channels) 157 158 signal = X[trial, :, :] 159 for f in range(n_harmonics): 160 if f == 0: 161 filt_signal = bandpass( 162 signal, 163 f_low=target_freq - (f_width / 2), 164 f_high=target_freq + (f_width / 2), 165 order=5, 166 fsample=fsample, 167 ) 168 else: 169 filt_signal += bandpass( 170 signal, 171 f_low=(target_freq * (f + 1)) - (f_width / 2), 172 f_high=(target_freq * (f + 1)) + (f_width / 2), 173 order=5, 174 fsample=fsample, 175 ) 176 177 cov_mat = Covariances(estimator=covariance_estimator).transform( 178 np.expand_dims(filt_signal, axis=0) 179 ) 180 181 cov_mat_diag = np.diag(np.diag(cov_mat[0, :, :])) 182 183 super_X[trial, lower_bound:upper_bound, lower_bound:upper_bound] = ( 184 cov_mat_diag 185 ) 186 187 return super_X
Get SSVEP Supertrial.
Creates the Riemannian Geometry supertrial for SSVEP.
Parameters
- X (numpy.ndarray):
Trials of EEG data.
3D array containing data with
floattype. shape = (n_trials,n_channels,n_samples) - target_freqs (numpy.ndarray): Target frequencies for the SSVEP.
- fsample (float): Sampling rate.
- f_width (float, optional):
Width of frequency bins to be used around the target
frequencies.
- Default is
0.4.
- Default is
- n_harmonics (int, optional):
Number of harmonics to be used for each frequency.
- Default is
2.
- Default is
- covarianc_estimator (str, optional):
Covariance Estimator (see Covariances - pyriemann)
- Default is
"oas".
- Default is
Returns
super_X (numpy.ndarray): Supertrials of X. 3D array containing data with
floattype.shape = (
n_trials,n_channels*number of target_freqs,n_channels*number of target_freqs)
def
fit(self):
189 def fit(self): 190 """Fit the model. 191 192 Returns 193 ------- 194 `None` 195 Models created used in `predict()`. 196 197 """ 198 199 # Convert each trial of X into a SPD of dimensions [n_trials, n_channels*nfreqs, n_channels*nfreqs] 200 n_trials, n_channels, n_samples = self.X.shape 201 202 ################# 203 # Try rebuilding the classifier each time 204 if self.rebuild: 205 self.next_fit_trial = 0 206 self.clf = self.clf_model 207 208 # get temporal subset 209 subX = self.X[self.next_fit_trial :, :, :] 210 suby = self.y[self.next_fit_trial :] 211 self.next_fit_trial = n_trials 212 213 # Init predictions to all false 214 cv_preds = np.zeros(n_trials) 215 216 def __ssvep_kernel(subX, suby): 217 """SSVEP kernel. 218 219 Parameters 220 ---------- 221 subX : numpy.ndarray 222 Input data ffor training/testing the classifier. 223 3D array with shape = (`n_epochs`, `n_channels`, `n_samples`). 224 suby : numpy.ndarray 225 Target labels for the input data. 226 1D array with shape = (`n_epochs`, ). 227 228 Returns 229 ------- 230 kernelResults : KernelResults 231 KernelResults object containing the following attributes: 232 model : classifier 233 The trained classification model. 234 cv_preds : numpy.ndarray 235 The predictions from the model using cross validation. 236 1D array with the same shape as `suby`. 237 accuracy : float 238 The accuracy of the trained classification model. 239 precision : float 240 The precision of the trained classification model. 241 recall : float 242 The recall of the trained classification model. 243 244 245 """ 246 for train_idx, test_idx in self.cv.split(subX, suby): 247 self.clf = self.clf_model 248 249 X_train, X_test = subX[train_idx], subX[test_idx] 250 y_train = suby[train_idx] 251 252 # get the covariance matrices for the training set 253 X_train_super = self.get_ssvep_supertrial( 254 X_train, 255 self.target_freqs, 256 fsample=256, 257 n_harmonics=self.n_harmonics, 258 f_width=self.f_width, 259 covariance_estimator=self.covariance_estimator, 260 ) 261 X_test_super = self.get_ssvep_supertrial( 262 X_test, 263 self.target_freqs, 264 fsample=256, 265 n_harmonics=self.n_harmonics, 266 f_width=self.f_width, 267 covariance_estimator=self.covariance_estimator, 268 ) 269 270 # fit the classsifier 271 self.clf.fit(X_train_super, y_train) 272 cv_preds[test_idx] = self.clf.predict(X_test_super) 273 274 # Create super trial with all available data 275 X_super = self.get_ssvep_supertrial( 276 subX, 277 self.target_freqs, 278 fsample=256, 279 n_harmonics=self.n_harmonics, 280 f_width=self.f_width, 281 covariance_estimator=self.covariance_estimator, 282 ) 283 284 # Train final model with all available data 285 self.clf.fit(X_super, suby) 286 model = self.clf 287 288 accuracy = sum(cv_preds == self.y) / len(cv_preds) 289 precision = precision_score(self.y, cv_preds, average="micro") 290 recall = recall_score(self.y, cv_preds, average="micro") 291 292 return KernelResults(model, cv_preds, accuracy, precision, recall) 293 294 # Check if channel selection is true 295 if self.channel_selection_setup: 296 logger.info("Doing channel selection") 297 298 channel_selection_results = channel_selection_by_method( 299 __ssvep_kernel, 300 self.X, 301 self.y, 302 self.channel_labels, # kernel setup 303 self.chs_method, 304 self.chs_metric, 305 self.chs_initial_subset, # wrapper setup 306 self.chs_max_time, 307 self.chs_min_channels, 308 self.chs_max_channels, 309 self.chs_performance_delta, # stopping criterion 310 self.chs_n_jobs, 311 ) 312 313 preds = channel_selection_results.best_preds 314 accuracy = channel_selection_results.best_accuracy 315 precision = channel_selection_results.best_precision 316 recall = channel_selection_results.best_recall 317 318 logger.info( 319 "The optimal subset is: %s", 320 channel_selection_results.best_channel_subset, 321 ) 322 323 self.subset = channel_selection_results.best_channel_subset 324 self.clf = channel_selection_results.best_model 325 else: 326 logger.warning("Not doing channel selection") 327 current_results = __ssvep_kernel(subX, suby) 328 self.clf = current_results.model 329 preds = current_results.cv_preds 330 accuracy = current_results.accuracy 331 precision = current_results.precision 332 recall = current_results.recall 333 334 # Log performance stats 335 336 self.offline_trial_count = n_trials 337 self.offline_trial_counts.append(self.offline_trial_count) 338 339 # accuracy 340 accuracy = sum(preds == self.y) / len(preds) 341 self.offline_accuracy.append(accuracy) 342 logger.info("Accuracy = %s", accuracy) 343 344 # precision 345 precision = precision_score(self.y, preds, average="micro") 346 self.offline_precision.append(precision) 347 logger.info("Precision = %s", precision) 348 349 # recall 350 recall = recall_score(self.y, preds, average="micro") 351 self.offline_recall.append(recall) 352 logger.info("Recall = %s", recall) 353 354 # confusion matrix in command line 355 cm = confusion_matrix(self.y, preds) 356 self.offline_cm = cm 357 logger.info("Confusion matrix:\n%s", cm)
def
predict(self, X):
359 def predict(self, X): 360 """Predict the class labels for the provided data. 361 362 Parameters 363 ---------- 364 X : numpy.ndarray 365 3D array where shape = (trials, channels, samples) 366 367 Returns 368 ------- 369 prediction : Prediction 370 Results of predict call containing the predicted class labels, and 371 the probabilities of the labels. 372 373 """ 374 # if X is 2D, make it 3D with one as first dimension 375 if len(X.shape) < 3: 376 X = X[np.newaxis, ...] 377 378 X = self.get_subset(X) 379 380 logger.info("The shape of X is %s", X.shape) 381 382 X_super = self.get_ssvep_supertrial( 383 X, 384 self.target_freqs, 385 fsample=256, 386 n_harmonics=self.n_harmonics, 387 f_width=self.f_width, 388 ) 389 390 pred = self.clf.predict(X_super) 391 pred_proba = self.clf.predict_proba(X_super) 392 393 logger.info("Prediction: %s", pred) 394 logger.info("Prediction probabilities: %s", pred_proba) 395 396 for i in range(len(pred)): 397 self.predictions.append(pred[i]) 398 self.pred_probas.append(pred_proba[i]) 399 400 return Prediction(labels=pred, probabilities=pred_proba)
Predict the class labels for the provided data.
Parameters
- X (numpy.ndarray): 3D array where shape = (trials, channels, samples)
Returns
- prediction (Prediction): Results of predict call containing the predicted class labels, and the probabilities of the labels.