bci_essentials.classification.erp_rg_classifier
ERP RG Classifier
This classifier is used to classify ERPs using the Riemannian Geometry approach.
1"""**ERP RG Classifier** 2 3This classifier is used to classify ERPs using the Riemannian Geometry 4approach. 5 6""" 7 8# Stock libraries 9import random 10import numpy as np 11import matplotlib.pyplot as plt 12from sklearn.pipeline import make_pipeline 13from sklearn.model_selection import StratifiedKFold 14from sklearn.metrics import ( 15 confusion_matrix, 16 ConfusionMatrixDisplay, 17 precision_score, 18 recall_score, 19) 20from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 21from pyriemann.estimation import XdawnCovariances 22from pyriemann.tangentspace import TangentSpace 23from pyriemann.channelselection import FlatChannelRemover 24 25# Import bci_essentials modules and methods 26from ..classification.generic_classifier import ( 27 GenericClassifier, 28 Prediction, 29 KernelResults, 30) 31from ..signal_processing import lico 32from ..channel_selection import channel_selection_by_method 33from ..utils.logger import Logger # Logger wrapper 34 35# Instantiate a logger for the module at the default level of logging.INFO 36# Logs to bci_essentials.__module__) where __module__ is the name of the module 37logger = Logger(name=__name__) 38 39 40class ErpRgClassifier(GenericClassifier): 41 """ERP RG Classifier class (*inherits from `GenericClassifier`*).""" 42 43 def set_p300_clf_settings( 44 self, 45 n_splits=3, 46 lico_expansion_factor=1, 47 oversample_ratio=0, 48 undersample_ratio=0, 49 random_seed=42, 50 covariance_estimator="oas", # Covariance estimator, see pyriemann Covariances 51 remove_flats=True, 52 ): 53 """Set P300 Classifier Settings. 54 55 Parameters 56 ---------- 57 n_splits : int, *optional* 58 Number of folds for cross-validation. 59 - Default is `3`. 60 lico_expansion_factor : int, *optional* 61 Linear Combination Oversampling expansion factor, which is the 62 factor by which the number of ERPs in the training set will be 63 expanded. 64 - Default is `1`. 65 oversample_ratio : float, *optional* 66 Traditional oversampling. Range is from from 0.1-1 resulting 67 from the ratio of erp to non-erp class. 0 for no oversampling. 68 - Default is `0`. 69 undersample_ratio : float, *optional* 70 Traditional undersampling. Range is from from 0.1-1 resulting 71 from the ratio of erp to non-erp class. 0 for no undersampling. 72 - Default is `0`. 73 random_seed : int, *optional* 74 Random seed. 75 - Default is `42`. 76 covariance_estimator : str, *optional* 77 Covariance estimator. See pyriemann Covariances. 78 - Default is `"oas"`. 79 remove_flats : bool, *optional* 80 Whether to remove flat channels. 81 - Default is `True`. 82 83 Returns 84 ------- 85 `None` 86 87 """ 88 self.n_splits = n_splits 89 self.lico_expansion_factor = lico_expansion_factor 90 self.oversample_ratio = oversample_ratio 91 self.undersample_ratio = undersample_ratio 92 self.random_seed = random_seed 93 self.covariance_estimator = covariance_estimator 94 95 # Define the classifier 96 self.clf = make_pipeline( 97 XdawnCovariances(estimator=self.covariance_estimator), 98 TangentSpace(metric="riemann"), 99 LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"), 100 ) 101 102 if remove_flats: 103 rf = FlatChannelRemover() 104 self.clf.steps.insert(0, ["Remove Flat Channels", rf]) 105 106 def fit( 107 self, 108 plot_cm=False, 109 plot_roc=False, 110 lico_expansion_factor=1, 111 ): 112 """Fit the model. 113 114 Parameters 115 ---------- 116 plot_cm : bool, *optional* 117 Whether to plot the confusion matrix during training. 118 - Default is `False`. 119 plot_roc : bool, *optional* 120 Whether to plot the ROC curve during training. 121 - Default is `False`. 122 lico_expansion_factor : int, *optional* 123 Linear combination oversampling expansion factor. 124 Determines the number of ERPs in the training set that will be expanded. 125 Higher value increases the oversampling, generating more synthetic 126 samples for the minority class. 127 - Default is `1`. 128 129 Returns 130 ------- 131 `None` 132 Models created used in `predict()`. 133 134 """ 135 logger.info("Fitting the model using RG") 136 logger.info("X shape: %s", self.X.shape) 137 logger.info("y shape: %s", self.y.shape) 138 139 # Define the strategy for cross validation 140 cv = StratifiedKFold( 141 n_splits=self.n_splits, shuffle=True, random_state=self.random_seed 142 ) 143 144 # Init predictions to all false 145 cv_preds = np.zeros(len(self.y)) 146 147 def __erp_rg_kernel(X, y): 148 """ERP RG kernel. 149 150 Parameters 151 ---------- 152 X : numpy.ndarray 153 Input features (ERP data) for training. 154 3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`). 155 E.g. (100, 32, 1000) for 100 trials, 32 channels and 1000 samples per channel. 156 157 y : numpy.ndarray 158 Target labels corresponding to the input features in `X`. 159 1D numpy array with shape (n_trails, ). 160 Each label indicates the class of the corresponding trial in `X`. 161 E.g. (100, ) for 100 trials. 162 163 164 Returns 165 ------- 166 kernelResults : KernelResults 167 KernelResults object containing the following attributes: 168 model : classifier 169 The trained classification model. 170 cv_preds : numpy.ndarray 171 The predictions from the model using cross validation. 172 1D array with the same shape as `y`. 173 accuracy : float 174 The accuracy of the trained classification model. 175 precision : float 176 The precision of the trained classification model. 177 recall : float 178 The recall of the trained classification model. 179 180 """ 181 for train_idx, test_idx in cv.split(X, y): 182 y_train, y_test = y[train_idx], y[test_idx] 183 184 X_train, X_test = X[train_idx], X[test_idx] 185 186 # LICO 187 logger.debug( 188 "Before LICO:\n\tShape X: %s\n\tShape y: %s", 189 X_train.shape, 190 y_train.shape, 191 ) 192 193 if sum(y_train) > 2: 194 if lico_expansion_factor > 1: 195 X_train, y_train = lico( 196 X_train, 197 y_train, 198 expansion_factor=lico_expansion_factor, 199 sum_num=2, 200 shuffle=False, 201 ) 202 logger.debug("y_train = %s", y_train) 203 204 logger.debug( 205 "After LICO:\n\tShape X: %s\n\tShape y: %s", 206 X_train.shape, 207 y_train.shape, 208 ) 209 210 # Oversampling 211 if self.oversample_ratio > 0: 212 p_count = sum(y_train) 213 n_count = len(y_train) - sum(y_train) 214 215 num_to_add = int( 216 np.floor((self.oversample_ratio * n_count) - p_count) 217 ) 218 219 # Add num_to_add random selections from the positive 220 true_X_train = X_train[y_train == 1] 221 222 len_X_train = len(true_X_train) 223 224 for s in range(num_to_add): 225 to_add_X = true_X_train[random.randrange(0, len_X_train), :, :] 226 227 X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0) 228 y_train = np.append(y_train, [1], axis=0) 229 230 # Undersampling 231 if self.undersample_ratio > 0: 232 p_count = sum(y_train) 233 n_count = len(y_train) - sum(y_train) 234 235 num_to_remove = int( 236 np.floor(n_count - (p_count / self.undersample_ratio)) 237 ) 238 239 ind_range = np.arange(len(y_train)) 240 ind_list = list(ind_range) 241 to_remove = [] 242 243 # Remove num_to_remove random selections from the negative 244 false_ind = list(ind_range[y_train == 0]) 245 246 for s in range(num_to_remove): 247 # select a random value from the list of false indices 248 remove_at = false_ind[random.randrange(0, len(false_ind))] 249 250 # remove that value from the false index list 251 false_ind.remove(remove_at) 252 253 # add the index to be removed to a list 254 to_remove.append(remove_at) 255 256 remaining_ind = ind_list 257 for i in range(len(to_remove)): 258 remaining_ind.remove(to_remove[i]) 259 260 X_train = X_train[remaining_ind, :, :] 261 y_train = y_train[remaining_ind] 262 263 self.clf.fit(X_train, y_train) 264 cv_preds[test_idx] = self.clf.predict(X_test) 265 predproba = self.clf.predict_proba(X_test) 266 267 # Use pred proba to show what would be predicted 268 predprobs = predproba[:, 1] 269 real = np.where(y_test == 1) 270 271 # TODO handle exception where two probabilities are the same 272 prediction = int(np.where(predprobs == np.amax(predprobs))[0][0]) 273 274 logger.debug("y_test = %s", y_test) 275 logger.debug("predproba = %s", predproba) 276 logger.debug("real = %s", real[0]) 277 logger.debug("prediction = %s", prediction) 278 279 # Train final model with all available data 280 self.clf.fit(X, y) 281 model = self.clf 282 283 accuracy = sum(cv_preds == self.y) / len(cv_preds) 284 precision = precision_score(self.y, cv_preds) 285 recall = recall_score(self.y, cv_preds) 286 287 return KernelResults(model, cv_preds, accuracy, precision, recall) 288 289 # Check if channel selection is true 290 if self.channel_selection_setup: 291 logger.info("Doing channel selection") 292 logger.debug("Initial subset: %s", self.chs_initial_subset) 293 294 channel_selection_results = channel_selection_by_method( 295 __erp_rg_kernel, 296 self.X, 297 self.y, 298 self.channel_labels, # kernel setup 299 self.chs_method, 300 self.chs_metric, 301 self.chs_initial_subset, # wrapper setup 302 self.chs_max_time, 303 self.chs_min_channels, 304 self.chs_max_channels, 305 self.chs_performance_delta, # stopping criterion 306 self.chs_n_jobs, 307 ) # njobs, output messages 308 309 preds = channel_selection_results.best_preds 310 accuracy = channel_selection_results.best_accuracy 311 precision = channel_selection_results.best_precision 312 recall = channel_selection_results.best_recall 313 314 logger.info( 315 "The optimal subset is %s", 316 channel_selection_results.best_channel_subset, 317 ) 318 319 self.results_df = channel_selection_results.results_df 320 self.subset = channel_selection_results.best_channel_subset 321 self.subset_defined = True 322 self.clf = channel_selection_results.best_model 323 else: 324 logger.warning("Not doing channel selection") 325 X = self.get_subset(self.X, self.subset, self.channel_labels) 326 327 current_results = __erp_rg_kernel(X, self.y) 328 self.clf = current_results.model 329 preds = current_results.cv_preds 330 accuracy = current_results.accuracy 331 precision = current_results.precision 332 recall = current_results.recall 333 334 # Log performance stats 335 # accuracy 336 accuracy = sum(preds == self.y) / len(preds) 337 self.offline_accuracy = accuracy 338 logger.info("Accuracy = %s", accuracy) 339 340 # precision 341 precision = precision_score(self.y, preds) 342 self.offline_precision = precision 343 logger.info("Precision = %s", precision) 344 345 # recall 346 recall = recall_score(self.y, preds) 347 self.offline_recall = recall 348 logger.info("Recall = %s", recall) 349 350 # confusion matrix in command line 351 cm = confusion_matrix(self.y, preds) 352 self.offline_cm = cm 353 logger.info("Confusion matrix:\n%s", cm) 354 355 if plot_cm: 356 cm = confusion_matrix(self.y, preds) 357 ConfusionMatrixDisplay(cm).plot() 358 plt.show() 359 360 if plot_roc: 361 logger.error("ROC plot has not been implemented yet") 362 363 def predict(self, X): 364 """Predict the class of the data (Unused in this classifier) 365 366 Parameters 367 ---------- 368 X : numpy.ndarray 369 3D array where shape = (n_epochs, n_channels, n_samples) 370 371 Returns 372 ------- 373 prediction : Prediction 374 Predict object. Contains the predicted labels and and the probability. 375 Because this classifier chooses the P300 object with the highest posterior probability, 376 the probability is only the posterior probability of the chosen object. 377 378 """ 379 380 subset_X = self.get_subset(X, self.subset, self.channel_labels) 381 382 # Get posterior probability for each target 383 posterior_probabilities = self.clf.predict_proba(subset_X)[:, 1] 384 label = [int(np.argmax(posterior_probabilities) + 1)] 385 386 return Prediction(label, posterior_probabilities)
logger =
<bci_essentials.utils.logger.Logger object>
41class ErpRgClassifier(GenericClassifier): 42 """ERP RG Classifier class (*inherits from `GenericClassifier`*).""" 43 44 def set_p300_clf_settings( 45 self, 46 n_splits=3, 47 lico_expansion_factor=1, 48 oversample_ratio=0, 49 undersample_ratio=0, 50 random_seed=42, 51 covariance_estimator="oas", # Covariance estimator, see pyriemann Covariances 52 remove_flats=True, 53 ): 54 """Set P300 Classifier Settings. 55 56 Parameters 57 ---------- 58 n_splits : int, *optional* 59 Number of folds for cross-validation. 60 - Default is `3`. 61 lico_expansion_factor : int, *optional* 62 Linear Combination Oversampling expansion factor, which is the 63 factor by which the number of ERPs in the training set will be 64 expanded. 65 - Default is `1`. 66 oversample_ratio : float, *optional* 67 Traditional oversampling. Range is from from 0.1-1 resulting 68 from the ratio of erp to non-erp class. 0 for no oversampling. 69 - Default is `0`. 70 undersample_ratio : float, *optional* 71 Traditional undersampling. Range is from from 0.1-1 resulting 72 from the ratio of erp to non-erp class. 0 for no undersampling. 73 - Default is `0`. 74 random_seed : int, *optional* 75 Random seed. 76 - Default is `42`. 77 covariance_estimator : str, *optional* 78 Covariance estimator. See pyriemann Covariances. 79 - Default is `"oas"`. 80 remove_flats : bool, *optional* 81 Whether to remove flat channels. 82 - Default is `True`. 83 84 Returns 85 ------- 86 `None` 87 88 """ 89 self.n_splits = n_splits 90 self.lico_expansion_factor = lico_expansion_factor 91 self.oversample_ratio = oversample_ratio 92 self.undersample_ratio = undersample_ratio 93 self.random_seed = random_seed 94 self.covariance_estimator = covariance_estimator 95 96 # Define the classifier 97 self.clf = make_pipeline( 98 XdawnCovariances(estimator=self.covariance_estimator), 99 TangentSpace(metric="riemann"), 100 LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"), 101 ) 102 103 if remove_flats: 104 rf = FlatChannelRemover() 105 self.clf.steps.insert(0, ["Remove Flat Channels", rf]) 106 107 def fit( 108 self, 109 plot_cm=False, 110 plot_roc=False, 111 lico_expansion_factor=1, 112 ): 113 """Fit the model. 114 115 Parameters 116 ---------- 117 plot_cm : bool, *optional* 118 Whether to plot the confusion matrix during training. 119 - Default is `False`. 120 plot_roc : bool, *optional* 121 Whether to plot the ROC curve during training. 122 - Default is `False`. 123 lico_expansion_factor : int, *optional* 124 Linear combination oversampling expansion factor. 125 Determines the number of ERPs in the training set that will be expanded. 126 Higher value increases the oversampling, generating more synthetic 127 samples for the minority class. 128 - Default is `1`. 129 130 Returns 131 ------- 132 `None` 133 Models created used in `predict()`. 134 135 """ 136 logger.info("Fitting the model using RG") 137 logger.info("X shape: %s", self.X.shape) 138 logger.info("y shape: %s", self.y.shape) 139 140 # Define the strategy for cross validation 141 cv = StratifiedKFold( 142 n_splits=self.n_splits, shuffle=True, random_state=self.random_seed 143 ) 144 145 # Init predictions to all false 146 cv_preds = np.zeros(len(self.y)) 147 148 def __erp_rg_kernel(X, y): 149 """ERP RG kernel. 150 151 Parameters 152 ---------- 153 X : numpy.ndarray 154 Input features (ERP data) for training. 155 3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`). 156 E.g. (100, 32, 1000) for 100 trials, 32 channels and 1000 samples per channel. 157 158 y : numpy.ndarray 159 Target labels corresponding to the input features in `X`. 160 1D numpy array with shape (n_trails, ). 161 Each label indicates the class of the corresponding trial in `X`. 162 E.g. (100, ) for 100 trials. 163 164 165 Returns 166 ------- 167 kernelResults : KernelResults 168 KernelResults object containing the following attributes: 169 model : classifier 170 The trained classification model. 171 cv_preds : numpy.ndarray 172 The predictions from the model using cross validation. 173 1D array with the same shape as `y`. 174 accuracy : float 175 The accuracy of the trained classification model. 176 precision : float 177 The precision of the trained classification model. 178 recall : float 179 The recall of the trained classification model. 180 181 """ 182 for train_idx, test_idx in cv.split(X, y): 183 y_train, y_test = y[train_idx], y[test_idx] 184 185 X_train, X_test = X[train_idx], X[test_idx] 186 187 # LICO 188 logger.debug( 189 "Before LICO:\n\tShape X: %s\n\tShape y: %s", 190 X_train.shape, 191 y_train.shape, 192 ) 193 194 if sum(y_train) > 2: 195 if lico_expansion_factor > 1: 196 X_train, y_train = lico( 197 X_train, 198 y_train, 199 expansion_factor=lico_expansion_factor, 200 sum_num=2, 201 shuffle=False, 202 ) 203 logger.debug("y_train = %s", y_train) 204 205 logger.debug( 206 "After LICO:\n\tShape X: %s\n\tShape y: %s", 207 X_train.shape, 208 y_train.shape, 209 ) 210 211 # Oversampling 212 if self.oversample_ratio > 0: 213 p_count = sum(y_train) 214 n_count = len(y_train) - sum(y_train) 215 216 num_to_add = int( 217 np.floor((self.oversample_ratio * n_count) - p_count) 218 ) 219 220 # Add num_to_add random selections from the positive 221 true_X_train = X_train[y_train == 1] 222 223 len_X_train = len(true_X_train) 224 225 for s in range(num_to_add): 226 to_add_X = true_X_train[random.randrange(0, len_X_train), :, :] 227 228 X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0) 229 y_train = np.append(y_train, [1], axis=0) 230 231 # Undersampling 232 if self.undersample_ratio > 0: 233 p_count = sum(y_train) 234 n_count = len(y_train) - sum(y_train) 235 236 num_to_remove = int( 237 np.floor(n_count - (p_count / self.undersample_ratio)) 238 ) 239 240 ind_range = np.arange(len(y_train)) 241 ind_list = list(ind_range) 242 to_remove = [] 243 244 # Remove num_to_remove random selections from the negative 245 false_ind = list(ind_range[y_train == 0]) 246 247 for s in range(num_to_remove): 248 # select a random value from the list of false indices 249 remove_at = false_ind[random.randrange(0, len(false_ind))] 250 251 # remove that value from the false index list 252 false_ind.remove(remove_at) 253 254 # add the index to be removed to a list 255 to_remove.append(remove_at) 256 257 remaining_ind = ind_list 258 for i in range(len(to_remove)): 259 remaining_ind.remove(to_remove[i]) 260 261 X_train = X_train[remaining_ind, :, :] 262 y_train = y_train[remaining_ind] 263 264 self.clf.fit(X_train, y_train) 265 cv_preds[test_idx] = self.clf.predict(X_test) 266 predproba = self.clf.predict_proba(X_test) 267 268 # Use pred proba to show what would be predicted 269 predprobs = predproba[:, 1] 270 real = np.where(y_test == 1) 271 272 # TODO handle exception where two probabilities are the same 273 prediction = int(np.where(predprobs == np.amax(predprobs))[0][0]) 274 275 logger.debug("y_test = %s", y_test) 276 logger.debug("predproba = %s", predproba) 277 logger.debug("real = %s", real[0]) 278 logger.debug("prediction = %s", prediction) 279 280 # Train final model with all available data 281 self.clf.fit(X, y) 282 model = self.clf 283 284 accuracy = sum(cv_preds == self.y) / len(cv_preds) 285 precision = precision_score(self.y, cv_preds) 286 recall = recall_score(self.y, cv_preds) 287 288 return KernelResults(model, cv_preds, accuracy, precision, recall) 289 290 # Check if channel selection is true 291 if self.channel_selection_setup: 292 logger.info("Doing channel selection") 293 logger.debug("Initial subset: %s", self.chs_initial_subset) 294 295 channel_selection_results = channel_selection_by_method( 296 __erp_rg_kernel, 297 self.X, 298 self.y, 299 self.channel_labels, # kernel setup 300 self.chs_method, 301 self.chs_metric, 302 self.chs_initial_subset, # wrapper setup 303 self.chs_max_time, 304 self.chs_min_channels, 305 self.chs_max_channels, 306 self.chs_performance_delta, # stopping criterion 307 self.chs_n_jobs, 308 ) # njobs, output messages 309 310 preds = channel_selection_results.best_preds 311 accuracy = channel_selection_results.best_accuracy 312 precision = channel_selection_results.best_precision 313 recall = channel_selection_results.best_recall 314 315 logger.info( 316 "The optimal subset is %s", 317 channel_selection_results.best_channel_subset, 318 ) 319 320 self.results_df = channel_selection_results.results_df 321 self.subset = channel_selection_results.best_channel_subset 322 self.subset_defined = True 323 self.clf = channel_selection_results.best_model 324 else: 325 logger.warning("Not doing channel selection") 326 X = self.get_subset(self.X, self.subset, self.channel_labels) 327 328 current_results = __erp_rg_kernel(X, self.y) 329 self.clf = current_results.model 330 preds = current_results.cv_preds 331 accuracy = current_results.accuracy 332 precision = current_results.precision 333 recall = current_results.recall 334 335 # Log performance stats 336 # accuracy 337 accuracy = sum(preds == self.y) / len(preds) 338 self.offline_accuracy = accuracy 339 logger.info("Accuracy = %s", accuracy) 340 341 # precision 342 precision = precision_score(self.y, preds) 343 self.offline_precision = precision 344 logger.info("Precision = %s", precision) 345 346 # recall 347 recall = recall_score(self.y, preds) 348 self.offline_recall = recall 349 logger.info("Recall = %s", recall) 350 351 # confusion matrix in command line 352 cm = confusion_matrix(self.y, preds) 353 self.offline_cm = cm 354 logger.info("Confusion matrix:\n%s", cm) 355 356 if plot_cm: 357 cm = confusion_matrix(self.y, preds) 358 ConfusionMatrixDisplay(cm).plot() 359 plt.show() 360 361 if plot_roc: 362 logger.error("ROC plot has not been implemented yet") 363 364 def predict(self, X): 365 """Predict the class of the data (Unused in this classifier) 366 367 Parameters 368 ---------- 369 X : numpy.ndarray 370 3D array where shape = (n_epochs, n_channels, n_samples) 371 372 Returns 373 ------- 374 prediction : Prediction 375 Predict object. Contains the predicted labels and and the probability. 376 Because this classifier chooses the P300 object with the highest posterior probability, 377 the probability is only the posterior probability of the chosen object. 378 379 """ 380 381 subset_X = self.get_subset(X, self.subset, self.channel_labels) 382 383 # Get posterior probability for each target 384 posterior_probabilities = self.clf.predict_proba(subset_X)[:, 1] 385 label = [int(np.argmax(posterior_probabilities) + 1)] 386 387 return Prediction(label, posterior_probabilities)
ERP RG Classifier class (inherits from GenericClassifier).
def
set_p300_clf_settings( self, n_splits=3, lico_expansion_factor=1, oversample_ratio=0, undersample_ratio=0, random_seed=42, covariance_estimator='oas', remove_flats=True):
44 def set_p300_clf_settings( 45 self, 46 n_splits=3, 47 lico_expansion_factor=1, 48 oversample_ratio=0, 49 undersample_ratio=0, 50 random_seed=42, 51 covariance_estimator="oas", # Covariance estimator, see pyriemann Covariances 52 remove_flats=True, 53 ): 54 """Set P300 Classifier Settings. 55 56 Parameters 57 ---------- 58 n_splits : int, *optional* 59 Number of folds for cross-validation. 60 - Default is `3`. 61 lico_expansion_factor : int, *optional* 62 Linear Combination Oversampling expansion factor, which is the 63 factor by which the number of ERPs in the training set will be 64 expanded. 65 - Default is `1`. 66 oversample_ratio : float, *optional* 67 Traditional oversampling. Range is from from 0.1-1 resulting 68 from the ratio of erp to non-erp class. 0 for no oversampling. 69 - Default is `0`. 70 undersample_ratio : float, *optional* 71 Traditional undersampling. Range is from from 0.1-1 resulting 72 from the ratio of erp to non-erp class. 0 for no undersampling. 73 - Default is `0`. 74 random_seed : int, *optional* 75 Random seed. 76 - Default is `42`. 77 covariance_estimator : str, *optional* 78 Covariance estimator. See pyriemann Covariances. 79 - Default is `"oas"`. 80 remove_flats : bool, *optional* 81 Whether to remove flat channels. 82 - Default is `True`. 83 84 Returns 85 ------- 86 `None` 87 88 """ 89 self.n_splits = n_splits 90 self.lico_expansion_factor = lico_expansion_factor 91 self.oversample_ratio = oversample_ratio 92 self.undersample_ratio = undersample_ratio 93 self.random_seed = random_seed 94 self.covariance_estimator = covariance_estimator 95 96 # Define the classifier 97 self.clf = make_pipeline( 98 XdawnCovariances(estimator=self.covariance_estimator), 99 TangentSpace(metric="riemann"), 100 LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"), 101 ) 102 103 if remove_flats: 104 rf = FlatChannelRemover() 105 self.clf.steps.insert(0, ["Remove Flat Channels", rf])
Set P300 Classifier Settings.
Parameters
- n_splits (int, optional):
Number of folds for cross-validation.
- Default is
3.
- Default is
- lico_expansion_factor (int, optional):
Linear Combination Oversampling expansion factor, which is the
factor by which the number of ERPs in the training set will be
expanded.
- Default is
1.
- Default is
- oversample_ratio (float, optional):
Traditional oversampling. Range is from from 0.1-1 resulting
from the ratio of erp to non-erp class. 0 for no oversampling.
- Default is
0.
- Default is
- undersample_ratio (float, optional):
Traditional undersampling. Range is from from 0.1-1 resulting
from the ratio of erp to non-erp class. 0 for no undersampling.
- Default is
0.
- Default is
- random_seed (int, optional):
Random seed.
- Default is
42.
- Default is
- covariance_estimator (str, optional):
Covariance estimator. See pyriemann Covariances.
- Default is
"oas".
- Default is
- remove_flats (bool, optional):
Whether to remove flat channels.
- Default is
True.
- Default is
Returns
None
def
fit(self, plot_cm=False, plot_roc=False, lico_expansion_factor=1):
107 def fit( 108 self, 109 plot_cm=False, 110 plot_roc=False, 111 lico_expansion_factor=1, 112 ): 113 """Fit the model. 114 115 Parameters 116 ---------- 117 plot_cm : bool, *optional* 118 Whether to plot the confusion matrix during training. 119 - Default is `False`. 120 plot_roc : bool, *optional* 121 Whether to plot the ROC curve during training. 122 - Default is `False`. 123 lico_expansion_factor : int, *optional* 124 Linear combination oversampling expansion factor. 125 Determines the number of ERPs in the training set that will be expanded. 126 Higher value increases the oversampling, generating more synthetic 127 samples for the minority class. 128 - Default is `1`. 129 130 Returns 131 ------- 132 `None` 133 Models created used in `predict()`. 134 135 """ 136 logger.info("Fitting the model using RG") 137 logger.info("X shape: %s", self.X.shape) 138 logger.info("y shape: %s", self.y.shape) 139 140 # Define the strategy for cross validation 141 cv = StratifiedKFold( 142 n_splits=self.n_splits, shuffle=True, random_state=self.random_seed 143 ) 144 145 # Init predictions to all false 146 cv_preds = np.zeros(len(self.y)) 147 148 def __erp_rg_kernel(X, y): 149 """ERP RG kernel. 150 151 Parameters 152 ---------- 153 X : numpy.ndarray 154 Input features (ERP data) for training. 155 3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`). 156 E.g. (100, 32, 1000) for 100 trials, 32 channels and 1000 samples per channel. 157 158 y : numpy.ndarray 159 Target labels corresponding to the input features in `X`. 160 1D numpy array with shape (n_trails, ). 161 Each label indicates the class of the corresponding trial in `X`. 162 E.g. (100, ) for 100 trials. 163 164 165 Returns 166 ------- 167 kernelResults : KernelResults 168 KernelResults object containing the following attributes: 169 model : classifier 170 The trained classification model. 171 cv_preds : numpy.ndarray 172 The predictions from the model using cross validation. 173 1D array with the same shape as `y`. 174 accuracy : float 175 The accuracy of the trained classification model. 176 precision : float 177 The precision of the trained classification model. 178 recall : float 179 The recall of the trained classification model. 180 181 """ 182 for train_idx, test_idx in cv.split(X, y): 183 y_train, y_test = y[train_idx], y[test_idx] 184 185 X_train, X_test = X[train_idx], X[test_idx] 186 187 # LICO 188 logger.debug( 189 "Before LICO:\n\tShape X: %s\n\tShape y: %s", 190 X_train.shape, 191 y_train.shape, 192 ) 193 194 if sum(y_train) > 2: 195 if lico_expansion_factor > 1: 196 X_train, y_train = lico( 197 X_train, 198 y_train, 199 expansion_factor=lico_expansion_factor, 200 sum_num=2, 201 shuffle=False, 202 ) 203 logger.debug("y_train = %s", y_train) 204 205 logger.debug( 206 "After LICO:\n\tShape X: %s\n\tShape y: %s", 207 X_train.shape, 208 y_train.shape, 209 ) 210 211 # Oversampling 212 if self.oversample_ratio > 0: 213 p_count = sum(y_train) 214 n_count = len(y_train) - sum(y_train) 215 216 num_to_add = int( 217 np.floor((self.oversample_ratio * n_count) - p_count) 218 ) 219 220 # Add num_to_add random selections from the positive 221 true_X_train = X_train[y_train == 1] 222 223 len_X_train = len(true_X_train) 224 225 for s in range(num_to_add): 226 to_add_X = true_X_train[random.randrange(0, len_X_train), :, :] 227 228 X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0) 229 y_train = np.append(y_train, [1], axis=0) 230 231 # Undersampling 232 if self.undersample_ratio > 0: 233 p_count = sum(y_train) 234 n_count = len(y_train) - sum(y_train) 235 236 num_to_remove = int( 237 np.floor(n_count - (p_count / self.undersample_ratio)) 238 ) 239 240 ind_range = np.arange(len(y_train)) 241 ind_list = list(ind_range) 242 to_remove = [] 243 244 # Remove num_to_remove random selections from the negative 245 false_ind = list(ind_range[y_train == 0]) 246 247 for s in range(num_to_remove): 248 # select a random value from the list of false indices 249 remove_at = false_ind[random.randrange(0, len(false_ind))] 250 251 # remove that value from the false index list 252 false_ind.remove(remove_at) 253 254 # add the index to be removed to a list 255 to_remove.append(remove_at) 256 257 remaining_ind = ind_list 258 for i in range(len(to_remove)): 259 remaining_ind.remove(to_remove[i]) 260 261 X_train = X_train[remaining_ind, :, :] 262 y_train = y_train[remaining_ind] 263 264 self.clf.fit(X_train, y_train) 265 cv_preds[test_idx] = self.clf.predict(X_test) 266 predproba = self.clf.predict_proba(X_test) 267 268 # Use pred proba to show what would be predicted 269 predprobs = predproba[:, 1] 270 real = np.where(y_test == 1) 271 272 # TODO handle exception where two probabilities are the same 273 prediction = int(np.where(predprobs == np.amax(predprobs))[0][0]) 274 275 logger.debug("y_test = %s", y_test) 276 logger.debug("predproba = %s", predproba) 277 logger.debug("real = %s", real[0]) 278 logger.debug("prediction = %s", prediction) 279 280 # Train final model with all available data 281 self.clf.fit(X, y) 282 model = self.clf 283 284 accuracy = sum(cv_preds == self.y) / len(cv_preds) 285 precision = precision_score(self.y, cv_preds) 286 recall = recall_score(self.y, cv_preds) 287 288 return KernelResults(model, cv_preds, accuracy, precision, recall) 289 290 # Check if channel selection is true 291 if self.channel_selection_setup: 292 logger.info("Doing channel selection") 293 logger.debug("Initial subset: %s", self.chs_initial_subset) 294 295 channel_selection_results = channel_selection_by_method( 296 __erp_rg_kernel, 297 self.X, 298 self.y, 299 self.channel_labels, # kernel setup 300 self.chs_method, 301 self.chs_metric, 302 self.chs_initial_subset, # wrapper setup 303 self.chs_max_time, 304 self.chs_min_channels, 305 self.chs_max_channels, 306 self.chs_performance_delta, # stopping criterion 307 self.chs_n_jobs, 308 ) # njobs, output messages 309 310 preds = channel_selection_results.best_preds 311 accuracy = channel_selection_results.best_accuracy 312 precision = channel_selection_results.best_precision 313 recall = channel_selection_results.best_recall 314 315 logger.info( 316 "The optimal subset is %s", 317 channel_selection_results.best_channel_subset, 318 ) 319 320 self.results_df = channel_selection_results.results_df 321 self.subset = channel_selection_results.best_channel_subset 322 self.subset_defined = True 323 self.clf = channel_selection_results.best_model 324 else: 325 logger.warning("Not doing channel selection") 326 X = self.get_subset(self.X, self.subset, self.channel_labels) 327 328 current_results = __erp_rg_kernel(X, self.y) 329 self.clf = current_results.model 330 preds = current_results.cv_preds 331 accuracy = current_results.accuracy 332 precision = current_results.precision 333 recall = current_results.recall 334 335 # Log performance stats 336 # accuracy 337 accuracy = sum(preds == self.y) / len(preds) 338 self.offline_accuracy = accuracy 339 logger.info("Accuracy = %s", accuracy) 340 341 # precision 342 precision = precision_score(self.y, preds) 343 self.offline_precision = precision 344 logger.info("Precision = %s", precision) 345 346 # recall 347 recall = recall_score(self.y, preds) 348 self.offline_recall = recall 349 logger.info("Recall = %s", recall) 350 351 # confusion matrix in command line 352 cm = confusion_matrix(self.y, preds) 353 self.offline_cm = cm 354 logger.info("Confusion matrix:\n%s", cm) 355 356 if plot_cm: 357 cm = confusion_matrix(self.y, preds) 358 ConfusionMatrixDisplay(cm).plot() 359 plt.show() 360 361 if plot_roc: 362 logger.error("ROC plot has not been implemented yet")
Fit the model.
Parameters
- plot_cm (bool, optional):
Whether to plot the confusion matrix during training.
- Default is
False.
- Default is
- plot_roc (bool, optional):
Whether to plot the ROC curve during training.
- Default is
False.
- Default is
- lico_expansion_factor (int, optional):
Linear combination oversampling expansion factor.
Determines the number of ERPs in the training set that will be expanded.
Higher value increases the oversampling, generating more synthetic
samples for the minority class.
- Default is
1.
- Default is
Returns
None: Models created used inpredict().
def
predict(self, X):
364 def predict(self, X): 365 """Predict the class of the data (Unused in this classifier) 366 367 Parameters 368 ---------- 369 X : numpy.ndarray 370 3D array where shape = (n_epochs, n_channels, n_samples) 371 372 Returns 373 ------- 374 prediction : Prediction 375 Predict object. Contains the predicted labels and and the probability. 376 Because this classifier chooses the P300 object with the highest posterior probability, 377 the probability is only the posterior probability of the chosen object. 378 379 """ 380 381 subset_X = self.get_subset(X, self.subset, self.channel_labels) 382 383 # Get posterior probability for each target 384 posterior_probabilities = self.clf.predict_proba(subset_X)[:, 1] 385 label = [int(np.argmax(posterior_probabilities) + 1)] 386 387 return Prediction(label, posterior_probabilities)
Predict the class of the data (Unused in this classifier)
Parameters
- X (numpy.ndarray): 3D array where shape = (n_epochs, n_channels, n_samples)
Returns
- prediction (Prediction): Predict object. Contains the predicted labels and and the probability. Because this classifier chooses the P300 object with the highest posterior probability, the probability is only the posterior probability of the chosen object.