bci_essentials.classification.erp_rg_classifier_hyperparamgridsearch
ERP RG Classifier
This classifier is used to classify ERPs using the Riemannian Geometry approach.
1"""**ERP RG Classifier** 2 3This classifier is used to classify ERPs using the Riemannian Geometry 4approach. 5 6""" 7 8# Stock libraries 9import numpy as np 10import matplotlib.pyplot as plt 11from sklearn.model_selection import StratifiedKFold, GridSearchCV 12from sklearn.metrics import ( 13 confusion_matrix, 14 ConfusionMatrixDisplay, 15 precision_score, 16 recall_score, 17 roc_auc_score, 18 make_scorer, 19) 20from sklearn.pipeline import Pipeline 21from pyriemann.tangentspace import TangentSpace 22from pyriemann.estimation import XdawnCovariances 23from pyriemann.channelselection import FlatChannelRemover 24from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 25 26# Import bci_essentials modules and methods 27from .generic_classifier import ( 28 GenericClassifier, 29 Prediction, 30) 31from ..signal_processing import lico, random_oversampling, random_undersampling 32from ..utils.logger import Logger # Logger wrapper 33 34# Instantiate a logger for the module at the default level of logging.INFO 35# Logs to bci_essentials.__module__) where __module__ is the name of the module 36logger = Logger(name=__name__) 37 38 39class ErpRgClassifierHyperparamGridSearch(GenericClassifier): 40 """ERP RG Classifier with hyperparameter grid search 41 class (*inherits from `GenericClassifier`*).""" 42 43 def set_p300_clf_settings( 44 self, 45 n_splits=3, 46 resampling_method=None, 47 lico_expansion_factor=1, 48 oversample_ratio=0, 49 undersample_ratio=0, 50 random_seed=42, 51 remove_flats=True, 52 ): 53 """Set P300 Classifier Settings. 54 55 Parameters 56 ---------- 57 n_splits : int, *optional* 58 Number of folds for cross-validation. 59 - Default is `3`. 60 resampling_method : str, *optional*, None 61 Resampling method to use ["lico", "oversample", "undersample"]. 62 Default is None. 63 lico_expansion_factor : int, *optional* 64 Linear Combination Oversampling expansion factor, which is the 65 factor by which the number of ERPs in the training set will be 66 expanded. 67 - Default is `1`. 68 oversample_ratio : float, *optional* 69 Traditional oversampling. Range is from from 0.1-1 resulting 70 from the ratio of erp to non-erp class. 0 for no oversampling. 71 - Default is `0`. 72 undersample_ratio : float, *optional* 73 Traditional undersampling. Range is from from 0.1-1 resulting 74 from the ratio of erp to non-erp class. 0 for no undersampling. 75 - Default is `0`. 76 random_seed : int, *optional* 77 Random seed. 78 - Default is `42`. 79 remove_flats : bool, *optional* 80 Whether to remove flat channels. 81 - Default is `True`. 82 83 Returns 84 ------- 85 `None` 86 87 """ 88 self.n_splits = n_splits 89 self.resampling_method = resampling_method 90 self.lico_expansion_factor = lico_expansion_factor 91 self.oversample_ratio = oversample_ratio 92 self.undersample_ratio = undersample_ratio 93 self.random_seed = random_seed 94 95 # # Create steps list with proper formatting 96 steps = [] 97 if remove_flats: 98 steps.append(("remove_flats", FlatChannelRemover())) 99 100 steps.extend( 101 [ 102 ("xdawn", XdawnCovariances()), 103 ("tangent", TangentSpace()), 104 ("lda", LinearDiscriminantAnalysis()), 105 ] 106 ) 107 108 # Create pipeline 109 self.clf = Pipeline(steps) 110 111 # Hyperparameters to be optimized 112 # TODO: Implement an extended nfilter set, dynamically based on the number of channels 113 # Example of dynamic nfilter set 114 # n_channels = self.X.shape[1] 115 # nfilter_set = list(range(2, n_channels+1)) # Example range from 2 to n_channels inclusive 116 # Then set "xdawn__nfilter": nfilter_set in the param_grid below 117 self.param_grid = { 118 "xdawn__nfilter": [2, 3, 4], 119 "xdawn__estimator": ["oas", "lwf"], 120 "tangent__metric": ["riemann"], 121 "lda__solver": ["lsqr", "eigen"], 122 "lda__shrinkage": np.linspace(0.5, 0.9, 5), 123 } 124 125 def fit( 126 self, 127 plot_cm=False, 128 plot_roc=False, 129 ): 130 """Fit the model. 131 132 Parameters 133 ---------- 134 plot_cm : bool, *optional* 135 Whether to plot the confusion matrix during training. 136 - Default is `False`. 137 plot_roc : bool, *optional* 138 Whether to plot the ROC curve during training. 139 - Default is `False`. 140 141 Returns 142 ------- 143 `None` 144 Models created used in `predict()`. 145 146 """ 147 logger.info("Fitting the model using RG") 148 logger.info("X shape: %s", self.X.shape) 149 logger.info("y shape: %s", self.y.shape) 150 151 # Resample data if needed 152 self.X, self.y = self.__resample_data() 153 154 # Optimize hyperparameters with cross-validation 155 self.__optimize_hyperparameters() 156 157 # Fit the model with the complete dataset and optimized hyperparameters 158 self.clf.fit(self.X, self.y) 159 160 # Get predictions for final model 161 y_pred_proba = self.clf.predict_proba(self.X)[:, 1] 162 163 # Calculate estimate of training metrics of final model 164 # TODO: Implement proper training metrics calculation, using cross validation. 165 # self.offline_accuracy = sum(y_pred == self.y) / len(self.y) 166 # self.offline_precision = precision_score(self.y, y_pred) 167 # self.offline_recall = recall_score(self.y, y_pred) 168 169 try: 170 roc_auc = roc_auc_score(self.y, y_pred_proba) 171 logger.info(f"ROC AUC Score: {roc_auc:0.3f}") 172 except Exception as e: 173 logger.warning(f"Could not calculate ROC AUC score: {e}") 174 175 # Display training confusion matrix 176 # self.offline_cm = confusion_matrix(self.y, y_pred) 177 if plot_cm: 178 disp = ConfusionMatrixDisplay(confusion_matrix=self.offline_cm) 179 disp.plot() 180 plt.title("Training confusion matrix") 181 182 if plot_roc: 183 # TODO Implementation missing 184 pass 185 186 # Log training metrics 187 logger.info("Final model training performance metrics:") 188 logger.info(f"Accuracy: {self.offline_accuracy:0.3f} - MAY NOT BE ACCURATE") 189 logger.info(f"Precision: {self.offline_precision:0.3f} - MAY NOT BE ACCURATE") 190 logger.info(f"Recall: {self.offline_recall:0.3f} - MAY NOT BE ACCURATE") 191 logger.info(f"Confusion Matrix:\n{self.offline_cm} ") 192 logger.warning( 193 "Note: Training metrics may not be accurate due to the use of " 194 "cross-validation and resampling methods. Use with caution." 195 ) 196 197 def predict(self, X): 198 """Predict the class of the data 199 200 Parameters 201 ---------- 202 X : numpy.ndarray 203 3D array where shape = (n_epochs, n_channels, n_samples) 204 205 Returns 206 ------- 207 prediction : Prediction 208 Predict object. Contains the predicted labels and and the probability. 209 Because this classifier chooses the P300 object with the highest posterior probability, 210 the probability is only the posterior probability of the chosen object. 211 212 """ 213 214 subset_X = self.get_subset(X, self.subset, self.channel_labels) 215 216 # Get posterior probability for each target 217 posterior_prob = self.clf.predict_proba(subset_X)[:, 1] 218 219 label = [int(np.argmax(posterior_prob))] 220 probability = [np.max(posterior_prob)] 221 222 return Prediction(label, probability) 223 224 # TODO implement additional resampling methods, JIRA ticket: B4K-342 225 def __resample_data(self): 226 """Resample data based on the selected method""" 227 228 X_resampled = self.X.copy() 229 y_resampled = self.y.copy() 230 231 try: 232 if (self.resampling_method == "lico") and (self.lico_expansion_factor > 1): 233 [X_resampled, y_resampled] = lico( 234 self.X, self.y, self.lico_expansion_factor 235 ) 236 pass 237 238 elif (self.resampling_method == "oversample") and ( 239 self.oversample_ratio > 0 240 ): 241 [X_resampled, y_resampled] = random_oversampling( 242 self.X, self.y, self.oversample_ratio 243 ) 244 pass 245 246 elif (self.resampling_method == "undersample") and ( 247 self.undersample_ratio > 0 248 ): 249 [X_resampled, y_resampled] = random_undersampling( 250 self.X, self.y, self.undersample_ratio 251 ) 252 pass 253 254 logger.info(f"Resampling with {self.resampling_method} done") 255 logger.info(f"X_resampled shape: {X_resampled.shape}") 256 logger.info(f"y_resampled shape: {y_resampled.shape}") 257 258 except Exception as e: 259 logger.error( 260 f"{self.resampling_method.capitalize()} resampling method failed" 261 ) 262 logger.error(e) 263 264 return X_resampled, y_resampled 265 266 def __optimize_hyperparameters(self): 267 """Optimize hyperparameters with cross-validation using brute force grid search 268 269 Returns 270 ------- 271 `None` 272 Model with best hyperparameters to be used in `predict()`. 273 274 """ 275 276 # Perform cross-validation 277 cv = StratifiedKFold( 278 n_splits=self.n_splits, shuffle=True, random_state=self.random_seed 279 ) 280 281 # Create custom scorer function 282 custom_scorer = make_scorer( 283 self._valid_roc_auc, response_method="predict_proba", greater_is_better=True 284 ) 285 286 # Create GridSearchCV object 287 grid_search = GridSearchCV( 288 estimator=self.clf, 289 param_grid=self.param_grid, 290 cv=cv, 291 n_jobs=-1, 292 verbose=1, 293 scoring=custom_scorer, 294 refit=True, 295 return_train_score=True, 296 ) 297 298 # Start grid search optimization 299 logger.info("Starting grid search optimization...") 300 grid_search.fit(self.X, self.y) 301 302 # Get best parameters and score 303 logger.info("Grid search optimization completed.") 304 best_params = grid_search.best_params_ 305 best_score = grid_search.best_score_ 306 307 # Report training metrics: TODO: Verify this is the right way to calculate training metrics 308 self.offline_accuracy = grid_search.best_estimator_.score(self.X, self.y) 309 self.offline_cm = confusion_matrix( 310 self.y, grid_search.best_estimator_.predict(self.X) 311 ) 312 self.offline_precision = precision_score( 313 self.y, grid_search.best_estimator_.predict(self.X) 314 ) 315 self.offline_recall = recall_score( 316 self.y, grid_search.best_estimator_.predict(self.X) 317 ) 318 319 # Update classifier with best parameters 320 self.clf.set_params(**best_params) 321 logger.info(f"Best parameters found: {best_params}") 322 logger.info(f"Best CV score: {best_score:0.3f}") 323 324 def _valid_roc_auc(self, y_true, y_pred, **kwargs): 325 """Calculate the ROC AUC score for the classifier. 326 This method is used because the stock `roc_auc_score` function 327 does not handle the case where one class is missing in the fold. 328 This method will return 0.5 in that case. 329 330 Parameters 331 ---------- 332 y_true : numpy.ndarray 333 True labels. 334 y_pred : numpy.ndarray 335 Predicted labels. 336 **kwargs : dict 337 Additional keyword arguments passed by make_scorer. 338 339 Returns 340 ------- 341 roc_auc : float 342 ROC AUC score. 343 344 """ 345 try: 346 # Check if we have both classes in the fold 347 if len(np.unique(y_true)) < 2: 348 logger.warning("Fold contains only one class") 349 return 0.5 350 351 return roc_auc_score(y_true, y_pred) 352 353 except Exception as e: 354 logger.warning(f"ROC AUC calculation failed: {e}") 355 return 0.5
logger =
<bci_essentials.utils.logger.Logger object>
class
ErpRgClassifierHyperparamGridSearch(bci_essentials.classification.generic_classifier.GenericClassifier):
40class ErpRgClassifierHyperparamGridSearch(GenericClassifier): 41 """ERP RG Classifier with hyperparameter grid search 42 class (*inherits from `GenericClassifier`*).""" 43 44 def set_p300_clf_settings( 45 self, 46 n_splits=3, 47 resampling_method=None, 48 lico_expansion_factor=1, 49 oversample_ratio=0, 50 undersample_ratio=0, 51 random_seed=42, 52 remove_flats=True, 53 ): 54 """Set P300 Classifier Settings. 55 56 Parameters 57 ---------- 58 n_splits : int, *optional* 59 Number of folds for cross-validation. 60 - Default is `3`. 61 resampling_method : str, *optional*, None 62 Resampling method to use ["lico", "oversample", "undersample"]. 63 Default is None. 64 lico_expansion_factor : int, *optional* 65 Linear Combination Oversampling expansion factor, which is the 66 factor by which the number of ERPs in the training set will be 67 expanded. 68 - Default is `1`. 69 oversample_ratio : float, *optional* 70 Traditional oversampling. Range is from from 0.1-1 resulting 71 from the ratio of erp to non-erp class. 0 for no oversampling. 72 - Default is `0`. 73 undersample_ratio : float, *optional* 74 Traditional undersampling. Range is from from 0.1-1 resulting 75 from the ratio of erp to non-erp class. 0 for no undersampling. 76 - Default is `0`. 77 random_seed : int, *optional* 78 Random seed. 79 - Default is `42`. 80 remove_flats : bool, *optional* 81 Whether to remove flat channels. 82 - Default is `True`. 83 84 Returns 85 ------- 86 `None` 87 88 """ 89 self.n_splits = n_splits 90 self.resampling_method = resampling_method 91 self.lico_expansion_factor = lico_expansion_factor 92 self.oversample_ratio = oversample_ratio 93 self.undersample_ratio = undersample_ratio 94 self.random_seed = random_seed 95 96 # # Create steps list with proper formatting 97 steps = [] 98 if remove_flats: 99 steps.append(("remove_flats", FlatChannelRemover())) 100 101 steps.extend( 102 [ 103 ("xdawn", XdawnCovariances()), 104 ("tangent", TangentSpace()), 105 ("lda", LinearDiscriminantAnalysis()), 106 ] 107 ) 108 109 # Create pipeline 110 self.clf = Pipeline(steps) 111 112 # Hyperparameters to be optimized 113 # TODO: Implement an extended nfilter set, dynamically based on the number of channels 114 # Example of dynamic nfilter set 115 # n_channels = self.X.shape[1] 116 # nfilter_set = list(range(2, n_channels+1)) # Example range from 2 to n_channels inclusive 117 # Then set "xdawn__nfilter": nfilter_set in the param_grid below 118 self.param_grid = { 119 "xdawn__nfilter": [2, 3, 4], 120 "xdawn__estimator": ["oas", "lwf"], 121 "tangent__metric": ["riemann"], 122 "lda__solver": ["lsqr", "eigen"], 123 "lda__shrinkage": np.linspace(0.5, 0.9, 5), 124 } 125 126 def fit( 127 self, 128 plot_cm=False, 129 plot_roc=False, 130 ): 131 """Fit the model. 132 133 Parameters 134 ---------- 135 plot_cm : bool, *optional* 136 Whether to plot the confusion matrix during training. 137 - Default is `False`. 138 plot_roc : bool, *optional* 139 Whether to plot the ROC curve during training. 140 - Default is `False`. 141 142 Returns 143 ------- 144 `None` 145 Models created used in `predict()`. 146 147 """ 148 logger.info("Fitting the model using RG") 149 logger.info("X shape: %s", self.X.shape) 150 logger.info("y shape: %s", self.y.shape) 151 152 # Resample data if needed 153 self.X, self.y = self.__resample_data() 154 155 # Optimize hyperparameters with cross-validation 156 self.__optimize_hyperparameters() 157 158 # Fit the model with the complete dataset and optimized hyperparameters 159 self.clf.fit(self.X, self.y) 160 161 # Get predictions for final model 162 y_pred_proba = self.clf.predict_proba(self.X)[:, 1] 163 164 # Calculate estimate of training metrics of final model 165 # TODO: Implement proper training metrics calculation, using cross validation. 166 # self.offline_accuracy = sum(y_pred == self.y) / len(self.y) 167 # self.offline_precision = precision_score(self.y, y_pred) 168 # self.offline_recall = recall_score(self.y, y_pred) 169 170 try: 171 roc_auc = roc_auc_score(self.y, y_pred_proba) 172 logger.info(f"ROC AUC Score: {roc_auc:0.3f}") 173 except Exception as e: 174 logger.warning(f"Could not calculate ROC AUC score: {e}") 175 176 # Display training confusion matrix 177 # self.offline_cm = confusion_matrix(self.y, y_pred) 178 if plot_cm: 179 disp = ConfusionMatrixDisplay(confusion_matrix=self.offline_cm) 180 disp.plot() 181 plt.title("Training confusion matrix") 182 183 if plot_roc: 184 # TODO Implementation missing 185 pass 186 187 # Log training metrics 188 logger.info("Final model training performance metrics:") 189 logger.info(f"Accuracy: {self.offline_accuracy:0.3f} - MAY NOT BE ACCURATE") 190 logger.info(f"Precision: {self.offline_precision:0.3f} - MAY NOT BE ACCURATE") 191 logger.info(f"Recall: {self.offline_recall:0.3f} - MAY NOT BE ACCURATE") 192 logger.info(f"Confusion Matrix:\n{self.offline_cm} ") 193 logger.warning( 194 "Note: Training metrics may not be accurate due to the use of " 195 "cross-validation and resampling methods. Use with caution." 196 ) 197 198 def predict(self, X): 199 """Predict the class of the data 200 201 Parameters 202 ---------- 203 X : numpy.ndarray 204 3D array where shape = (n_epochs, n_channels, n_samples) 205 206 Returns 207 ------- 208 prediction : Prediction 209 Predict object. Contains the predicted labels and and the probability. 210 Because this classifier chooses the P300 object with the highest posterior probability, 211 the probability is only the posterior probability of the chosen object. 212 213 """ 214 215 subset_X = self.get_subset(X, self.subset, self.channel_labels) 216 217 # Get posterior probability for each target 218 posterior_prob = self.clf.predict_proba(subset_X)[:, 1] 219 220 label = [int(np.argmax(posterior_prob))] 221 probability = [np.max(posterior_prob)] 222 223 return Prediction(label, probability) 224 225 # TODO implement additional resampling methods, JIRA ticket: B4K-342 226 def __resample_data(self): 227 """Resample data based on the selected method""" 228 229 X_resampled = self.X.copy() 230 y_resampled = self.y.copy() 231 232 try: 233 if (self.resampling_method == "lico") and (self.lico_expansion_factor > 1): 234 [X_resampled, y_resampled] = lico( 235 self.X, self.y, self.lico_expansion_factor 236 ) 237 pass 238 239 elif (self.resampling_method == "oversample") and ( 240 self.oversample_ratio > 0 241 ): 242 [X_resampled, y_resampled] = random_oversampling( 243 self.X, self.y, self.oversample_ratio 244 ) 245 pass 246 247 elif (self.resampling_method == "undersample") and ( 248 self.undersample_ratio > 0 249 ): 250 [X_resampled, y_resampled] = random_undersampling( 251 self.X, self.y, self.undersample_ratio 252 ) 253 pass 254 255 logger.info(f"Resampling with {self.resampling_method} done") 256 logger.info(f"X_resampled shape: {X_resampled.shape}") 257 logger.info(f"y_resampled shape: {y_resampled.shape}") 258 259 except Exception as e: 260 logger.error( 261 f"{self.resampling_method.capitalize()} resampling method failed" 262 ) 263 logger.error(e) 264 265 return X_resampled, y_resampled 266 267 def __optimize_hyperparameters(self): 268 """Optimize hyperparameters with cross-validation using brute force grid search 269 270 Returns 271 ------- 272 `None` 273 Model with best hyperparameters to be used in `predict()`. 274 275 """ 276 277 # Perform cross-validation 278 cv = StratifiedKFold( 279 n_splits=self.n_splits, shuffle=True, random_state=self.random_seed 280 ) 281 282 # Create custom scorer function 283 custom_scorer = make_scorer( 284 self._valid_roc_auc, response_method="predict_proba", greater_is_better=True 285 ) 286 287 # Create GridSearchCV object 288 grid_search = GridSearchCV( 289 estimator=self.clf, 290 param_grid=self.param_grid, 291 cv=cv, 292 n_jobs=-1, 293 verbose=1, 294 scoring=custom_scorer, 295 refit=True, 296 return_train_score=True, 297 ) 298 299 # Start grid search optimization 300 logger.info("Starting grid search optimization...") 301 grid_search.fit(self.X, self.y) 302 303 # Get best parameters and score 304 logger.info("Grid search optimization completed.") 305 best_params = grid_search.best_params_ 306 best_score = grid_search.best_score_ 307 308 # Report training metrics: TODO: Verify this is the right way to calculate training metrics 309 self.offline_accuracy = grid_search.best_estimator_.score(self.X, self.y) 310 self.offline_cm = confusion_matrix( 311 self.y, grid_search.best_estimator_.predict(self.X) 312 ) 313 self.offline_precision = precision_score( 314 self.y, grid_search.best_estimator_.predict(self.X) 315 ) 316 self.offline_recall = recall_score( 317 self.y, grid_search.best_estimator_.predict(self.X) 318 ) 319 320 # Update classifier with best parameters 321 self.clf.set_params(**best_params) 322 logger.info(f"Best parameters found: {best_params}") 323 logger.info(f"Best CV score: {best_score:0.3f}") 324 325 def _valid_roc_auc(self, y_true, y_pred, **kwargs): 326 """Calculate the ROC AUC score for the classifier. 327 This method is used because the stock `roc_auc_score` function 328 does not handle the case where one class is missing in the fold. 329 This method will return 0.5 in that case. 330 331 Parameters 332 ---------- 333 y_true : numpy.ndarray 334 True labels. 335 y_pred : numpy.ndarray 336 Predicted labels. 337 **kwargs : dict 338 Additional keyword arguments passed by make_scorer. 339 340 Returns 341 ------- 342 roc_auc : float 343 ROC AUC score. 344 345 """ 346 try: 347 # Check if we have both classes in the fold 348 if len(np.unique(y_true)) < 2: 349 logger.warning("Fold contains only one class") 350 return 0.5 351 352 return roc_auc_score(y_true, y_pred) 353 354 except Exception as e: 355 logger.warning(f"ROC AUC calculation failed: {e}") 356 return 0.5
ERP RG Classifier with hyperparameter grid search
class (inherits from GenericClassifier).
def
set_p300_clf_settings( self, n_splits=3, resampling_method=None, lico_expansion_factor=1, oversample_ratio=0, undersample_ratio=0, random_seed=42, remove_flats=True):
44 def set_p300_clf_settings( 45 self, 46 n_splits=3, 47 resampling_method=None, 48 lico_expansion_factor=1, 49 oversample_ratio=0, 50 undersample_ratio=0, 51 random_seed=42, 52 remove_flats=True, 53 ): 54 """Set P300 Classifier Settings. 55 56 Parameters 57 ---------- 58 n_splits : int, *optional* 59 Number of folds for cross-validation. 60 - Default is `3`. 61 resampling_method : str, *optional*, None 62 Resampling method to use ["lico", "oversample", "undersample"]. 63 Default is None. 64 lico_expansion_factor : int, *optional* 65 Linear Combination Oversampling expansion factor, which is the 66 factor by which the number of ERPs in the training set will be 67 expanded. 68 - Default is `1`. 69 oversample_ratio : float, *optional* 70 Traditional oversampling. Range is from from 0.1-1 resulting 71 from the ratio of erp to non-erp class. 0 for no oversampling. 72 - Default is `0`. 73 undersample_ratio : float, *optional* 74 Traditional undersampling. Range is from from 0.1-1 resulting 75 from the ratio of erp to non-erp class. 0 for no undersampling. 76 - Default is `0`. 77 random_seed : int, *optional* 78 Random seed. 79 - Default is `42`. 80 remove_flats : bool, *optional* 81 Whether to remove flat channels. 82 - Default is `True`. 83 84 Returns 85 ------- 86 `None` 87 88 """ 89 self.n_splits = n_splits 90 self.resampling_method = resampling_method 91 self.lico_expansion_factor = lico_expansion_factor 92 self.oversample_ratio = oversample_ratio 93 self.undersample_ratio = undersample_ratio 94 self.random_seed = random_seed 95 96 # # Create steps list with proper formatting 97 steps = [] 98 if remove_flats: 99 steps.append(("remove_flats", FlatChannelRemover())) 100 101 steps.extend( 102 [ 103 ("xdawn", XdawnCovariances()), 104 ("tangent", TangentSpace()), 105 ("lda", LinearDiscriminantAnalysis()), 106 ] 107 ) 108 109 # Create pipeline 110 self.clf = Pipeline(steps) 111 112 # Hyperparameters to be optimized 113 # TODO: Implement an extended nfilter set, dynamically based on the number of channels 114 # Example of dynamic nfilter set 115 # n_channels = self.X.shape[1] 116 # nfilter_set = list(range(2, n_channels+1)) # Example range from 2 to n_channels inclusive 117 # Then set "xdawn__nfilter": nfilter_set in the param_grid below 118 self.param_grid = { 119 "xdawn__nfilter": [2, 3, 4], 120 "xdawn__estimator": ["oas", "lwf"], 121 "tangent__metric": ["riemann"], 122 "lda__solver": ["lsqr", "eigen"], 123 "lda__shrinkage": np.linspace(0.5, 0.9, 5), 124 }
Set P300 Classifier Settings.
Parameters
- n_splits (int, optional):
Number of folds for cross-validation.
- Default is
3.
- Default is
- resampling_method (str, optional, None): Resampling method to use ["lico", "oversample", "undersample"]. Default is None.
- lico_expansion_factor (int, optional):
Linear Combination Oversampling expansion factor, which is the
factor by which the number of ERPs in the training set will be
expanded.
- Default is
1.
- Default is
- oversample_ratio (float, optional):
Traditional oversampling. Range is from from 0.1-1 resulting
from the ratio of erp to non-erp class. 0 for no oversampling.
- Default is
0.
- Default is
- undersample_ratio (float, optional):
Traditional undersampling. Range is from from 0.1-1 resulting
from the ratio of erp to non-erp class. 0 for no undersampling.
- Default is
0.
- Default is
- random_seed (int, optional):
Random seed.
- Default is
42.
- Default is
- remove_flats (bool, optional):
Whether to remove flat channels.
- Default is
True.
- Default is
Returns
None
def
fit(self, plot_cm=False, plot_roc=False):
126 def fit( 127 self, 128 plot_cm=False, 129 plot_roc=False, 130 ): 131 """Fit the model. 132 133 Parameters 134 ---------- 135 plot_cm : bool, *optional* 136 Whether to plot the confusion matrix during training. 137 - Default is `False`. 138 plot_roc : bool, *optional* 139 Whether to plot the ROC curve during training. 140 - Default is `False`. 141 142 Returns 143 ------- 144 `None` 145 Models created used in `predict()`. 146 147 """ 148 logger.info("Fitting the model using RG") 149 logger.info("X shape: %s", self.X.shape) 150 logger.info("y shape: %s", self.y.shape) 151 152 # Resample data if needed 153 self.X, self.y = self.__resample_data() 154 155 # Optimize hyperparameters with cross-validation 156 self.__optimize_hyperparameters() 157 158 # Fit the model with the complete dataset and optimized hyperparameters 159 self.clf.fit(self.X, self.y) 160 161 # Get predictions for final model 162 y_pred_proba = self.clf.predict_proba(self.X)[:, 1] 163 164 # Calculate estimate of training metrics of final model 165 # TODO: Implement proper training metrics calculation, using cross validation. 166 # self.offline_accuracy = sum(y_pred == self.y) / len(self.y) 167 # self.offline_precision = precision_score(self.y, y_pred) 168 # self.offline_recall = recall_score(self.y, y_pred) 169 170 try: 171 roc_auc = roc_auc_score(self.y, y_pred_proba) 172 logger.info(f"ROC AUC Score: {roc_auc:0.3f}") 173 except Exception as e: 174 logger.warning(f"Could not calculate ROC AUC score: {e}") 175 176 # Display training confusion matrix 177 # self.offline_cm = confusion_matrix(self.y, y_pred) 178 if plot_cm: 179 disp = ConfusionMatrixDisplay(confusion_matrix=self.offline_cm) 180 disp.plot() 181 plt.title("Training confusion matrix") 182 183 if plot_roc: 184 # TODO Implementation missing 185 pass 186 187 # Log training metrics 188 logger.info("Final model training performance metrics:") 189 logger.info(f"Accuracy: {self.offline_accuracy:0.3f} - MAY NOT BE ACCURATE") 190 logger.info(f"Precision: {self.offline_precision:0.3f} - MAY NOT BE ACCURATE") 191 logger.info(f"Recall: {self.offline_recall:0.3f} - MAY NOT BE ACCURATE") 192 logger.info(f"Confusion Matrix:\n{self.offline_cm} ") 193 logger.warning( 194 "Note: Training metrics may not be accurate due to the use of " 195 "cross-validation and resampling methods. Use with caution." 196 )
Fit the model.
Parameters
- plot_cm (bool, optional):
Whether to plot the confusion matrix during training.
- Default is
False.
- Default is
- plot_roc (bool, optional):
Whether to plot the ROC curve during training.
- Default is
False.
- Default is
Returns
None: Models created used inpredict().
def
predict(self, X):
198 def predict(self, X): 199 """Predict the class of the data 200 201 Parameters 202 ---------- 203 X : numpy.ndarray 204 3D array where shape = (n_epochs, n_channels, n_samples) 205 206 Returns 207 ------- 208 prediction : Prediction 209 Predict object. Contains the predicted labels and and the probability. 210 Because this classifier chooses the P300 object with the highest posterior probability, 211 the probability is only the posterior probability of the chosen object. 212 213 """ 214 215 subset_X = self.get_subset(X, self.subset, self.channel_labels) 216 217 # Get posterior probability for each target 218 posterior_prob = self.clf.predict_proba(subset_X)[:, 1] 219 220 label = [int(np.argmax(posterior_prob))] 221 probability = [np.max(posterior_prob)] 222 223 return Prediction(label, probability)
Predict the class of the data
Parameters
- X (numpy.ndarray): 3D array where shape = (n_epochs, n_channels, n_samples)
Returns
- prediction (Prediction): Predict object. Contains the predicted labels and and the probability. Because this classifier chooses the P300 object with the highest posterior probability, the probability is only the posterior probability of the chosen object.