bci_essentials.classification.erp_single_channel_classifier
ERP Single Channel Classifier
This classifier is used to classify ERPs when only a single channel (ex. ear EEG) is available.
1"""**ERP Single Channel Classifier** 2 3This classifier is used to classify ERPs when only a single channel (ex. ear EEG) is available. 4 5""" 6 7# Stock libraries 8import random 9import numpy as np 10import matplotlib.pyplot as plt 11from sklearn.pipeline import make_pipeline 12from sklearn.model_selection import StratifiedKFold 13from sklearn.metrics import ( 14 confusion_matrix, 15 ConfusionMatrixDisplay, 16 precision_score, 17 recall_score, 18) 19from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 20 21# Import bci_essentials modules and methods 22from .generic_classifier import GenericClassifier, Prediction, KernelResults 23from ..signal_processing import lico 24from ..channel_selection import channel_selection_by_method 25from ..utils.logger import Logger # Logger wrapper 26from ..utils.reduce_to_single_channel import ReduceToSingleChannel 27 28# Instantiate a logger for the module at the default level of logging.INFO 29# Logs to bci_essentials.__module__) where __module__ is the name of the module 30logger = Logger(name=__name__) 31 32 33class ErpSingleChannelClassifier(GenericClassifier): 34 """ERP Single Channel Classifier class (*inherits from `GenericClassifier`*).""" 35 36 def set_p300_clf_settings( 37 self, 38 n_splits=3, 39 lico_expansion_factor=1, 40 oversample_ratio=0, 41 undersample_ratio=0, 42 random_seed=42, 43 ): 44 """Set P300 Classifier Settings. 45 46 Parameters 47 ---------- 48 n_splits : int, *optional* 49 Number of folds for cross-validation. 50 - Default is `3`. 51 lico_expansion_factor : int, *optional* 52 Linear Combination Oversampling expansion factor, which is the 53 factor by which the number of ERPs in the training set will be 54 expanded. 55 - Default is `1`. 56 oversample_ratio : float, *optional* 57 Traditional oversampling. Range is from from 0.1-1 resulting 58 from the ratio of erp to non-erp class. 0 for no oversampling. 59 - Default is `0`. 60 undersample_ratio : float, *optional* 61 Traditional undersampling. Range is from from 0.1-1 resulting 62 from the ratio of erp to non-erp class. 0 for no undersampling. 63 - Default is `0`. 64 random_seed : int, *optional* 65 Random seed. 66 - Default is `42`. 67 68 Returns 69 ------- 70 `None` 71 72 """ 73 self.n_splits = n_splits 74 self.lico_expansion_factor = lico_expansion_factor 75 self.oversample_ratio = oversample_ratio 76 self.undersample_ratio = undersample_ratio 77 self.random_seed = random_seed 78 79 def fit( 80 self, 81 plot_cm=False, 82 plot_roc=False, 83 lico_expansion_factor=1, 84 ): 85 """Fit the model. 86 87 Parameters 88 ---------- 89 plot_cm : bool, *optional* 90 Whether to plot the confusion matrix during training. 91 - Default is `False`. 92 plot_roc : bool, *optional* 93 Whether to plot the ROC curve during training. 94 - Default is `False`. 95 lico_expansion_factor : int, *optional* 96 Linear combination oversampling expansion factor. 97 Determines the number of ERPs in the training set that will be expanded. 98 Higher value increases the oversampling, generating more synthetic 99 samples for the minority class. 100 - Default is `1`. 101 102 Returns 103 ------- 104 `None` 105 Models created used in `predict()`. 106 107 """ 108 109 logger.info("Fitting the model using sLDA") 110 logger.info("X shape: %s", self.X.shape) 111 logger.info("y shape: %s", self.y.shape) 112 113 # Define the strategy for cross validation 114 cv = StratifiedKFold( 115 n_splits=self.n_splits, shuffle=True, random_state=self.random_seed 116 ) 117 118 # Define the classifier 119 self.clf = make_pipeline( 120 ReduceToSingleChannel(), 121 LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"), 122 ) 123 124 # Init predictions to all false 125 cv_preds = np.zeros(len(self.y)) 126 127 # 128 def __erp_single_channel_kernel(X, y): 129 """ERP Single Channel kernel. 130 131 Parameters 132 ---------- 133 X : numpy.ndarray 134 Input features (ERP data) for training. 135 3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`). 136 E.g. (100, 1, 1000) for 100 trials, 1 channel and 1000 samples. 137 138 y : numpy.ndarray 139 Target labels corresponding to the input features in `X`. 140 1D numpy array with shape (n_trails, ). 141 Each label indicates the class of the corresponding trial in `X`. 142 E.g. (100, ) for 100 trials. 143 144 145 Returns 146 ------- 147 kernelResults : KernelResults 148 KernelResults object containing the following attributes: 149 model : classifier 150 The trained classification model. 151 cv_preds : numpy.ndarray 152 The predictions from the model using cross validation. 153 1D array with the same shape as `y`. 154 accuracy : float 155 The accuracy of the trained classification model. 156 precision : float 157 The precision of the trained classification model. 158 recall : float 159 The recall of the trained classification model. 160 161 """ 162 logger.info("X shape: %s", X.shape) 163 164 for train_idx, test_idx in cv.split(X, y): 165 y_train, y_test = y[train_idx], y[test_idx] 166 167 X_train, X_test = X[train_idx], X[test_idx] 168 169 # LICO 170 logger.debug( 171 "Before LICO:\n\tShape X: %s\n\tShape y: %s", 172 X_train.shape, 173 y_train.shape, 174 ) 175 176 if sum(y_train) > 2: 177 if lico_expansion_factor > 1: 178 X_train, y_train = lico( 179 X_train, 180 y_train, 181 expansion_factor=lico_expansion_factor, 182 sum_num=2, 183 shuffle=False, 184 ) 185 logger.debug("y_train = %s", y_train) 186 187 logger.debug( 188 "After LICO:\n\tShape X: %s\n\tShape y: %s", 189 X_train.shape, 190 y_train.shape, 191 ) 192 193 # Oversampling 194 if self.oversample_ratio > 0: 195 p_count = sum(y_train) 196 n_count = len(y_train) - sum(y_train) 197 198 num_to_add = int( 199 np.floor((self.oversample_ratio * n_count) - p_count) 200 ) 201 202 # Add num_to_add random selections from the positive 203 true_X_train = X_train[y_train == 1] 204 205 len_X_train = len(true_X_train) 206 207 for s in range(num_to_add): 208 to_add_X = true_X_train[random.randrange(0, len_X_train), :, :] 209 210 X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0) 211 y_train = np.append(y_train, [1], axis=0) 212 213 # Undersampling 214 if self.undersample_ratio > 0: 215 p_count = sum(y_train) 216 n_count = len(y_train) - sum(y_train) 217 218 num_to_remove = int( 219 np.floor(n_count - (p_count / self.undersample_ratio)) 220 ) 221 222 ind_range = np.arange(len(y_train)) 223 ind_list = list(ind_range) 224 to_remove = [] 225 226 # Remove num_to_remove random selections from the negative 227 false_ind = list(ind_range[y_train == 0]) 228 229 for s in range(num_to_remove): 230 # select a random value from the list of false indices 231 remove_at = false_ind[random.randrange(0, len(false_ind))] 232 233 # remove that value from the false ind list 234 false_ind.remove(remove_at) 235 236 # add the index to be removed to a list 237 to_remove.append(remove_at) 238 239 remaining_ind = ind_list 240 for i in range(len(to_remove)): 241 remaining_ind.remove(to_remove[i]) 242 243 X_train = X_train[remaining_ind, :, :] 244 y_train = y_train[remaining_ind] 245 246 self.clf.fit(X_train, y_train) 247 cv_preds[test_idx] = self.clf.predict(X_test) 248 predproba = self.clf.predict_proba(X_test) 249 250 # Use pred proba to show what would be predicted 251 predprobs = predproba[:, 1] 252 real = np.where(y_test == 1) 253 254 # TODO handle exception where two probabilities are the same 255 prediction = int(np.where(predprobs == np.amax(predprobs))[0][0]) 256 257 logger.debug("y_test = %s", y_test) 258 logger.debug("predproba = %s", predproba) 259 logger.debug("real = %s", real[0]) 260 logger.debug("prediction = %s", prediction) 261 262 # Train final model with all available data 263 self.clf.fit(X, y) 264 model = self.clf 265 266 accuracy = sum(cv_preds == self.y) / len(cv_preds) 267 precision = precision_score(self.y, cv_preds) 268 recall = recall_score(self.y, cv_preds) 269 270 return KernelResults(model, cv_preds, accuracy, precision, recall) 271 272 # Check if channel selection is true 273 if self.channel_selection_setup: 274 logger.info("Doing channel selection") 275 logger.debug("Initial subset: %s", self.chs_initial_subset) 276 277 channel_selection_results = channel_selection_by_method( 278 __erp_single_channel_kernel, 279 self.X, 280 self.y, 281 self.channel_labels, # kernel setup 282 self.chs_method, 283 self.chs_metric, 284 self.chs_initial_subset, # wrapper setup 285 self.chs_max_time, 286 self.chs_min_channels, 287 self.chs_max_channels, 288 self.chs_performance_delta, # stopping criterion 289 self.chs_n_jobs, 290 ) # njobs, output messages 291 292 preds = channel_selection_results.best_preds 293 accuracy = channel_selection_results.best_accuracy 294 precision = channel_selection_results.best_precision 295 recall = channel_selection_results.best_recall 296 297 logger.info( 298 "The optimal subset is %s", 299 channel_selection_results.best_channel_subset, 300 ) 301 302 self.results_df = channel_selection_results.results_df 303 self.subset = channel_selection_results.best_channel_subset 304 self.clf = channel_selection_results.best_model 305 else: 306 logger.warning("Not doing channel selection") 307 current_results = __erp_single_channel_kernel(self.X, self.y) 308 self.clf = current_results.model 309 preds = current_results.cv_preds 310 accuracy = current_results.accuracy 311 precision = current_results.precision 312 recall = current_results.recall 313 314 # Log performance stats 315 # accuracy 316 accuracy = sum(preds == self.y) / len(preds) 317 self.offline_accuracy = accuracy 318 logger.info("Accuracy = %s", accuracy) 319 320 # precision 321 precision = precision_score(self.y, preds) 322 self.offline_precision = precision 323 logger.info("Precision = %s", precision) 324 325 # recall 326 recall = recall_score(self.y, preds) 327 self.offline_recall = recall 328 logger.info("Recall = %s", recall) 329 330 # confusion matrix in command line 331 cm = confusion_matrix(self.y, preds) 332 self.offline_cm = cm 333 logger.info("Confusion matrix:\n%s", cm) 334 335 if plot_cm: 336 cm = confusion_matrix(self.y, preds) 337 ConfusionMatrixDisplay(cm).plot() 338 plt.show() 339 340 if plot_roc: 341 logger.error("ROC plot has not been implemented yet") 342 343 def predict(self, X): 344 """Predict the class of the data (Unused in this classifier) 345 346 Parameters 347 ---------- 348 X : numpy.ndarray 349 3D array where shape = (n_trials, n_channels, n_samples) 350 351 Returns 352 ------- 353 prediction : Prediction 354 Empty Predict object 355 356 """ 357 358 return Prediction()
logger =
<bci_essentials.utils.logger.Logger object>
class
ErpSingleChannelClassifier(bci_essentials.classification.generic_classifier.GenericClassifier):
34class ErpSingleChannelClassifier(GenericClassifier): 35 """ERP Single Channel Classifier class (*inherits from `GenericClassifier`*).""" 36 37 def set_p300_clf_settings( 38 self, 39 n_splits=3, 40 lico_expansion_factor=1, 41 oversample_ratio=0, 42 undersample_ratio=0, 43 random_seed=42, 44 ): 45 """Set P300 Classifier Settings. 46 47 Parameters 48 ---------- 49 n_splits : int, *optional* 50 Number of folds for cross-validation. 51 - Default is `3`. 52 lico_expansion_factor : int, *optional* 53 Linear Combination Oversampling expansion factor, which is the 54 factor by which the number of ERPs in the training set will be 55 expanded. 56 - Default is `1`. 57 oversample_ratio : float, *optional* 58 Traditional oversampling. Range is from from 0.1-1 resulting 59 from the ratio of erp to non-erp class. 0 for no oversampling. 60 - Default is `0`. 61 undersample_ratio : float, *optional* 62 Traditional undersampling. Range is from from 0.1-1 resulting 63 from the ratio of erp to non-erp class. 0 for no undersampling. 64 - Default is `0`. 65 random_seed : int, *optional* 66 Random seed. 67 - Default is `42`. 68 69 Returns 70 ------- 71 `None` 72 73 """ 74 self.n_splits = n_splits 75 self.lico_expansion_factor = lico_expansion_factor 76 self.oversample_ratio = oversample_ratio 77 self.undersample_ratio = undersample_ratio 78 self.random_seed = random_seed 79 80 def fit( 81 self, 82 plot_cm=False, 83 plot_roc=False, 84 lico_expansion_factor=1, 85 ): 86 """Fit the model. 87 88 Parameters 89 ---------- 90 plot_cm : bool, *optional* 91 Whether to plot the confusion matrix during training. 92 - Default is `False`. 93 plot_roc : bool, *optional* 94 Whether to plot the ROC curve during training. 95 - Default is `False`. 96 lico_expansion_factor : int, *optional* 97 Linear combination oversampling expansion factor. 98 Determines the number of ERPs in the training set that will be expanded. 99 Higher value increases the oversampling, generating more synthetic 100 samples for the minority class. 101 - Default is `1`. 102 103 Returns 104 ------- 105 `None` 106 Models created used in `predict()`. 107 108 """ 109 110 logger.info("Fitting the model using sLDA") 111 logger.info("X shape: %s", self.X.shape) 112 logger.info("y shape: %s", self.y.shape) 113 114 # Define the strategy for cross validation 115 cv = StratifiedKFold( 116 n_splits=self.n_splits, shuffle=True, random_state=self.random_seed 117 ) 118 119 # Define the classifier 120 self.clf = make_pipeline( 121 ReduceToSingleChannel(), 122 LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"), 123 ) 124 125 # Init predictions to all false 126 cv_preds = np.zeros(len(self.y)) 127 128 # 129 def __erp_single_channel_kernel(X, y): 130 """ERP Single Channel kernel. 131 132 Parameters 133 ---------- 134 X : numpy.ndarray 135 Input features (ERP data) for training. 136 3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`). 137 E.g. (100, 1, 1000) for 100 trials, 1 channel and 1000 samples. 138 139 y : numpy.ndarray 140 Target labels corresponding to the input features in `X`. 141 1D numpy array with shape (n_trails, ). 142 Each label indicates the class of the corresponding trial in `X`. 143 E.g. (100, ) for 100 trials. 144 145 146 Returns 147 ------- 148 kernelResults : KernelResults 149 KernelResults object containing the following attributes: 150 model : classifier 151 The trained classification model. 152 cv_preds : numpy.ndarray 153 The predictions from the model using cross validation. 154 1D array with the same shape as `y`. 155 accuracy : float 156 The accuracy of the trained classification model. 157 precision : float 158 The precision of the trained classification model. 159 recall : float 160 The recall of the trained classification model. 161 162 """ 163 logger.info("X shape: %s", X.shape) 164 165 for train_idx, test_idx in cv.split(X, y): 166 y_train, y_test = y[train_idx], y[test_idx] 167 168 X_train, X_test = X[train_idx], X[test_idx] 169 170 # LICO 171 logger.debug( 172 "Before LICO:\n\tShape X: %s\n\tShape y: %s", 173 X_train.shape, 174 y_train.shape, 175 ) 176 177 if sum(y_train) > 2: 178 if lico_expansion_factor > 1: 179 X_train, y_train = lico( 180 X_train, 181 y_train, 182 expansion_factor=lico_expansion_factor, 183 sum_num=2, 184 shuffle=False, 185 ) 186 logger.debug("y_train = %s", y_train) 187 188 logger.debug( 189 "After LICO:\n\tShape X: %s\n\tShape y: %s", 190 X_train.shape, 191 y_train.shape, 192 ) 193 194 # Oversampling 195 if self.oversample_ratio > 0: 196 p_count = sum(y_train) 197 n_count = len(y_train) - sum(y_train) 198 199 num_to_add = int( 200 np.floor((self.oversample_ratio * n_count) - p_count) 201 ) 202 203 # Add num_to_add random selections from the positive 204 true_X_train = X_train[y_train == 1] 205 206 len_X_train = len(true_X_train) 207 208 for s in range(num_to_add): 209 to_add_X = true_X_train[random.randrange(0, len_X_train), :, :] 210 211 X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0) 212 y_train = np.append(y_train, [1], axis=0) 213 214 # Undersampling 215 if self.undersample_ratio > 0: 216 p_count = sum(y_train) 217 n_count = len(y_train) - sum(y_train) 218 219 num_to_remove = int( 220 np.floor(n_count - (p_count / self.undersample_ratio)) 221 ) 222 223 ind_range = np.arange(len(y_train)) 224 ind_list = list(ind_range) 225 to_remove = [] 226 227 # Remove num_to_remove random selections from the negative 228 false_ind = list(ind_range[y_train == 0]) 229 230 for s in range(num_to_remove): 231 # select a random value from the list of false indices 232 remove_at = false_ind[random.randrange(0, len(false_ind))] 233 234 # remove that value from the false ind list 235 false_ind.remove(remove_at) 236 237 # add the index to be removed to a list 238 to_remove.append(remove_at) 239 240 remaining_ind = ind_list 241 for i in range(len(to_remove)): 242 remaining_ind.remove(to_remove[i]) 243 244 X_train = X_train[remaining_ind, :, :] 245 y_train = y_train[remaining_ind] 246 247 self.clf.fit(X_train, y_train) 248 cv_preds[test_idx] = self.clf.predict(X_test) 249 predproba = self.clf.predict_proba(X_test) 250 251 # Use pred proba to show what would be predicted 252 predprobs = predproba[:, 1] 253 real = np.where(y_test == 1) 254 255 # TODO handle exception where two probabilities are the same 256 prediction = int(np.where(predprobs == np.amax(predprobs))[0][0]) 257 258 logger.debug("y_test = %s", y_test) 259 logger.debug("predproba = %s", predproba) 260 logger.debug("real = %s", real[0]) 261 logger.debug("prediction = %s", prediction) 262 263 # Train final model with all available data 264 self.clf.fit(X, y) 265 model = self.clf 266 267 accuracy = sum(cv_preds == self.y) / len(cv_preds) 268 precision = precision_score(self.y, cv_preds) 269 recall = recall_score(self.y, cv_preds) 270 271 return KernelResults(model, cv_preds, accuracy, precision, recall) 272 273 # Check if channel selection is true 274 if self.channel_selection_setup: 275 logger.info("Doing channel selection") 276 logger.debug("Initial subset: %s", self.chs_initial_subset) 277 278 channel_selection_results = channel_selection_by_method( 279 __erp_single_channel_kernel, 280 self.X, 281 self.y, 282 self.channel_labels, # kernel setup 283 self.chs_method, 284 self.chs_metric, 285 self.chs_initial_subset, # wrapper setup 286 self.chs_max_time, 287 self.chs_min_channels, 288 self.chs_max_channels, 289 self.chs_performance_delta, # stopping criterion 290 self.chs_n_jobs, 291 ) # njobs, output messages 292 293 preds = channel_selection_results.best_preds 294 accuracy = channel_selection_results.best_accuracy 295 precision = channel_selection_results.best_precision 296 recall = channel_selection_results.best_recall 297 298 logger.info( 299 "The optimal subset is %s", 300 channel_selection_results.best_channel_subset, 301 ) 302 303 self.results_df = channel_selection_results.results_df 304 self.subset = channel_selection_results.best_channel_subset 305 self.clf = channel_selection_results.best_model 306 else: 307 logger.warning("Not doing channel selection") 308 current_results = __erp_single_channel_kernel(self.X, self.y) 309 self.clf = current_results.model 310 preds = current_results.cv_preds 311 accuracy = current_results.accuracy 312 precision = current_results.precision 313 recall = current_results.recall 314 315 # Log performance stats 316 # accuracy 317 accuracy = sum(preds == self.y) / len(preds) 318 self.offline_accuracy = accuracy 319 logger.info("Accuracy = %s", accuracy) 320 321 # precision 322 precision = precision_score(self.y, preds) 323 self.offline_precision = precision 324 logger.info("Precision = %s", precision) 325 326 # recall 327 recall = recall_score(self.y, preds) 328 self.offline_recall = recall 329 logger.info("Recall = %s", recall) 330 331 # confusion matrix in command line 332 cm = confusion_matrix(self.y, preds) 333 self.offline_cm = cm 334 logger.info("Confusion matrix:\n%s", cm) 335 336 if plot_cm: 337 cm = confusion_matrix(self.y, preds) 338 ConfusionMatrixDisplay(cm).plot() 339 plt.show() 340 341 if plot_roc: 342 logger.error("ROC plot has not been implemented yet") 343 344 def predict(self, X): 345 """Predict the class of the data (Unused in this classifier) 346 347 Parameters 348 ---------- 349 X : numpy.ndarray 350 3D array where shape = (n_trials, n_channels, n_samples) 351 352 Returns 353 ------- 354 prediction : Prediction 355 Empty Predict object 356 357 """ 358 359 return Prediction()
ERP Single Channel Classifier class (inherits from GenericClassifier).
def
set_p300_clf_settings( self, n_splits=3, lico_expansion_factor=1, oversample_ratio=0, undersample_ratio=0, random_seed=42):
37 def set_p300_clf_settings( 38 self, 39 n_splits=3, 40 lico_expansion_factor=1, 41 oversample_ratio=0, 42 undersample_ratio=0, 43 random_seed=42, 44 ): 45 """Set P300 Classifier Settings. 46 47 Parameters 48 ---------- 49 n_splits : int, *optional* 50 Number of folds for cross-validation. 51 - Default is `3`. 52 lico_expansion_factor : int, *optional* 53 Linear Combination Oversampling expansion factor, which is the 54 factor by which the number of ERPs in the training set will be 55 expanded. 56 - Default is `1`. 57 oversample_ratio : float, *optional* 58 Traditional oversampling. Range is from from 0.1-1 resulting 59 from the ratio of erp to non-erp class. 0 for no oversampling. 60 - Default is `0`. 61 undersample_ratio : float, *optional* 62 Traditional undersampling. Range is from from 0.1-1 resulting 63 from the ratio of erp to non-erp class. 0 for no undersampling. 64 - Default is `0`. 65 random_seed : int, *optional* 66 Random seed. 67 - Default is `42`. 68 69 Returns 70 ------- 71 `None` 72 73 """ 74 self.n_splits = n_splits 75 self.lico_expansion_factor = lico_expansion_factor 76 self.oversample_ratio = oversample_ratio 77 self.undersample_ratio = undersample_ratio 78 self.random_seed = random_seed
Set P300 Classifier Settings.
Parameters
- n_splits (int, optional):
Number of folds for cross-validation.
- Default is
3.
- Default is
- lico_expansion_factor (int, optional):
Linear Combination Oversampling expansion factor, which is the
factor by which the number of ERPs in the training set will be
expanded.
- Default is
1.
- Default is
- oversample_ratio (float, optional):
Traditional oversampling. Range is from from 0.1-1 resulting
from the ratio of erp to non-erp class. 0 for no oversampling.
- Default is
0.
- Default is
- undersample_ratio (float, optional):
Traditional undersampling. Range is from from 0.1-1 resulting
from the ratio of erp to non-erp class. 0 for no undersampling.
- Default is
0.
- Default is
- random_seed (int, optional):
Random seed.
- Default is
42.
- Default is
Returns
None
def
fit(self, plot_cm=False, plot_roc=False, lico_expansion_factor=1):
80 def fit( 81 self, 82 plot_cm=False, 83 plot_roc=False, 84 lico_expansion_factor=1, 85 ): 86 """Fit the model. 87 88 Parameters 89 ---------- 90 plot_cm : bool, *optional* 91 Whether to plot the confusion matrix during training. 92 - Default is `False`. 93 plot_roc : bool, *optional* 94 Whether to plot the ROC curve during training. 95 - Default is `False`. 96 lico_expansion_factor : int, *optional* 97 Linear combination oversampling expansion factor. 98 Determines the number of ERPs in the training set that will be expanded. 99 Higher value increases the oversampling, generating more synthetic 100 samples for the minority class. 101 - Default is `1`. 102 103 Returns 104 ------- 105 `None` 106 Models created used in `predict()`. 107 108 """ 109 110 logger.info("Fitting the model using sLDA") 111 logger.info("X shape: %s", self.X.shape) 112 logger.info("y shape: %s", self.y.shape) 113 114 # Define the strategy for cross validation 115 cv = StratifiedKFold( 116 n_splits=self.n_splits, shuffle=True, random_state=self.random_seed 117 ) 118 119 # Define the classifier 120 self.clf = make_pipeline( 121 ReduceToSingleChannel(), 122 LinearDiscriminantAnalysis(solver="eigen", shrinkage="auto"), 123 ) 124 125 # Init predictions to all false 126 cv_preds = np.zeros(len(self.y)) 127 128 # 129 def __erp_single_channel_kernel(X, y): 130 """ERP Single Channel kernel. 131 132 Parameters 133 ---------- 134 X : numpy.ndarray 135 Input features (ERP data) for training. 136 3D numpy array with shape = (`n_trials`, `n_channels`, `n_samples`). 137 E.g. (100, 1, 1000) for 100 trials, 1 channel and 1000 samples. 138 139 y : numpy.ndarray 140 Target labels corresponding to the input features in `X`. 141 1D numpy array with shape (n_trails, ). 142 Each label indicates the class of the corresponding trial in `X`. 143 E.g. (100, ) for 100 trials. 144 145 146 Returns 147 ------- 148 kernelResults : KernelResults 149 KernelResults object containing the following attributes: 150 model : classifier 151 The trained classification model. 152 cv_preds : numpy.ndarray 153 The predictions from the model using cross validation. 154 1D array with the same shape as `y`. 155 accuracy : float 156 The accuracy of the trained classification model. 157 precision : float 158 The precision of the trained classification model. 159 recall : float 160 The recall of the trained classification model. 161 162 """ 163 logger.info("X shape: %s", X.shape) 164 165 for train_idx, test_idx in cv.split(X, y): 166 y_train, y_test = y[train_idx], y[test_idx] 167 168 X_train, X_test = X[train_idx], X[test_idx] 169 170 # LICO 171 logger.debug( 172 "Before LICO:\n\tShape X: %s\n\tShape y: %s", 173 X_train.shape, 174 y_train.shape, 175 ) 176 177 if sum(y_train) > 2: 178 if lico_expansion_factor > 1: 179 X_train, y_train = lico( 180 X_train, 181 y_train, 182 expansion_factor=lico_expansion_factor, 183 sum_num=2, 184 shuffle=False, 185 ) 186 logger.debug("y_train = %s", y_train) 187 188 logger.debug( 189 "After LICO:\n\tShape X: %s\n\tShape y: %s", 190 X_train.shape, 191 y_train.shape, 192 ) 193 194 # Oversampling 195 if self.oversample_ratio > 0: 196 p_count = sum(y_train) 197 n_count = len(y_train) - sum(y_train) 198 199 num_to_add = int( 200 np.floor((self.oversample_ratio * n_count) - p_count) 201 ) 202 203 # Add num_to_add random selections from the positive 204 true_X_train = X_train[y_train == 1] 205 206 len_X_train = len(true_X_train) 207 208 for s in range(num_to_add): 209 to_add_X = true_X_train[random.randrange(0, len_X_train), :, :] 210 211 X_train = np.append(X_train, to_add_X[np.newaxis, :], axis=0) 212 y_train = np.append(y_train, [1], axis=0) 213 214 # Undersampling 215 if self.undersample_ratio > 0: 216 p_count = sum(y_train) 217 n_count = len(y_train) - sum(y_train) 218 219 num_to_remove = int( 220 np.floor(n_count - (p_count / self.undersample_ratio)) 221 ) 222 223 ind_range = np.arange(len(y_train)) 224 ind_list = list(ind_range) 225 to_remove = [] 226 227 # Remove num_to_remove random selections from the negative 228 false_ind = list(ind_range[y_train == 0]) 229 230 for s in range(num_to_remove): 231 # select a random value from the list of false indices 232 remove_at = false_ind[random.randrange(0, len(false_ind))] 233 234 # remove that value from the false ind list 235 false_ind.remove(remove_at) 236 237 # add the index to be removed to a list 238 to_remove.append(remove_at) 239 240 remaining_ind = ind_list 241 for i in range(len(to_remove)): 242 remaining_ind.remove(to_remove[i]) 243 244 X_train = X_train[remaining_ind, :, :] 245 y_train = y_train[remaining_ind] 246 247 self.clf.fit(X_train, y_train) 248 cv_preds[test_idx] = self.clf.predict(X_test) 249 predproba = self.clf.predict_proba(X_test) 250 251 # Use pred proba to show what would be predicted 252 predprobs = predproba[:, 1] 253 real = np.where(y_test == 1) 254 255 # TODO handle exception where two probabilities are the same 256 prediction = int(np.where(predprobs == np.amax(predprobs))[0][0]) 257 258 logger.debug("y_test = %s", y_test) 259 logger.debug("predproba = %s", predproba) 260 logger.debug("real = %s", real[0]) 261 logger.debug("prediction = %s", prediction) 262 263 # Train final model with all available data 264 self.clf.fit(X, y) 265 model = self.clf 266 267 accuracy = sum(cv_preds == self.y) / len(cv_preds) 268 precision = precision_score(self.y, cv_preds) 269 recall = recall_score(self.y, cv_preds) 270 271 return KernelResults(model, cv_preds, accuracy, precision, recall) 272 273 # Check if channel selection is true 274 if self.channel_selection_setup: 275 logger.info("Doing channel selection") 276 logger.debug("Initial subset: %s", self.chs_initial_subset) 277 278 channel_selection_results = channel_selection_by_method( 279 __erp_single_channel_kernel, 280 self.X, 281 self.y, 282 self.channel_labels, # kernel setup 283 self.chs_method, 284 self.chs_metric, 285 self.chs_initial_subset, # wrapper setup 286 self.chs_max_time, 287 self.chs_min_channels, 288 self.chs_max_channels, 289 self.chs_performance_delta, # stopping criterion 290 self.chs_n_jobs, 291 ) # njobs, output messages 292 293 preds = channel_selection_results.best_preds 294 accuracy = channel_selection_results.best_accuracy 295 precision = channel_selection_results.best_precision 296 recall = channel_selection_results.best_recall 297 298 logger.info( 299 "The optimal subset is %s", 300 channel_selection_results.best_channel_subset, 301 ) 302 303 self.results_df = channel_selection_results.results_df 304 self.subset = channel_selection_results.best_channel_subset 305 self.clf = channel_selection_results.best_model 306 else: 307 logger.warning("Not doing channel selection") 308 current_results = __erp_single_channel_kernel(self.X, self.y) 309 self.clf = current_results.model 310 preds = current_results.cv_preds 311 accuracy = current_results.accuracy 312 precision = current_results.precision 313 recall = current_results.recall 314 315 # Log performance stats 316 # accuracy 317 accuracy = sum(preds == self.y) / len(preds) 318 self.offline_accuracy = accuracy 319 logger.info("Accuracy = %s", accuracy) 320 321 # precision 322 precision = precision_score(self.y, preds) 323 self.offline_precision = precision 324 logger.info("Precision = %s", precision) 325 326 # recall 327 recall = recall_score(self.y, preds) 328 self.offline_recall = recall 329 logger.info("Recall = %s", recall) 330 331 # confusion matrix in command line 332 cm = confusion_matrix(self.y, preds) 333 self.offline_cm = cm 334 logger.info("Confusion matrix:\n%s", cm) 335 336 if plot_cm: 337 cm = confusion_matrix(self.y, preds) 338 ConfusionMatrixDisplay(cm).plot() 339 plt.show() 340 341 if plot_roc: 342 logger.error("ROC plot has not been implemented yet")
Fit the model.
Parameters
- plot_cm (bool, optional):
Whether to plot the confusion matrix during training.
- Default is
False.
- Default is
- plot_roc (bool, optional):
Whether to plot the ROC curve during training.
- Default is
False.
- Default is
- lico_expansion_factor (int, optional):
Linear combination oversampling expansion factor.
Determines the number of ERPs in the training set that will be expanded.
Higher value increases the oversampling, generating more synthetic
samples for the minority class.
- Default is
1.
- Default is
Returns
None: Models created used inpredict().
def
predict(self, X):
344 def predict(self, X): 345 """Predict the class of the data (Unused in this classifier) 346 347 Parameters 348 ---------- 349 X : numpy.ndarray 350 3D array where shape = (n_trials, n_channels, n_samples) 351 352 Returns 353 ------- 354 prediction : Prediction 355 Empty Predict object 356 357 """ 358 359 return Prediction()
Predict the class of the data (Unused in this classifier)
Parameters
- X (numpy.ndarray): 3D array where shape = (n_trials, n_channels, n_samples)
Returns
- prediction (Prediction): Empty Predict object