bci_essentials.classification.erp_rg_classifier
ERP RG Classifier
This classifier is used to classify ERPs using the Riemannian Geometry approach.
1"""**ERP RG Classifier** 2 3This classifier is used to classify ERPs using the Riemannian Geometry 4approach. 5 6""" 7 8# Stock libraries 9import random 10import numpy as np 11import matplotlib.pyplot as plt 12from sklearn.pipeline import make_pipeline 13from sklearn.model_selection import StratifiedKFold 14from sklearn.metrics import ( 15 confusion_matrix, 16 ConfusionMatrixDisplay, 17 precision_score, 18 recall_score, 19) 20from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 21from pyriemann.estimation import XdawnCovariances 22from pyriemann.tangentspace import TangentSpace 23from pyriemann.channelselection import FlatChannelRemover 24 25 26# Import bci_essentials modules and methods 27from ..classification.generic_classifier import ( 28 GenericClassifier, 29 Prediction, 30 KernelResults, 31) 32from ..signal_processing import lico 33from ..channel_selection import channel_selection_by_method 34from ..utils.logger import Logger # Logger wrapper 35 36# Instantiate a logger for the module at the default level of logging.INFO 37# Logs to bci_essentials.__module__) where __module__ is the name of the module 38logger = Logger(name=__name__) 39 40 41class ErpRgClassifier(GenericClassifier): 42 """ERP RG Classifier class (*inherits from `GenericClassifier`*).""" 43 44 def set_p300_clf_settings( 45 self, 46 n_splits=3, 47 lico_expansion_factor=1, 48 oversample_ratio=0, 49 undersample_ratio=0, 50 random_seed=42, 51 covariance_estimator="oas", # Covariance estimator, see pyriemann Covariances 52 remove_flats=True, 53 ): 54 """Set P300 Classifier Settings. 55 56 Parameters 57 ---------- 58 n_splits : int, *optional* 59 Number of folds for cross-validation. 60 - Default is `3`. 61 lico_expansion_factor : int, *optional* 62 Linear Combination Oversampling expansion factor, which is the 63 factor by which the number of ERPs in the training set will be 64 expanded. 65 - Default is `1`. 66 oversample_ratio : float, *optional* 67 Traditional oversampling. Range is from from 0.1-1 resulting 68 from the ratio of erp to non-erp class. 0 for no oversampling. 69 - Default is `0`. 70 undersample_ratio : float, *optional* 71 Traditional undersampling. Range is from from 0.1-1 resulting 72 from the ratio of erp to non-erp class. 0 for no undersampling. 73 - Default is `0`. 74 random_seed : int, *optional* 75 Random seed. 76 - Default is `42`. 77 covariance_estimator : str, *optional* 78 Covariance estimator. See pyriemann Covariances. 79 - Default is `"oas"`. 80 remove_flats : bool, *optional* 81 Whether to remove flat channels. 82 - Default is `True`. 83 84 Returns 85 ------- 86 `None` 87 88 """ 89 self.n_splits = n_splits 90 self.lico_expansion_factor = lico_expansion_factor 91 self.oversample_ratio = oversample_ratio 92 self.undersample_ratio = undersample_ratio 93 self.random_seed = random_seed 94 self.covariance_estimator = covariance_estimator 95 96 # Define the classifier 97 self.clf = make_pipeline( 98 XdawnCovariances(estimator=self.covariance_estimator), 99 TangentSpace(metric="riemann"), 100 LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"), 101 ) 102 103 if remove_flats: 104 rf = FlatChannelRemover() 105 self.clf.steps.insert(0, ["Remove Flat Channels", rf]) 106 107 def fit( 108 self, 109 plot_cm=False, 110 plot_roc=False, 111 lico_expansion_factor=1, 112 ): 113 """Fit the model. 114 115 Parameters 116 ---------- 117 plot_cm : bool, *optional* 118 Whether to plot the confusion matrix during training. 119 - Default is `False`. 120 plot_roc : bool, *optional* 121 Whether to plot the ROC curve during training. 122 - Default is `False`. 123 lico_expansion_factor : int, *optional* 124 Linear combination oversampling expansion factor. 125 Determines the number of ERPs in the training set that will be expanded. 126 Higher value increases the oversampling, generating more synthetic 127 samples for the minority class. 128 - Default is `1`. 129 130 Returns 131 ------- 132 `None` 133 Models created used in `predict()`. 134 135 """ 136 logger.info("Fitting the model using RG") 137 logger.info("X shape: %s", self.X.shape) 138 logger.info("y shape: %s", self.y.shape) 139 140 # Define the strategy for cross validation 141 cv = StratifiedKFold( 142 n_splits=self.n_splits, shuffle=True, random_state=self.random_seed 143 ) 144 145 # Init predictions to all false 146 cv_preds = np.zeros(len(self.y)) 147 148 def __erp_rg_kernel(X, y): 149 """ERP RG kernel. 150 151 Parameters 152 ---------- 153 X : numpy.ndarray 154 Input features (ERP data) for training. 155 3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`). 156 E.g. (100, 32, 1000) for 100 trials, 32 channels and 1000 samples per channel. 157 158 y : numpy.ndarray 159 Target labels corresponding to the input features in `X`. 160 1D numpy array with shape (n_trails, ). 161 Each label indicates the class of the corresponding trial in `X`. 162 E.g. (100, ) for 100 trials. 163 164 165 Returns 166 ------- 167 kernelResults : KernelResults 168 KernelResults object containing the following attributes: 169 model : classifier 170 The trained classification model. 171 cv_preds : numpy.ndarray 172 The predictions from the model using cross validation. 173 1D array with the same shape as `y`. 174 accuracy : float 175 The accuracy of the trained classification model. 176 precision : float 177 The precision of the trained classification model. 178 recall : float 179 The recall of the trained classification model. 180 181 """ 182 for train_idx, test_idx in cv.split(X, y): 183 y_train, y_test = y[train_idx], y[test_idx] 184 185 X_train, X_test = X[train_idx], X[test_idx] 186 187 # LICO 188 logger.debug( 189 "Before LICO:\n\tShape X: %s\n\tShape y: %s", 190 X_train.shape, 191 y_train.shape, 192 ) 193 194 if sum(y_train) > 2: 195 if lico_expansion_factor > 1: 196 X_train, y_train = lico( 197 X_train, 198 y_train, 199 expansion_factor=lico_expansion_factor, 200 sum_num=2, 201 shuffle=False, 202 ) 203 logger.debug("y_train = %s", y_train) 204 205 logger.debug( 206 "After LICO:\n\tShape X: %s\n\tShape y: %s", 207 X_train.shape, 208 y_train.shape, 209 ) 210 211 # Oversampling 212 if self.oversample_ratio > 0: 213 p_count = sum(y_train) 214 n_count = len(y_train) - sum(y_train) 215 216 num_to_add = int( 217 np.floor((self.oversample_ratio * n_count) - p_count) 218 ) 219 220 # Add num_to_add random selections from the positive 221 true_X_train = X_train[y_train == 1] 222 223 len_X_train = len(true_X_train) 224 225 for s in range(num_to_add): 226 to_add_X = true_X_train[random.randrange(0, len_X_train), :, :] 227 228 X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0) 229 y_train = np.append(y_train, [1], axis=0) 230 231 # Undersampling 232 if self.undersample_ratio > 0: 233 p_count = sum(y_train) 234 n_count = len(y_train) - sum(y_train) 235 236 num_to_remove = int( 237 np.floor(n_count - (p_count / self.undersample_ratio)) 238 ) 239 240 ind_range = np.arange(len(y_train)) 241 ind_list = list(ind_range) 242 to_remove = [] 243 244 # Remove num_to_remove random selections from the negative 245 false_ind = list(ind_range[y_train == 0]) 246 247 for s in range(num_to_remove): 248 # select a random value from the list of false indices 249 remove_at = false_ind[random.randrange(0, len(false_ind))] 250 251 # remove that value from the false index list 252 false_ind.remove(remove_at) 253 254 # add the index to be removed to a list 255 to_remove.append(remove_at) 256 257 remaining_ind = ind_list 258 for i in range(len(to_remove)): 259 remaining_ind.remove(to_remove[i]) 260 261 X_train = X_train[remaining_ind, :, :] 262 y_train = y_train[remaining_ind] 263 264 self.clf.fit(X_train, y_train) 265 cv_preds[test_idx] = self.clf.predict(X_test) 266 predproba = self.clf.predict_proba(X_test) 267 268 # Use pred proba to show what would be predicted 269 predprobs = predproba[:, 1] 270 real = np.where(y_test == 1) 271 272 # TODO handle exception where two probabilities are the same 273 prediction = int(np.where(predprobs == np.amax(predprobs))[0][0]) 274 275 logger.debug("y_test = %s", y_test) 276 logger.debug("predproba = %s", predproba) 277 logger.debug("real = %s", real[0]) 278 logger.debug("prediction = %s", prediction) 279 280 # Train final model with all available data 281 self.clf.fit(X, y) 282 model = self.clf 283 284 accuracy = sum(cv_preds == self.y) / len(cv_preds) 285 precision = precision_score(self.y, cv_preds) 286 recall = recall_score(self.y, cv_preds) 287 288 return KernelResults(model, cv_preds, accuracy, precision, recall) 289 290 # Check if channel selection is true 291 if self.channel_selection_setup: 292 logger.info("Doing channel selection") 293 logger.debug("Initial subset: %s", self.chs_initial_subset) 294 295 channel_selection_results = channel_selection_by_method( 296 __erp_rg_kernel, 297 self.X, 298 self.y, 299 self.channel_labels, # kernel setup 300 self.chs_method, 301 self.chs_metric, 302 self.chs_initial_subset, # wrapper setup 303 self.chs_max_time, 304 self.chs_min_channels, 305 self.chs_max_channels, 306 self.chs_performance_delta, # stopping criterion 307 self.chs_n_jobs, 308 ) # njobs, output messages 309 310 preds = channel_selection_results.best_preds 311 accuracy = channel_selection_results.best_accuracy 312 precision = channel_selection_results.best_precision 313 recall = channel_selection_results.best_recall 314 315 logger.info( 316 "The optimal subset is %s", 317 channel_selection_results.best_channel_subset, 318 ) 319 320 self.results_df = channel_selection_results.results_df 321 self.subset = channel_selection_results.best_channel_subset 322 self.subset_defined = True 323 self.clf = channel_selection_results.best_model 324 else: 325 logger.warning("Not doing channel selection") 326 X = self.get_subset(self.X, self.subset, self.channel_labels) 327 328 current_results = __erp_rg_kernel(X, self.y) 329 self.clf = current_results.model 330 preds = current_results.cv_preds 331 accuracy = current_results.accuracy 332 precision = current_results.precision 333 recall = current_results.recall 334 335 # Log performance stats 336 # accuracy 337 accuracy = sum(preds == self.y) / len(preds) 338 self.offline_accuracy = accuracy 339 logger.info("Accuracy = %s", accuracy) 340 341 # precision 342 precision = precision_score(self.y, preds) 343 self.offline_precision = precision 344 logger.info("Precision = %s", precision) 345 346 # recall 347 recall = recall_score(self.y, preds) 348 self.offline_recall = recall 349 logger.info("Recall = %s", recall) 350 351 # confusion matrix in command line 352 cm = confusion_matrix(self.y, preds) 353 self.offline_cm = cm 354 logger.info("Confusion matrix:\n%s", cm) 355 356 if plot_cm: 357 cm = confusion_matrix(self.y, preds) 358 ConfusionMatrixDisplay(cm).plot() 359 plt.show() 360 361 if plot_roc: 362 logger.error("ROC plot has not been implemented yet") 363 364 def predict(self, X): 365 """Predict the class of the data (Unused in this classifier) 366 367 Parameters 368 ---------- 369 X : numpy.ndarray 370 3D array where shape = (n_epochs, n_channels, n_samples) 371 372 Returns 373 ------- 374 prediction : Prediction 375 Predict object. Contains the predicted labels and and the probability. 376 Because this classifier chooses the P300 object with the highest posterior probability, 377 the probability is only the posterior probability of the chosen object. 378 379 """ 380 381 subset_X = self.get_subset(X, self.subset, self.channel_labels) 382 383 # Get posterior probability for each target 384 posterior_prob = self.clf.predict_proba(subset_X)[:, 1] 385 386 label = [int(np.argmax(posterior_prob))] 387 probability = [np.max(posterior_prob)] 388 389 return Prediction(label, probability)
logger =
<bci_essentials.utils.logger.Logger object>
42class ErpRgClassifier(GenericClassifier): 43 """ERP RG Classifier class (*inherits from `GenericClassifier`*).""" 44 45 def set_p300_clf_settings( 46 self, 47 n_splits=3, 48 lico_expansion_factor=1, 49 oversample_ratio=0, 50 undersample_ratio=0, 51 random_seed=42, 52 covariance_estimator="oas", # Covariance estimator, see pyriemann Covariances 53 remove_flats=True, 54 ): 55 """Set P300 Classifier Settings. 56 57 Parameters 58 ---------- 59 n_splits : int, *optional* 60 Number of folds for cross-validation. 61 - Default is `3`. 62 lico_expansion_factor : int, *optional* 63 Linear Combination Oversampling expansion factor, which is the 64 factor by which the number of ERPs in the training set will be 65 expanded. 66 - Default is `1`. 67 oversample_ratio : float, *optional* 68 Traditional oversampling. Range is from from 0.1-1 resulting 69 from the ratio of erp to non-erp class. 0 for no oversampling. 70 - Default is `0`. 71 undersample_ratio : float, *optional* 72 Traditional undersampling. Range is from from 0.1-1 resulting 73 from the ratio of erp to non-erp class. 0 for no undersampling. 74 - Default is `0`. 75 random_seed : int, *optional* 76 Random seed. 77 - Default is `42`. 78 covariance_estimator : str, *optional* 79 Covariance estimator. See pyriemann Covariances. 80 - Default is `"oas"`. 81 remove_flats : bool, *optional* 82 Whether to remove flat channels. 83 - Default is `True`. 84 85 Returns 86 ------- 87 `None` 88 89 """ 90 self.n_splits = n_splits 91 self.lico_expansion_factor = lico_expansion_factor 92 self.oversample_ratio = oversample_ratio 93 self.undersample_ratio = undersample_ratio 94 self.random_seed = random_seed 95 self.covariance_estimator = covariance_estimator 96 97 # Define the classifier 98 self.clf = make_pipeline( 99 XdawnCovariances(estimator=self.covariance_estimator), 100 TangentSpace(metric="riemann"), 101 LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"), 102 ) 103 104 if remove_flats: 105 rf = FlatChannelRemover() 106 self.clf.steps.insert(0, ["Remove Flat Channels", rf]) 107 108 def fit( 109 self, 110 plot_cm=False, 111 plot_roc=False, 112 lico_expansion_factor=1, 113 ): 114 """Fit the model. 115 116 Parameters 117 ---------- 118 plot_cm : bool, *optional* 119 Whether to plot the confusion matrix during training. 120 - Default is `False`. 121 plot_roc : bool, *optional* 122 Whether to plot the ROC curve during training. 123 - Default is `False`. 124 lico_expansion_factor : int, *optional* 125 Linear combination oversampling expansion factor. 126 Determines the number of ERPs in the training set that will be expanded. 127 Higher value increases the oversampling, generating more synthetic 128 samples for the minority class. 129 - Default is `1`. 130 131 Returns 132 ------- 133 `None` 134 Models created used in `predict()`. 135 136 """ 137 logger.info("Fitting the model using RG") 138 logger.info("X shape: %s", self.X.shape) 139 logger.info("y shape: %s", self.y.shape) 140 141 # Define the strategy for cross validation 142 cv = StratifiedKFold( 143 n_splits=self.n_splits, shuffle=True, random_state=self.random_seed 144 ) 145 146 # Init predictions to all false 147 cv_preds = np.zeros(len(self.y)) 148 149 def __erp_rg_kernel(X, y): 150 """ERP RG kernel. 151 152 Parameters 153 ---------- 154 X : numpy.ndarray 155 Input features (ERP data) for training. 156 3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`). 157 E.g. (100, 32, 1000) for 100 trials, 32 channels and 1000 samples per channel. 158 159 y : numpy.ndarray 160 Target labels corresponding to the input features in `X`. 161 1D numpy array with shape (n_trails, ). 162 Each label indicates the class of the corresponding trial in `X`. 163 E.g. (100, ) for 100 trials. 164 165 166 Returns 167 ------- 168 kernelResults : KernelResults 169 KernelResults object containing the following attributes: 170 model : classifier 171 The trained classification model. 172 cv_preds : numpy.ndarray 173 The predictions from the model using cross validation. 174 1D array with the same shape as `y`. 175 accuracy : float 176 The accuracy of the trained classification model. 177 precision : float 178 The precision of the trained classification model. 179 recall : float 180 The recall of the trained classification model. 181 182 """ 183 for train_idx, test_idx in cv.split(X, y): 184 y_train, y_test = y[train_idx], y[test_idx] 185 186 X_train, X_test = X[train_idx], X[test_idx] 187 188 # LICO 189 logger.debug( 190 "Before LICO:\n\tShape X: %s\n\tShape y: %s", 191 X_train.shape, 192 y_train.shape, 193 ) 194 195 if sum(y_train) > 2: 196 if lico_expansion_factor > 1: 197 X_train, y_train = lico( 198 X_train, 199 y_train, 200 expansion_factor=lico_expansion_factor, 201 sum_num=2, 202 shuffle=False, 203 ) 204 logger.debug("y_train = %s", y_train) 205 206 logger.debug( 207 "After LICO:\n\tShape X: %s\n\tShape y: %s", 208 X_train.shape, 209 y_train.shape, 210 ) 211 212 # Oversampling 213 if self.oversample_ratio > 0: 214 p_count = sum(y_train) 215 n_count = len(y_train) - sum(y_train) 216 217 num_to_add = int( 218 np.floor((self.oversample_ratio * n_count) - p_count) 219 ) 220 221 # Add num_to_add random selections from the positive 222 true_X_train = X_train[y_train == 1] 223 224 len_X_train = len(true_X_train) 225 226 for s in range(num_to_add): 227 to_add_X = true_X_train[random.randrange(0, len_X_train), :, :] 228 229 X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0) 230 y_train = np.append(y_train, [1], axis=0) 231 232 # Undersampling 233 if self.undersample_ratio > 0: 234 p_count = sum(y_train) 235 n_count = len(y_train) - sum(y_train) 236 237 num_to_remove = int( 238 np.floor(n_count - (p_count / self.undersample_ratio)) 239 ) 240 241 ind_range = np.arange(len(y_train)) 242 ind_list = list(ind_range) 243 to_remove = [] 244 245 # Remove num_to_remove random selections from the negative 246 false_ind = list(ind_range[y_train == 0]) 247 248 for s in range(num_to_remove): 249 # select a random value from the list of false indices 250 remove_at = false_ind[random.randrange(0, len(false_ind))] 251 252 # remove that value from the false index list 253 false_ind.remove(remove_at) 254 255 # add the index to be removed to a list 256 to_remove.append(remove_at) 257 258 remaining_ind = ind_list 259 for i in range(len(to_remove)): 260 remaining_ind.remove(to_remove[i]) 261 262 X_train = X_train[remaining_ind, :, :] 263 y_train = y_train[remaining_ind] 264 265 self.clf.fit(X_train, y_train) 266 cv_preds[test_idx] = self.clf.predict(X_test) 267 predproba = self.clf.predict_proba(X_test) 268 269 # Use pred proba to show what would be predicted 270 predprobs = predproba[:, 1] 271 real = np.where(y_test == 1) 272 273 # TODO handle exception where two probabilities are the same 274 prediction = int(np.where(predprobs == np.amax(predprobs))[0][0]) 275 276 logger.debug("y_test = %s", y_test) 277 logger.debug("predproba = %s", predproba) 278 logger.debug("real = %s", real[0]) 279 logger.debug("prediction = %s", prediction) 280 281 # Train final model with all available data 282 self.clf.fit(X, y) 283 model = self.clf 284 285 accuracy = sum(cv_preds == self.y) / len(cv_preds) 286 precision = precision_score(self.y, cv_preds) 287 recall = recall_score(self.y, cv_preds) 288 289 return KernelResults(model, cv_preds, accuracy, precision, recall) 290 291 # Check if channel selection is true 292 if self.channel_selection_setup: 293 logger.info("Doing channel selection") 294 logger.debug("Initial subset: %s", self.chs_initial_subset) 295 296 channel_selection_results = channel_selection_by_method( 297 __erp_rg_kernel, 298 self.X, 299 self.y, 300 self.channel_labels, # kernel setup 301 self.chs_method, 302 self.chs_metric, 303 self.chs_initial_subset, # wrapper setup 304 self.chs_max_time, 305 self.chs_min_channels, 306 self.chs_max_channels, 307 self.chs_performance_delta, # stopping criterion 308 self.chs_n_jobs, 309 ) # njobs, output messages 310 311 preds = channel_selection_results.best_preds 312 accuracy = channel_selection_results.best_accuracy 313 precision = channel_selection_results.best_precision 314 recall = channel_selection_results.best_recall 315 316 logger.info( 317 "The optimal subset is %s", 318 channel_selection_results.best_channel_subset, 319 ) 320 321 self.results_df = channel_selection_results.results_df 322 self.subset = channel_selection_results.best_channel_subset 323 self.subset_defined = True 324 self.clf = channel_selection_results.best_model 325 else: 326 logger.warning("Not doing channel selection") 327 X = self.get_subset(self.X, self.subset, self.channel_labels) 328 329 current_results = __erp_rg_kernel(X, self.y) 330 self.clf = current_results.model 331 preds = current_results.cv_preds 332 accuracy = current_results.accuracy 333 precision = current_results.precision 334 recall = current_results.recall 335 336 # Log performance stats 337 # accuracy 338 accuracy = sum(preds == self.y) / len(preds) 339 self.offline_accuracy = accuracy 340 logger.info("Accuracy = %s", accuracy) 341 342 # precision 343 precision = precision_score(self.y, preds) 344 self.offline_precision = precision 345 logger.info("Precision = %s", precision) 346 347 # recall 348 recall = recall_score(self.y, preds) 349 self.offline_recall = recall 350 logger.info("Recall = %s", recall) 351 352 # confusion matrix in command line 353 cm = confusion_matrix(self.y, preds) 354 self.offline_cm = cm 355 logger.info("Confusion matrix:\n%s", cm) 356 357 if plot_cm: 358 cm = confusion_matrix(self.y, preds) 359 ConfusionMatrixDisplay(cm).plot() 360 plt.show() 361 362 if plot_roc: 363 logger.error("ROC plot has not been implemented yet") 364 365 def predict(self, X): 366 """Predict the class of the data (Unused in this classifier) 367 368 Parameters 369 ---------- 370 X : numpy.ndarray 371 3D array where shape = (n_epochs, n_channels, n_samples) 372 373 Returns 374 ------- 375 prediction : Prediction 376 Predict object. Contains the predicted labels and and the probability. 377 Because this classifier chooses the P300 object with the highest posterior probability, 378 the probability is only the posterior probability of the chosen object. 379 380 """ 381 382 subset_X = self.get_subset(X, self.subset, self.channel_labels) 383 384 # Get posterior probability for each target 385 posterior_prob = self.clf.predict_proba(subset_X)[:, 1] 386 387 label = [int(np.argmax(posterior_prob))] 388 probability = [np.max(posterior_prob)] 389 390 return Prediction(label, probability)
ERP RG Classifier class (inherits from GenericClassifier).
def
set_p300_clf_settings( self, n_splits=3, lico_expansion_factor=1, oversample_ratio=0, undersample_ratio=0, random_seed=42, covariance_estimator='oas', remove_flats=True):
45 def set_p300_clf_settings( 46 self, 47 n_splits=3, 48 lico_expansion_factor=1, 49 oversample_ratio=0, 50 undersample_ratio=0, 51 random_seed=42, 52 covariance_estimator="oas", # Covariance estimator, see pyriemann Covariances 53 remove_flats=True, 54 ): 55 """Set P300 Classifier Settings. 56 57 Parameters 58 ---------- 59 n_splits : int, *optional* 60 Number of folds for cross-validation. 61 - Default is `3`. 62 lico_expansion_factor : int, *optional* 63 Linear Combination Oversampling expansion factor, which is the 64 factor by which the number of ERPs in the training set will be 65 expanded. 66 - Default is `1`. 67 oversample_ratio : float, *optional* 68 Traditional oversampling. Range is from from 0.1-1 resulting 69 from the ratio of erp to non-erp class. 0 for no oversampling. 70 - Default is `0`. 71 undersample_ratio : float, *optional* 72 Traditional undersampling. Range is from from 0.1-1 resulting 73 from the ratio of erp to non-erp class. 0 for no undersampling. 74 - Default is `0`. 75 random_seed : int, *optional* 76 Random seed. 77 - Default is `42`. 78 covariance_estimator : str, *optional* 79 Covariance estimator. See pyriemann Covariances. 80 - Default is `"oas"`. 81 remove_flats : bool, *optional* 82 Whether to remove flat channels. 83 - Default is `True`. 84 85 Returns 86 ------- 87 `None` 88 89 """ 90 self.n_splits = n_splits 91 self.lico_expansion_factor = lico_expansion_factor 92 self.oversample_ratio = oversample_ratio 93 self.undersample_ratio = undersample_ratio 94 self.random_seed = random_seed 95 self.covariance_estimator = covariance_estimator 96 97 # Define the classifier 98 self.clf = make_pipeline( 99 XdawnCovariances(estimator=self.covariance_estimator), 100 TangentSpace(metric="riemann"), 101 LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"), 102 ) 103 104 if remove_flats: 105 rf = FlatChannelRemover() 106 self.clf.steps.insert(0, ["Remove Flat Channels", rf])
Set P300 Classifier Settings.
Parameters
- n_splits (int, optional):
Number of folds for cross-validation.
- Default is
3.
- Default is
- lico_expansion_factor (int, optional):
Linear Combination Oversampling expansion factor, which is the
factor by which the number of ERPs in the training set will be
expanded.
- Default is
1.
- Default is
- oversample_ratio (float, optional):
Traditional oversampling. Range is from from 0.1-1 resulting
from the ratio of erp to non-erp class. 0 for no oversampling.
- Default is
0.
- Default is
- undersample_ratio (float, optional):
Traditional undersampling. Range is from from 0.1-1 resulting
from the ratio of erp to non-erp class. 0 for no undersampling.
- Default is
0.
- Default is
- random_seed (int, optional):
Random seed.
- Default is
42.
- Default is
- covariance_estimator (str, optional):
Covariance estimator. See pyriemann Covariances.
- Default is
"oas".
- Default is
- remove_flats (bool, optional):
Whether to remove flat channels.
- Default is
True.
- Default is
Returns
None
def
fit(self, plot_cm=False, plot_roc=False, lico_expansion_factor=1):
108 def fit( 109 self, 110 plot_cm=False, 111 plot_roc=False, 112 lico_expansion_factor=1, 113 ): 114 """Fit the model. 115 116 Parameters 117 ---------- 118 plot_cm : bool, *optional* 119 Whether to plot the confusion matrix during training. 120 - Default is `False`. 121 plot_roc : bool, *optional* 122 Whether to plot the ROC curve during training. 123 - Default is `False`. 124 lico_expansion_factor : int, *optional* 125 Linear combination oversampling expansion factor. 126 Determines the number of ERPs in the training set that will be expanded. 127 Higher value increases the oversampling, generating more synthetic 128 samples for the minority class. 129 - Default is `1`. 130 131 Returns 132 ------- 133 `None` 134 Models created used in `predict()`. 135 136 """ 137 logger.info("Fitting the model using RG") 138 logger.info("X shape: %s", self.X.shape) 139 logger.info("y shape: %s", self.y.shape) 140 141 # Define the strategy for cross validation 142 cv = StratifiedKFold( 143 n_splits=self.n_splits, shuffle=True, random_state=self.random_seed 144 ) 145 146 # Init predictions to all false 147 cv_preds = np.zeros(len(self.y)) 148 149 def __erp_rg_kernel(X, y): 150 """ERP RG kernel. 151 152 Parameters 153 ---------- 154 X : numpy.ndarray 155 Input features (ERP data) for training. 156 3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`). 157 E.g. (100, 32, 1000) for 100 trials, 32 channels and 1000 samples per channel. 158 159 y : numpy.ndarray 160 Target labels corresponding to the input features in `X`. 161 1D numpy array with shape (n_trails, ). 162 Each label indicates the class of the corresponding trial in `X`. 163 E.g. (100, ) for 100 trials. 164 165 166 Returns 167 ------- 168 kernelResults : KernelResults 169 KernelResults object containing the following attributes: 170 model : classifier 171 The trained classification model. 172 cv_preds : numpy.ndarray 173 The predictions from the model using cross validation. 174 1D array with the same shape as `y`. 175 accuracy : float 176 The accuracy of the trained classification model. 177 precision : float 178 The precision of the trained classification model. 179 recall : float 180 The recall of the trained classification model. 181 182 """ 183 for train_idx, test_idx in cv.split(X, y): 184 y_train, y_test = y[train_idx], y[test_idx] 185 186 X_train, X_test = X[train_idx], X[test_idx] 187 188 # LICO 189 logger.debug( 190 "Before LICO:\n\tShape X: %s\n\tShape y: %s", 191 X_train.shape, 192 y_train.shape, 193 ) 194 195 if sum(y_train) > 2: 196 if lico_expansion_factor > 1: 197 X_train, y_train = lico( 198 X_train, 199 y_train, 200 expansion_factor=lico_expansion_factor, 201 sum_num=2, 202 shuffle=False, 203 ) 204 logger.debug("y_train = %s", y_train) 205 206 logger.debug( 207 "After LICO:\n\tShape X: %s\n\tShape y: %s", 208 X_train.shape, 209 y_train.shape, 210 ) 211 212 # Oversampling 213 if self.oversample_ratio > 0: 214 p_count = sum(y_train) 215 n_count = len(y_train) - sum(y_train) 216 217 num_to_add = int( 218 np.floor((self.oversample_ratio * n_count) - p_count) 219 ) 220 221 # Add num_to_add random selections from the positive 222 true_X_train = X_train[y_train == 1] 223 224 len_X_train = len(true_X_train) 225 226 for s in range(num_to_add): 227 to_add_X = true_X_train[random.randrange(0, len_X_train), :, :] 228 229 X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0) 230 y_train = np.append(y_train, [1], axis=0) 231 232 # Undersampling 233 if self.undersample_ratio > 0: 234 p_count = sum(y_train) 235 n_count = len(y_train) - sum(y_train) 236 237 num_to_remove = int( 238 np.floor(n_count - (p_count / self.undersample_ratio)) 239 ) 240 241 ind_range = np.arange(len(y_train)) 242 ind_list = list(ind_range) 243 to_remove = [] 244 245 # Remove num_to_remove random selections from the negative 246 false_ind = list(ind_range[y_train == 0]) 247 248 for s in range(num_to_remove): 249 # select a random value from the list of false indices 250 remove_at = false_ind[random.randrange(0, len(false_ind))] 251 252 # remove that value from the false index list 253 false_ind.remove(remove_at) 254 255 # add the index to be removed to a list 256 to_remove.append(remove_at) 257 258 remaining_ind = ind_list 259 for i in range(len(to_remove)): 260 remaining_ind.remove(to_remove[i]) 261 262 X_train = X_train[remaining_ind, :, :] 263 y_train = y_train[remaining_ind] 264 265 self.clf.fit(X_train, y_train) 266 cv_preds[test_idx] = self.clf.predict(X_test) 267 predproba = self.clf.predict_proba(X_test) 268 269 # Use pred proba to show what would be predicted 270 predprobs = predproba[:, 1] 271 real = np.where(y_test == 1) 272 273 # TODO handle exception where two probabilities are the same 274 prediction = int(np.where(predprobs == np.amax(predprobs))[0][0]) 275 276 logger.debug("y_test = %s", y_test) 277 logger.debug("predproba = %s", predproba) 278 logger.debug("real = %s", real[0]) 279 logger.debug("prediction = %s", prediction) 280 281 # Train final model with all available data 282 self.clf.fit(X, y) 283 model = self.clf 284 285 accuracy = sum(cv_preds == self.y) / len(cv_preds) 286 precision = precision_score(self.y, cv_preds) 287 recall = recall_score(self.y, cv_preds) 288 289 return KernelResults(model, cv_preds, accuracy, precision, recall) 290 291 # Check if channel selection is true 292 if self.channel_selection_setup: 293 logger.info("Doing channel selection") 294 logger.debug("Initial subset: %s", self.chs_initial_subset) 295 296 channel_selection_results = channel_selection_by_method( 297 __erp_rg_kernel, 298 self.X, 299 self.y, 300 self.channel_labels, # kernel setup 301 self.chs_method, 302 self.chs_metric, 303 self.chs_initial_subset, # wrapper setup 304 self.chs_max_time, 305 self.chs_min_channels, 306 self.chs_max_channels, 307 self.chs_performance_delta, # stopping criterion 308 self.chs_n_jobs, 309 ) # njobs, output messages 310 311 preds = channel_selection_results.best_preds 312 accuracy = channel_selection_results.best_accuracy 313 precision = channel_selection_results.best_precision 314 recall = channel_selection_results.best_recall 315 316 logger.info( 317 "The optimal subset is %s", 318 channel_selection_results.best_channel_subset, 319 ) 320 321 self.results_df = channel_selection_results.results_df 322 self.subset = channel_selection_results.best_channel_subset 323 self.subset_defined = True 324 self.clf = channel_selection_results.best_model 325 else: 326 logger.warning("Not doing channel selection") 327 X = self.get_subset(self.X, self.subset, self.channel_labels) 328 329 current_results = __erp_rg_kernel(X, self.y) 330 self.clf = current_results.model 331 preds = current_results.cv_preds 332 accuracy = current_results.accuracy 333 precision = current_results.precision 334 recall = current_results.recall 335 336 # Log performance stats 337 # accuracy 338 accuracy = sum(preds == self.y) / len(preds) 339 self.offline_accuracy = accuracy 340 logger.info("Accuracy = %s", accuracy) 341 342 # precision 343 precision = precision_score(self.y, preds) 344 self.offline_precision = precision 345 logger.info("Precision = %s", precision) 346 347 # recall 348 recall = recall_score(self.y, preds) 349 self.offline_recall = recall 350 logger.info("Recall = %s", recall) 351 352 # confusion matrix in command line 353 cm = confusion_matrix(self.y, preds) 354 self.offline_cm = cm 355 logger.info("Confusion matrix:\n%s", cm) 356 357 if plot_cm: 358 cm = confusion_matrix(self.y, preds) 359 ConfusionMatrixDisplay(cm).plot() 360 plt.show() 361 362 if plot_roc: 363 logger.error("ROC plot has not been implemented yet")
Fit the model.
Parameters
- plot_cm (bool, optional):
Whether to plot the confusion matrix during training.
- Default is
False.
- Default is
- plot_roc (bool, optional):
Whether to plot the ROC curve during training.
- Default is
False.
- Default is
- lico_expansion_factor (int, optional):
Linear combination oversampling expansion factor.
Determines the number of ERPs in the training set that will be expanded.
Higher value increases the oversampling, generating more synthetic
samples for the minority class.
- Default is
1.
- Default is
Returns
None: Models created used inpredict().
def
predict(self, X):
365 def predict(self, X): 366 """Predict the class of the data (Unused in this classifier) 367 368 Parameters 369 ---------- 370 X : numpy.ndarray 371 3D array where shape = (n_epochs, n_channels, n_samples) 372 373 Returns 374 ------- 375 prediction : Prediction 376 Predict object. Contains the predicted labels and and the probability. 377 Because this classifier chooses the P300 object with the highest posterior probability, 378 the probability is only the posterior probability of the chosen object. 379 380 """ 381 382 subset_X = self.get_subset(X, self.subset, self.channel_labels) 383 384 # Get posterior probability for each target 385 posterior_prob = self.clf.predict_proba(subset_X)[:, 1] 386 387 label = [int(np.argmax(posterior_prob))] 388 probability = [np.max(posterior_prob)] 389 390 return Prediction(label, probability)
Predict the class of the data (Unused in this classifier)
Parameters
- X (numpy.ndarray): 3D array where shape = (n_epochs, n_channels, n_samples)
Returns
- prediction (Prediction): Predict object. Contains the predicted labels and and the probability. Because this classifier chooses the P300 object with the highest posterior probability, the probability is only the posterior probability of the chosen object.