bci_essentials.classification.generic_classifier
Generic classifier class for BCI Essentials
Used as Parent classifier class for other classifiers.
1"""**Generic classifier class for BCI Essentials** 2 3Used as Parent classifier class for other classifiers. 4 5""" 6 7# Stock libraries 8from abc import ABC, abstractmethod 9from dataclasses import dataclass, field 10 11import numpy as np 12from sklearn.pipeline import Pipeline 13 14from ..utils.logger import Logger # Logger wrapper 15 16# Instantiate a logger for the module at the default level of logging.INFO 17# Logs to bci_essentials.__module__) where __module__ is the name of the module 18logger = Logger(name=__name__) 19 20 21@dataclass 22class Prediction: 23 """Prediction data returned by GenericClassifer.predict() 24 25 labels : list 26 List of the predicted class labels. 27 - Default is `[]`. 28 29 probabilities : list 30 List of probabilities for each class label. If the classifier can't 31 provide probabilities, this will be an empty list `[]`. 32 - Default is `[]` 33 34 """ 35 36 labels: list = field(default_factory=list) 37 probabilities: list = field(default_factory=list) 38 39 40@dataclass 41class KernelResults: 42 """Dataclass to store output from the kernel methods 43 44 model : classifier 45 The trained classification model. 46 47 cv_preds : numpy.ndarray 48 The predictions from the model using cross-validation. 49 50 accuracy : float 51 The accuracy of the trained classification model. 52 53 precision : float 54 The precision of the trained classification model. 55 56 recall : float 57 The recall of the trained classification model. 58 """ 59 60 model: Pipeline = field(default=None) 61 cv_preds: np.ndarray = field(default_factory=np.ndarray) 62 accuracy: float = field(default=0.0) 63 precision: float = field(default=0.0) 64 recall: float = field(default=0.0) 65 66 67class GenericClassifier(ABC): 68 """The base generic classifier class for other classifiers.""" 69 70 def __init__(self, training_selection=0, subset=[]): 71 """Initializes `GenericClassifier` class. 72 73 Parameters 74 ---------- 75 training_selection : int, *optional* 76 Integer representing the object selected for training. 77 - Default is `0`. 78 subset : list of `int` or `str`, *optional* 79 List of indices (int) or labels (str) of the desired channels. 80 - Default is `[]`. 81 82 Attributes 83 ---------- 84 X : numpy.ndarray 85 Input features (training data). 86 3D numpy array with shape = (`n_samples`, `n_channels`, `n_trials`). 87 - Initial value is `np.ndarray([0])`. 88 y : numpy.ndarray 89 Target labels corresponding to input features in `X`. 90 1D numpy array with shape = (`n_samples`, ). 91 - Initial value is `np.ndarray([0])`. 92 subset_defined : bool 93 Flag indicating whether a subset is defined. 94 - Initial value is `False`. 95 subset : list of `int` or `str` 96 List of indices (int) or labels (str) of the desired channels. 97 - Initial value is parameter `subset`. 98 channel_labels : list of `str` 99 Channel labels from the entire EEG montage. 100 - Initial value is `[]`. 101 channel_selection_setup : bool 102 FLag indicating whether channel selection is set up. 103 - Initial value is `False`. 104 offline_accuracy : list of `float` 105 Stores offline accuracy values during training. 106 - Initial value is `[]`. 107 offline_precision : list of `float` 108 Stores offline precision values during training. 109 - Initial value is `[]`. 110 offline_recall : list of `float` 111 Stores offline recall values during training. 112 - Initial value is `[]`. 113 offline_trial_count : int 114 Counter to keep track of the number of offline trials 115 - Initial value is `0`. 116 offline_trial_counts : list of `int` 117 List to store the counts of offline trials. 118 i.e. `offline_trial_count' values. 119 - Initial value is `[]`. 120 next_fit_trial : int 121 Counter to track the next trial for fitting. 122 - Initial value is `0`. 123 predictions : list 124 Stores predications made during training or testing 125 - Initial value is `[]`. 126 pred_probas : list of `float` 127 List to store predication probabilities during testing. 128 - Initial value is `[]`. 129 n_splits : int 130 Number of splits for cross-validation. Also serves as minimum required samples per class for training when running _check_ready_for_fit(). 131 - Initial value is `5`. 132 133 """ 134 logger.info("Initializing the classifier") 135 self.X = np.ndarray([0]) 136 """@private (This is just for the API docs, to avoid double listing.""" 137 self.y = np.ndarray([0]) 138 """@private (This is just for the API docs, to avoid double listing.""" 139 140 self.subset_defined = False 141 """@private (This is just for the API docs, to avoid double listing.""" 142 self.subset = subset 143 """@private (This is just for the API docs, to avoid double listing.""" 144 self.channel_labels = [] 145 """@private (This is just for the API docs, to avoid double listing.""" 146 self.channel_selection_setup = False 147 """@private (This is just for the API docs, to avoid double listing.""" 148 149 # Lists for plotting classifier performance over time 150 self.offline_accuracy = [] 151 """@private (This is just for the API docs, to avoid double listing.""" 152 self.offline_precision = [] 153 """@private (This is just for the API docs, to avoid double listing.""" 154 self.offline_recall = [] 155 """@private (This is just for the API docs, to avoid double listing.""" 156 self.offline_trial_count = 0 157 """@private (This is just for the API docs, to avoid double listing.""" 158 self.offline_trial_counts = [] 159 """@private (This is just for the API docs, to avoid double listing.""" 160 161 # For iterative fitting, 162 self.next_fit_trial = 0 163 """@private (This is just for the API docs, to avoid double listing.""" 164 165 # Keep track of predictions 166 self.predictions = [] 167 """@private (This is just for the API docs, to avoid double listing.""" 168 self.pred_probas = [] 169 """@private (This is just for the API docs, to avoid double listing.""" 170 171 # N Splits 172 self.n_splits = 5 173 """@private (This is just for the API docs, to avoid double listing.""" 174 175 def _check_ready_for_fit(self): 176 """Check if sufficient data is available for fitting with cross-validation. 177 178 This method validates that: 179 1. Training data (X) exists 180 2. At least two classes are present in the labels 181 3. Each class has at least n_splits samples (required for k-fold cross-validation) 182 183 Returns 184 ------- 185 bool 186 Returns `True` if data meets all requirements for fitting, otherwise `False`. 187 """ 188 if self.X.size == 0: 189 logger.warning("No data available for fitting") 190 return False 191 192 unique_y = np.unique(self.y) 193 194 if len(unique_y) == 1: 195 logger.warning("Only one class available for fitting") 196 return False 197 198 class_counts = np.zeros(len(unique_y)) 199 for i, y in enumerate(unique_y): 200 class_counts[i] = np.sum(self.y == y) 201 202 # If n_splits is greater than the min number of samples in a class, return False 203 if np.min(class_counts) < self.n_splits: 204 # Future implementation: Report the class with the least number of samples 205 # Future implementation: Report the number of samples in each class 206 logger.warning( 207 "Need at least %s samples per class for cross-validation. Please collect more training data.", 208 self.n_splits, 209 ) 210 return False 211 212 return True 213 214 def get_subset(self, X=[], subset=[], channel_labels=[]): 215 """Get a subset of X according to labels or indices. 216 217 Parameters 218 ---------- 219 X : numpy.ndarray, *optional* 220 3D array containing data with `float` type. 221 222 shape = (`n_trials`,`n_channels`,`n_samples`) 223 - Default is `[]`. 224 subset : list of `int` or `str`, *optional* 225 List of indices (int) or labels (str) of the desired channels. 226 - Default is `[]`. 227 channel_labels : list of `str`, *optional* 228 Channel labels from the entire EEG montage. 229 - Default is `[]`. 230 231 Returns 232 ------- 233 X : numpy.ndarray 234 Subset of input `X` according to labels or indices. 235 3D array containing data with `float` type. 236 237 shape = (`n_trials`,`n_channels`,`n_samples`) 238 239 """ 240 241 # Check for self.subset and/or self.channel_labels 242 243 # Init 244 subset_indices = [] 245 246 # Copy the indices based on subset 247 try: 248 # Check if we can use subset indices 249 if self.subset == []: 250 return X 251 252 if type(self.subset[0]) is int: 253 logger.info("Using subset indices") 254 255 subset_indices = self.subset 256 self.subset_defined 257 258 # Or channel labels 259 if type(self.subset[0]) is str: 260 logger.info("Using channel labels and subset labels") 261 262 # Replace indices with those described by labels 263 for sl in self.subset: 264 subset_indices.append(self.channel_labels.index(sl)) 265 266 self.subset_defined = True 267 # Return for the given indices 268 try: 269 if sum(X.shape) == 0: 270 new_X = self.X[:, subset_indices, :] 271 self.X = new_X 272 else: 273 new_X = X[:, subset_indices, :] 274 X = new_X 275 return X 276 277 except Exception: 278 if sum(X.shape) == 0: 279 new_X = self.X[subset_indices, :] 280 self.X = new_X 281 282 else: 283 new_X = X[subset_indices, :] 284 X = new_X 285 return X 286 287 # notify if failed 288 except Exception: 289 logger.warning("something went wrong, no subset taken") 290 return X 291 292 def setup_channel_selection( 293 self, 294 method="SBS", 295 metric="accuracy", 296 iterative_selection=False, 297 initial_channels=[], # wrapper setup 298 max_time=999, 299 min_channels=1, 300 max_channels=999, 301 performance_delta=0.001, # stopping criterion 302 n_jobs=1, 303 record_performance=False, 304 ): 305 """Setup channel selection parameters. 306 307 Parameters 308 ---------- 309 method : str, *optional* 310 The method used to add or remove channels. 311 - Default is `"SBS"`. 312 metric : str, *optional* 313 The metric used to measure performance. 314 - Default is `"accuracy"`. 315 iterative_selection : bool, *optional* 316 Whether or not to use the previously selected subset for the initial subset. 317 Default is `False`. 318 initial_channels : type, *optional* 319 List of channels to use as initial subset for selection. 320 If empty, `initial_channels` is set to all available channels. 321 - Default is `[]`. 322 max_time : int, *optional* 323 Maximum time in seconds allowed for channel selection. 324 - Default is `999`. 325 min_channels : int, *optional* 326 Minimum number of channels to select during channel selection. 327 - Default is `1`. 328 max_channels : int, *optional* 329 Maximum number of channels allowed in the final subset. 330 - Default is `999`. 331 performance_delta : float, *optional* 332 Smallest performance increment to allow continue of the search. 333 - Default is `0.001`. 334 n_jobs : int, *optional* 335 The number of threads to dedicate to this calculation. 336 - Default is `1`. 337 record_performance : bool, *optional* 338 Decides whether or not to record performance of channel selection. 339 - Default is `False`. 340 341 Returns 342 ------- 343 `None` 344 345 """ 346 # Add these to settings later 347 if initial_channels == []: 348 self.chs_initial_subset = self.channel_labels 349 else: 350 self.chs_initial_subset = initial_channels 351 self.chs_method = method # method to add/remove channels 352 self.chs_metric = metric # metric by which to measure performance 353 self.chs_iterative_selection = iterative_selection # whether or not to use the previously selected subset for the initial subset 354 self.chs_n_jobs = n_jobs # number of threads 355 self.chs_max_time = max_time # max time in seconds 356 self.chs_min_channels = min_channels # minimum number of channels 357 self.chs_max_channels = max_channels # maximum number of channels 358 self.chs_performance_delta = performance_delta # smallest performance increment to justify continuing search 359 self.chs_record_performance = record_performance # record performance 360 361 self.channel_selection_setup = True 362 363 # add training data, to the training set using a decision block and a label 364 def add_to_train(self, decision_block, labels, num_options=0, meta=[]): 365 """Add training data to the training set using a decision block 366 and a label. 367 368 Parameters 369 ---------- 370 decision_block : numpy.ndarray 371 Decision block containing EEG data for training. 372 3D array with shape = (`n_epochs`, `n_channels`, `n_samples`). 373 labels : numpy.ndarray 374 Labels corresponding to each epoch in `decision_block`. 375 1D array with shape = (`n_epochs`, ). 376 num_options : int, *optional* 377 Number of options available for each trial. 378 - Default is `0`. 379 meta : list, *optional* 380 Additional metadata related to the training data. 381 - Default is `[]`. 382 383 Returns 384 ------- 385 `None` 386 387 """ 388 logger.debug("Adding to training set") 389 # n = number of channels 390 # m = number of samples 391 # p = number of epochs 392 p, n, m = decision_block.shape 393 394 self.num_options = num_options 395 self.meta = meta 396 397 if self.X.size == 0: 398 self.X = decision_block 399 self.y = labels 400 401 else: 402 self.X = np.append(self.X, decision_block, axis=0) 403 self.y = np.append(self.y, labels, axis=0) 404 405 @abstractmethod 406 def fit(self): 407 """Abstract method to fit classifier 408 409 Returns 410 ------- 411 `None` 412 413 """ 414 pass 415 416 @abstractmethod 417 def predict(self, X: np.ndarray) -> Prediction: 418 """Abstract method to predict with classifier 419 420 X : numpy.ndarray 421 3D array where shape = (trials, channels, samples) 422 423 Returns 424 ------- 425 prediction : Prediction 426 Results of predict call containing the predicted class labels, and 427 optionally the probabilities of the labels (empty list if not possible). 428 429 """ 430 pass
22@dataclass 23class Prediction: 24 """Prediction data returned by GenericClassifer.predict() 25 26 labels : list 27 List of the predicted class labels. 28 - Default is `[]`. 29 30 probabilities : list 31 List of probabilities for each class label. If the classifier can't 32 provide probabilities, this will be an empty list `[]`. 33 - Default is `[]` 34 35 """ 36 37 labels: list = field(default_factory=list) 38 probabilities: list = field(default_factory=list)
Prediction data returned by GenericClassifer.predict()
labels : list
List of the predicted class labels.
- Default is [].
probabilities : list
List of probabilities for each class label. If the classifier can't
provide probabilities, this will be an empty list [].
- Default is []
41@dataclass 42class KernelResults: 43 """Dataclass to store output from the kernel methods 44 45 model : classifier 46 The trained classification model. 47 48 cv_preds : numpy.ndarray 49 The predictions from the model using cross-validation. 50 51 accuracy : float 52 The accuracy of the trained classification model. 53 54 precision : float 55 The precision of the trained classification model. 56 57 recall : float 58 The recall of the trained classification model. 59 """ 60 61 model: Pipeline = field(default=None) 62 cv_preds: np.ndarray = field(default_factory=np.ndarray) 63 accuracy: float = field(default=0.0) 64 precision: float = field(default=0.0) 65 recall: float = field(default=0.0)
Dataclass to store output from the kernel methods
model : classifier The trained classification model.
cv_preds : numpy.ndarray The predictions from the model using cross-validation.
accuracy : float The accuracy of the trained classification model.
precision : float The precision of the trained classification model.
recall : float The recall of the trained classification model.
68class GenericClassifier(ABC): 69 """The base generic classifier class for other classifiers.""" 70 71 def __init__(self, training_selection=0, subset=[]): 72 """Initializes `GenericClassifier` class. 73 74 Parameters 75 ---------- 76 training_selection : int, *optional* 77 Integer representing the object selected for training. 78 - Default is `0`. 79 subset : list of `int` or `str`, *optional* 80 List of indices (int) or labels (str) of the desired channels. 81 - Default is `[]`. 82 83 Attributes 84 ---------- 85 X : numpy.ndarray 86 Input features (training data). 87 3D numpy array with shape = (`n_samples`, `n_channels`, `n_trials`). 88 - Initial value is `np.ndarray([0])`. 89 y : numpy.ndarray 90 Target labels corresponding to input features in `X`. 91 1D numpy array with shape = (`n_samples`, ). 92 - Initial value is `np.ndarray([0])`. 93 subset_defined : bool 94 Flag indicating whether a subset is defined. 95 - Initial value is `False`. 96 subset : list of `int` or `str` 97 List of indices (int) or labels (str) of the desired channels. 98 - Initial value is parameter `subset`. 99 channel_labels : list of `str` 100 Channel labels from the entire EEG montage. 101 - Initial value is `[]`. 102 channel_selection_setup : bool 103 FLag indicating whether channel selection is set up. 104 - Initial value is `False`. 105 offline_accuracy : list of `float` 106 Stores offline accuracy values during training. 107 - Initial value is `[]`. 108 offline_precision : list of `float` 109 Stores offline precision values during training. 110 - Initial value is `[]`. 111 offline_recall : list of `float` 112 Stores offline recall values during training. 113 - Initial value is `[]`. 114 offline_trial_count : int 115 Counter to keep track of the number of offline trials 116 - Initial value is `0`. 117 offline_trial_counts : list of `int` 118 List to store the counts of offline trials. 119 i.e. `offline_trial_count' values. 120 - Initial value is `[]`. 121 next_fit_trial : int 122 Counter to track the next trial for fitting. 123 - Initial value is `0`. 124 predictions : list 125 Stores predications made during training or testing 126 - Initial value is `[]`. 127 pred_probas : list of `float` 128 List to store predication probabilities during testing. 129 - Initial value is `[]`. 130 n_splits : int 131 Number of splits for cross-validation. Also serves as minimum required samples per class for training when running _check_ready_for_fit(). 132 - Initial value is `5`. 133 134 """ 135 logger.info("Initializing the classifier") 136 self.X = np.ndarray([0]) 137 """@private (This is just for the API docs, to avoid double listing.""" 138 self.y = np.ndarray([0]) 139 """@private (This is just for the API docs, to avoid double listing.""" 140 141 self.subset_defined = False 142 """@private (This is just for the API docs, to avoid double listing.""" 143 self.subset = subset 144 """@private (This is just for the API docs, to avoid double listing.""" 145 self.channel_labels = [] 146 """@private (This is just for the API docs, to avoid double listing.""" 147 self.channel_selection_setup = False 148 """@private (This is just for the API docs, to avoid double listing.""" 149 150 # Lists for plotting classifier performance over time 151 self.offline_accuracy = [] 152 """@private (This is just for the API docs, to avoid double listing.""" 153 self.offline_precision = [] 154 """@private (This is just for the API docs, to avoid double listing.""" 155 self.offline_recall = [] 156 """@private (This is just for the API docs, to avoid double listing.""" 157 self.offline_trial_count = 0 158 """@private (This is just for the API docs, to avoid double listing.""" 159 self.offline_trial_counts = [] 160 """@private (This is just for the API docs, to avoid double listing.""" 161 162 # For iterative fitting, 163 self.next_fit_trial = 0 164 """@private (This is just for the API docs, to avoid double listing.""" 165 166 # Keep track of predictions 167 self.predictions = [] 168 """@private (This is just for the API docs, to avoid double listing.""" 169 self.pred_probas = [] 170 """@private (This is just for the API docs, to avoid double listing.""" 171 172 # N Splits 173 self.n_splits = 5 174 """@private (This is just for the API docs, to avoid double listing.""" 175 176 def _check_ready_for_fit(self): 177 """Check if sufficient data is available for fitting with cross-validation. 178 179 This method validates that: 180 1. Training data (X) exists 181 2. At least two classes are present in the labels 182 3. Each class has at least n_splits samples (required for k-fold cross-validation) 183 184 Returns 185 ------- 186 bool 187 Returns `True` if data meets all requirements for fitting, otherwise `False`. 188 """ 189 if self.X.size == 0: 190 logger.warning("No data available for fitting") 191 return False 192 193 unique_y = np.unique(self.y) 194 195 if len(unique_y) == 1: 196 logger.warning("Only one class available for fitting") 197 return False 198 199 class_counts = np.zeros(len(unique_y)) 200 for i, y in enumerate(unique_y): 201 class_counts[i] = np.sum(self.y == y) 202 203 # If n_splits is greater than the min number of samples in a class, return False 204 if np.min(class_counts) < self.n_splits: 205 # Future implementation: Report the class with the least number of samples 206 # Future implementation: Report the number of samples in each class 207 logger.warning( 208 "Need at least %s samples per class for cross-validation. Please collect more training data.", 209 self.n_splits, 210 ) 211 return False 212 213 return True 214 215 def get_subset(self, X=[], subset=[], channel_labels=[]): 216 """Get a subset of X according to labels or indices. 217 218 Parameters 219 ---------- 220 X : numpy.ndarray, *optional* 221 3D array containing data with `float` type. 222 223 shape = (`n_trials`,`n_channels`,`n_samples`) 224 - Default is `[]`. 225 subset : list of `int` or `str`, *optional* 226 List of indices (int) or labels (str) of the desired channels. 227 - Default is `[]`. 228 channel_labels : list of `str`, *optional* 229 Channel labels from the entire EEG montage. 230 - Default is `[]`. 231 232 Returns 233 ------- 234 X : numpy.ndarray 235 Subset of input `X` according to labels or indices. 236 3D array containing data with `float` type. 237 238 shape = (`n_trials`,`n_channels`,`n_samples`) 239 240 """ 241 242 # Check for self.subset and/or self.channel_labels 243 244 # Init 245 subset_indices = [] 246 247 # Copy the indices based on subset 248 try: 249 # Check if we can use subset indices 250 if self.subset == []: 251 return X 252 253 if type(self.subset[0]) is int: 254 logger.info("Using subset indices") 255 256 subset_indices = self.subset 257 self.subset_defined 258 259 # Or channel labels 260 if type(self.subset[0]) is str: 261 logger.info("Using channel labels and subset labels") 262 263 # Replace indices with those described by labels 264 for sl in self.subset: 265 subset_indices.append(self.channel_labels.index(sl)) 266 267 self.subset_defined = True 268 # Return for the given indices 269 try: 270 if sum(X.shape) == 0: 271 new_X = self.X[:, subset_indices, :] 272 self.X = new_X 273 else: 274 new_X = X[:, subset_indices, :] 275 X = new_X 276 return X 277 278 except Exception: 279 if sum(X.shape) == 0: 280 new_X = self.X[subset_indices, :] 281 self.X = new_X 282 283 else: 284 new_X = X[subset_indices, :] 285 X = new_X 286 return X 287 288 # notify if failed 289 except Exception: 290 logger.warning("something went wrong, no subset taken") 291 return X 292 293 def setup_channel_selection( 294 self, 295 method="SBS", 296 metric="accuracy", 297 iterative_selection=False, 298 initial_channels=[], # wrapper setup 299 max_time=999, 300 min_channels=1, 301 max_channels=999, 302 performance_delta=0.001, # stopping criterion 303 n_jobs=1, 304 record_performance=False, 305 ): 306 """Setup channel selection parameters. 307 308 Parameters 309 ---------- 310 method : str, *optional* 311 The method used to add or remove channels. 312 - Default is `"SBS"`. 313 metric : str, *optional* 314 The metric used to measure performance. 315 - Default is `"accuracy"`. 316 iterative_selection : bool, *optional* 317 Whether or not to use the previously selected subset for the initial subset. 318 Default is `False`. 319 initial_channels : type, *optional* 320 List of channels to use as initial subset for selection. 321 If empty, `initial_channels` is set to all available channels. 322 - Default is `[]`. 323 max_time : int, *optional* 324 Maximum time in seconds allowed for channel selection. 325 - Default is `999`. 326 min_channels : int, *optional* 327 Minimum number of channels to select during channel selection. 328 - Default is `1`. 329 max_channels : int, *optional* 330 Maximum number of channels allowed in the final subset. 331 - Default is `999`. 332 performance_delta : float, *optional* 333 Smallest performance increment to allow continue of the search. 334 - Default is `0.001`. 335 n_jobs : int, *optional* 336 The number of threads to dedicate to this calculation. 337 - Default is `1`. 338 record_performance : bool, *optional* 339 Decides whether or not to record performance of channel selection. 340 - Default is `False`. 341 342 Returns 343 ------- 344 `None` 345 346 """ 347 # Add these to settings later 348 if initial_channels == []: 349 self.chs_initial_subset = self.channel_labels 350 else: 351 self.chs_initial_subset = initial_channels 352 self.chs_method = method # method to add/remove channels 353 self.chs_metric = metric # metric by which to measure performance 354 self.chs_iterative_selection = iterative_selection # whether or not to use the previously selected subset for the initial subset 355 self.chs_n_jobs = n_jobs # number of threads 356 self.chs_max_time = max_time # max time in seconds 357 self.chs_min_channels = min_channels # minimum number of channels 358 self.chs_max_channels = max_channels # maximum number of channels 359 self.chs_performance_delta = performance_delta # smallest performance increment to justify continuing search 360 self.chs_record_performance = record_performance # record performance 361 362 self.channel_selection_setup = True 363 364 # add training data, to the training set using a decision block and a label 365 def add_to_train(self, decision_block, labels, num_options=0, meta=[]): 366 """Add training data to the training set using a decision block 367 and a label. 368 369 Parameters 370 ---------- 371 decision_block : numpy.ndarray 372 Decision block containing EEG data for training. 373 3D array with shape = (`n_epochs`, `n_channels`, `n_samples`). 374 labels : numpy.ndarray 375 Labels corresponding to each epoch in `decision_block`. 376 1D array with shape = (`n_epochs`, ). 377 num_options : int, *optional* 378 Number of options available for each trial. 379 - Default is `0`. 380 meta : list, *optional* 381 Additional metadata related to the training data. 382 - Default is `[]`. 383 384 Returns 385 ------- 386 `None` 387 388 """ 389 logger.debug("Adding to training set") 390 # n = number of channels 391 # m = number of samples 392 # p = number of epochs 393 p, n, m = decision_block.shape 394 395 self.num_options = num_options 396 self.meta = meta 397 398 if self.X.size == 0: 399 self.X = decision_block 400 self.y = labels 401 402 else: 403 self.X = np.append(self.X, decision_block, axis=0) 404 self.y = np.append(self.y, labels, axis=0) 405 406 @abstractmethod 407 def fit(self): 408 """Abstract method to fit classifier 409 410 Returns 411 ------- 412 `None` 413 414 """ 415 pass 416 417 @abstractmethod 418 def predict(self, X: np.ndarray) -> Prediction: 419 """Abstract method to predict with classifier 420 421 X : numpy.ndarray 422 3D array where shape = (trials, channels, samples) 423 424 Returns 425 ------- 426 prediction : Prediction 427 Results of predict call containing the predicted class labels, and 428 optionally the probabilities of the labels (empty list if not possible). 429 430 """ 431 pass
The base generic classifier class for other classifiers.
71 def __init__(self, training_selection=0, subset=[]): 72 """Initializes `GenericClassifier` class. 73 74 Parameters 75 ---------- 76 training_selection : int, *optional* 77 Integer representing the object selected for training. 78 - Default is `0`. 79 subset : list of `int` or `str`, *optional* 80 List of indices (int) or labels (str) of the desired channels. 81 - Default is `[]`. 82 83 Attributes 84 ---------- 85 X : numpy.ndarray 86 Input features (training data). 87 3D numpy array with shape = (`n_samples`, `n_channels`, `n_trials`). 88 - Initial value is `np.ndarray([0])`. 89 y : numpy.ndarray 90 Target labels corresponding to input features in `X`. 91 1D numpy array with shape = (`n_samples`, ). 92 - Initial value is `np.ndarray([0])`. 93 subset_defined : bool 94 Flag indicating whether a subset is defined. 95 - Initial value is `False`. 96 subset : list of `int` or `str` 97 List of indices (int) or labels (str) of the desired channels. 98 - Initial value is parameter `subset`. 99 channel_labels : list of `str` 100 Channel labels from the entire EEG montage. 101 - Initial value is `[]`. 102 channel_selection_setup : bool 103 FLag indicating whether channel selection is set up. 104 - Initial value is `False`. 105 offline_accuracy : list of `float` 106 Stores offline accuracy values during training. 107 - Initial value is `[]`. 108 offline_precision : list of `float` 109 Stores offline precision values during training. 110 - Initial value is `[]`. 111 offline_recall : list of `float` 112 Stores offline recall values during training. 113 - Initial value is `[]`. 114 offline_trial_count : int 115 Counter to keep track of the number of offline trials 116 - Initial value is `0`. 117 offline_trial_counts : list of `int` 118 List to store the counts of offline trials. 119 i.e. `offline_trial_count' values. 120 - Initial value is `[]`. 121 next_fit_trial : int 122 Counter to track the next trial for fitting. 123 - Initial value is `0`. 124 predictions : list 125 Stores predications made during training or testing 126 - Initial value is `[]`. 127 pred_probas : list of `float` 128 List to store predication probabilities during testing. 129 - Initial value is `[]`. 130 n_splits : int 131 Number of splits for cross-validation. Also serves as minimum required samples per class for training when running _check_ready_for_fit(). 132 - Initial value is `5`. 133 134 """ 135 logger.info("Initializing the classifier") 136 self.X = np.ndarray([0]) 137 """@private (This is just for the API docs, to avoid double listing.""" 138 self.y = np.ndarray([0]) 139 """@private (This is just for the API docs, to avoid double listing.""" 140 141 self.subset_defined = False 142 """@private (This is just for the API docs, to avoid double listing.""" 143 self.subset = subset 144 """@private (This is just for the API docs, to avoid double listing.""" 145 self.channel_labels = [] 146 """@private (This is just for the API docs, to avoid double listing.""" 147 self.channel_selection_setup = False 148 """@private (This is just for the API docs, to avoid double listing.""" 149 150 # Lists for plotting classifier performance over time 151 self.offline_accuracy = [] 152 """@private (This is just for the API docs, to avoid double listing.""" 153 self.offline_precision = [] 154 """@private (This is just for the API docs, to avoid double listing.""" 155 self.offline_recall = [] 156 """@private (This is just for the API docs, to avoid double listing.""" 157 self.offline_trial_count = 0 158 """@private (This is just for the API docs, to avoid double listing.""" 159 self.offline_trial_counts = [] 160 """@private (This is just for the API docs, to avoid double listing.""" 161 162 # For iterative fitting, 163 self.next_fit_trial = 0 164 """@private (This is just for the API docs, to avoid double listing.""" 165 166 # Keep track of predictions 167 self.predictions = [] 168 """@private (This is just for the API docs, to avoid double listing.""" 169 self.pred_probas = [] 170 """@private (This is just for the API docs, to avoid double listing.""" 171 172 # N Splits 173 self.n_splits = 5 174 """@private (This is just for the API docs, to avoid double listing."""
Initializes GenericClassifier class.
Parameters
- training_selection (int, optional):
Integer representing the object selected for training.
- Default is
0.
- Default is
- subset (list of
intorstr, optional): List of indices (int) or labels (str) of the desired channels.- Default is
[].
- Default is
Attributes
- X (numpy.ndarray):
Input features (training data).
3D numpy array with shape = (
n_samples,n_channels,n_trials).- Initial value is
np.ndarray([0]).
- Initial value is
- y (numpy.ndarray):
Target labels corresponding to input features in
X. 1D numpy array with shape = (n_samples, ).- Initial value is
np.ndarray([0]).
- Initial value is
- subset_defined (bool):
Flag indicating whether a subset is defined.
- Initial value is
False.
- Initial value is
- subset (list of
intorstr): List of indices (int) or labels (str) of the desired channels.- Initial value is parameter
subset.
- Initial value is parameter
- channel_labels (list of
str): Channel labels from the entire EEG montage.- Initial value is
[].
- Initial value is
- channel_selection_setup (bool):
FLag indicating whether channel selection is set up.
- Initial value is
False.
- Initial value is
- offline_accuracy (list of
float): Stores offline accuracy values during training.- Initial value is
[].
- Initial value is
- offline_precision (list of
float): Stores offline precision values during training.- Initial value is
[].
- Initial value is
- offline_recall (list of
float): Stores offline recall values during training.- Initial value is
[].
- Initial value is
- offline_trial_count (int):
Counter to keep track of the number of offline trials
- Initial value is
0.
- Initial value is
- offline_trial_counts (list of
int): List to store the counts of offline trials. i.e. `offline_trial_count' values.- Initial value is
[].
- Initial value is
- next_fit_trial (int):
Counter to track the next trial for fitting.
- Initial value is
0.
- Initial value is
- predictions (list):
Stores predications made during training or testing
- Initial value is
[].
- Initial value is
- pred_probas (list of
float): List to store predication probabilities during testing.- Initial value is
[].
- Initial value is
- n_splits (int):
Number of splits for cross-validation. Also serves as minimum required samples per class for training when running _check_ready_for_fit().
- Initial value is
5.
- Initial value is
215 def get_subset(self, X=[], subset=[], channel_labels=[]): 216 """Get a subset of X according to labels or indices. 217 218 Parameters 219 ---------- 220 X : numpy.ndarray, *optional* 221 3D array containing data with `float` type. 222 223 shape = (`n_trials`,`n_channels`,`n_samples`) 224 - Default is `[]`. 225 subset : list of `int` or `str`, *optional* 226 List of indices (int) or labels (str) of the desired channels. 227 - Default is `[]`. 228 channel_labels : list of `str`, *optional* 229 Channel labels from the entire EEG montage. 230 - Default is `[]`. 231 232 Returns 233 ------- 234 X : numpy.ndarray 235 Subset of input `X` according to labels or indices. 236 3D array containing data with `float` type. 237 238 shape = (`n_trials`,`n_channels`,`n_samples`) 239 240 """ 241 242 # Check for self.subset and/or self.channel_labels 243 244 # Init 245 subset_indices = [] 246 247 # Copy the indices based on subset 248 try: 249 # Check if we can use subset indices 250 if self.subset == []: 251 return X 252 253 if type(self.subset[0]) is int: 254 logger.info("Using subset indices") 255 256 subset_indices = self.subset 257 self.subset_defined 258 259 # Or channel labels 260 if type(self.subset[0]) is str: 261 logger.info("Using channel labels and subset labels") 262 263 # Replace indices with those described by labels 264 for sl in self.subset: 265 subset_indices.append(self.channel_labels.index(sl)) 266 267 self.subset_defined = True 268 # Return for the given indices 269 try: 270 if sum(X.shape) == 0: 271 new_X = self.X[:, subset_indices, :] 272 self.X = new_X 273 else: 274 new_X = X[:, subset_indices, :] 275 X = new_X 276 return X 277 278 except Exception: 279 if sum(X.shape) == 0: 280 new_X = self.X[subset_indices, :] 281 self.X = new_X 282 283 else: 284 new_X = X[subset_indices, :] 285 X = new_X 286 return X 287 288 # notify if failed 289 except Exception: 290 logger.warning("something went wrong, no subset taken") 291 return X
Get a subset of X according to labels or indices.
Parameters
X (numpy.ndarray, optional): 3D array containing data with
floattype.shape = (
n_trials,n_channels,n_samples)- Default is
[].
- Default is
- subset (list of
intorstr, optional): List of indices (int) or labels (str) of the desired channels.- Default is
[].
- Default is
- channel_labels (list of
str, optional): Channel labels from the entire EEG montage.- Default is
[].
- Default is
Returns
X (numpy.ndarray): Subset of input
Xaccording to labels or indices. 3D array containing data withfloattype.shape = (
n_trials,n_channels,n_samples)
293 def setup_channel_selection( 294 self, 295 method="SBS", 296 metric="accuracy", 297 iterative_selection=False, 298 initial_channels=[], # wrapper setup 299 max_time=999, 300 min_channels=1, 301 max_channels=999, 302 performance_delta=0.001, # stopping criterion 303 n_jobs=1, 304 record_performance=False, 305 ): 306 """Setup channel selection parameters. 307 308 Parameters 309 ---------- 310 method : str, *optional* 311 The method used to add or remove channels. 312 - Default is `"SBS"`. 313 metric : str, *optional* 314 The metric used to measure performance. 315 - Default is `"accuracy"`. 316 iterative_selection : bool, *optional* 317 Whether or not to use the previously selected subset for the initial subset. 318 Default is `False`. 319 initial_channels : type, *optional* 320 List of channels to use as initial subset for selection. 321 If empty, `initial_channels` is set to all available channels. 322 - Default is `[]`. 323 max_time : int, *optional* 324 Maximum time in seconds allowed for channel selection. 325 - Default is `999`. 326 min_channels : int, *optional* 327 Minimum number of channels to select during channel selection. 328 - Default is `1`. 329 max_channels : int, *optional* 330 Maximum number of channels allowed in the final subset. 331 - Default is `999`. 332 performance_delta : float, *optional* 333 Smallest performance increment to allow continue of the search. 334 - Default is `0.001`. 335 n_jobs : int, *optional* 336 The number of threads to dedicate to this calculation. 337 - Default is `1`. 338 record_performance : bool, *optional* 339 Decides whether or not to record performance of channel selection. 340 - Default is `False`. 341 342 Returns 343 ------- 344 `None` 345 346 """ 347 # Add these to settings later 348 if initial_channels == []: 349 self.chs_initial_subset = self.channel_labels 350 else: 351 self.chs_initial_subset = initial_channels 352 self.chs_method = method # method to add/remove channels 353 self.chs_metric = metric # metric by which to measure performance 354 self.chs_iterative_selection = iterative_selection # whether or not to use the previously selected subset for the initial subset 355 self.chs_n_jobs = n_jobs # number of threads 356 self.chs_max_time = max_time # max time in seconds 357 self.chs_min_channels = min_channels # minimum number of channels 358 self.chs_max_channels = max_channels # maximum number of channels 359 self.chs_performance_delta = performance_delta # smallest performance increment to justify continuing search 360 self.chs_record_performance = record_performance # record performance 361 362 self.channel_selection_setup = True
Setup channel selection parameters.
Parameters
- method (str, optional):
The method used to add or remove channels.
- Default is
"SBS".
- Default is
- metric (str, optional):
The metric used to measure performance.
- Default is
"accuracy".
- Default is
- iterative_selection (bool, optional):
Whether or not to use the previously selected subset for the initial subset.
Default is
False. - initial_channels (type, optional):
List of channels to use as initial subset for selection.
If empty,
initial_channelsis set to all available channels.- Default is
[].
- Default is
- max_time (int, optional):
Maximum time in seconds allowed for channel selection.
- Default is
999.
- Default is
- min_channels (int, optional):
Minimum number of channels to select during channel selection.
- Default is
1.
- Default is
- max_channels (int, optional):
Maximum number of channels allowed in the final subset.
- Default is
999.
- Default is
- performance_delta (float, optional):
Smallest performance increment to allow continue of the search.
- Default is
0.001.
- Default is
- n_jobs (int, optional):
The number of threads to dedicate to this calculation.
- Default is
1.
- Default is
- record_performance (bool, optional):
Decides whether or not to record performance of channel selection.
- Default is
False.
- Default is
Returns
None
365 def add_to_train(self, decision_block, labels, num_options=0, meta=[]): 366 """Add training data to the training set using a decision block 367 and a label. 368 369 Parameters 370 ---------- 371 decision_block : numpy.ndarray 372 Decision block containing EEG data for training. 373 3D array with shape = (`n_epochs`, `n_channels`, `n_samples`). 374 labels : numpy.ndarray 375 Labels corresponding to each epoch in `decision_block`. 376 1D array with shape = (`n_epochs`, ). 377 num_options : int, *optional* 378 Number of options available for each trial. 379 - Default is `0`. 380 meta : list, *optional* 381 Additional metadata related to the training data. 382 - Default is `[]`. 383 384 Returns 385 ------- 386 `None` 387 388 """ 389 logger.debug("Adding to training set") 390 # n = number of channels 391 # m = number of samples 392 # p = number of epochs 393 p, n, m = decision_block.shape 394 395 self.num_options = num_options 396 self.meta = meta 397 398 if self.X.size == 0: 399 self.X = decision_block 400 self.y = labels 401 402 else: 403 self.X = np.append(self.X, decision_block, axis=0) 404 self.y = np.append(self.y, labels, axis=0)
Add training data to the training set using a decision block and a label.
Parameters
- decision_block (numpy.ndarray):
Decision block containing EEG data for training.
3D array with shape = (
n_epochs,n_channels,n_samples). - labels (numpy.ndarray):
Labels corresponding to each epoch in
decision_block. 1D array with shape = (n_epochs, ). - num_options (int, optional):
Number of options available for each trial.
- Default is
0.
- Default is
- meta (list, optional):
Additional metadata related to the training data.
- Default is
[].
- Default is
Returns
None
406 @abstractmethod 407 def fit(self): 408 """Abstract method to fit classifier 409 410 Returns 411 ------- 412 `None` 413 414 """ 415 pass
Abstract method to fit classifier
Returns
None
417 @abstractmethod 418 def predict(self, X: np.ndarray) -> Prediction: 419 """Abstract method to predict with classifier 420 421 X : numpy.ndarray 422 3D array where shape = (trials, channels, samples) 423 424 Returns 425 ------- 426 prediction : Prediction 427 Results of predict call containing the predicted class labels, and 428 optionally the probabilities of the labels (empty list if not possible). 429 430 """ 431 pass
Abstract method to predict with classifier
X : numpy.ndarray 3D array where shape = (trials, channels, samples)
Returns
- prediction (Prediction): Results of predict call containing the predicted class labels, and optionally the probabilities of the labels (empty list if not possible).