bci_essentials.classification.generic_classifier

Generic classifier class for BCI Essentials

Used as Parent classifier class for other classifiers.

  1"""**Generic classifier class for BCI Essentials**
  2
  3Used as Parent classifier class for other classifiers.
  4
  5"""
  6
  7# Stock libraries
  8from abc import ABC, abstractmethod
  9from dataclasses import dataclass, field
 10
 11import numpy as np
 12from sklearn.pipeline import Pipeline
 13
 14from ..utils.logger import Logger  # Logger wrapper
 15
 16# Instantiate a logger for the module at the default level of logging.INFO
 17# Logs to bci_essentials.__module__) where __module__ is the name of the module
 18logger = Logger(name=__name__)
 19
 20
 21@dataclass
 22class Prediction:
 23    """Prediction data returned by GenericClassifer.predict()
 24
 25    labels : list
 26        List of the predicted class labels.
 27        - Default is `[]`.
 28
 29    probabilities : list
 30        List of probabilities for each class label. If the classifier can't
 31        provide probabilities, this will be an empty list `[]`.
 32        - Default is `[]`
 33
 34    """
 35
 36    labels: list = field(default_factory=list)
 37    probabilities: list = field(default_factory=list)
 38
 39
 40@dataclass
 41class KernelResults:
 42    """Dataclass to store output from the kernel methods
 43
 44    model : classifier
 45        The trained classification model.
 46
 47    cv_preds : numpy.ndarray
 48        The predictions from the model using cross-validation.
 49
 50    accuracy : float
 51        The accuracy of the trained classification model.
 52
 53    precision : float
 54        The precision of the trained classification model.
 55
 56    recall : float
 57        The recall of the trained classification model.
 58    """
 59
 60    model: Pipeline = field(default=None)
 61    cv_preds: np.ndarray = field(default_factory=np.ndarray)
 62    accuracy: float = field(default=0.0)
 63    precision: float = field(default=0.0)
 64    recall: float = field(default=0.0)
 65
 66
 67class GenericClassifier(ABC):
 68    """The base generic classifier class for other classifiers."""
 69
 70    def __init__(self, training_selection=0, subset=[]):
 71        """Initializes `GenericClassifier` class.
 72
 73        Parameters
 74        ----------
 75        training_selection : int, *optional*
 76            Integer representing the object selected for training.
 77            - Default is `0`.
 78        subset : list of `int` or `str`, *optional*
 79            List of indices (int) or labels (str) of the desired channels.
 80            - Default is `[]`.
 81
 82        Attributes
 83        ----------
 84        X : numpy.ndarray
 85            Input features (training data).
 86            3D numpy array with shape = (`n_samples`, `n_channels`, `n_trials`).
 87            - Initial value is `np.ndarray([0])`.
 88        y : numpy.ndarray
 89            Target labels corresponding to input features in `X`.
 90            1D numpy array with shape = (`n_samples`, ).
 91            - Initial value is `np.ndarray([0])`.
 92        subset_defined : bool
 93            Flag indicating whether a subset is defined.
 94            - Initial value is `False`.
 95        subset : list of `int` or `str`
 96            List of indices (int) or labels (str) of the desired channels.
 97            - Initial value is parameter `subset`.
 98        channel_labels : list of `str`
 99            Channel labels from the entire EEG montage.
100            - Initial value is `[]`.
101        channel_selection_setup : bool
102            FLag indicating whether channel selection is set up.
103            - Initial value is `False`.
104        offline_accuracy : list of `float`
105            Stores offline accuracy values during training.
106            - Initial value is `[]`.
107        offline_precision : list of `float`
108            Stores offline precision values during training.
109            - Initial value is `[]`.
110        offline_recall : list of `float`
111            Stores offline recall values during training.
112            - Initial value is `[]`.
113        offline_trial_count : int
114            Counter to keep track of the number of offline trials
115            - Initial value is `0`.
116        offline_trial_counts : list of `int`
117            List to store the counts of offline trials.
118            i.e. `offline_trial_count' values.
119            - Initial value is `[]`.
120        next_fit_trial : int
121            Counter to track the next trial for fitting.
122            - Initial value is `0`.
123        predictions : list
124            Stores predications made during training or testing
125            - Initial value is `[]`.
126        pred_probas : list of `float`
127            List to store predication probabilities during testing.
128            - Initial value is `[]`.
129        n_splits : int
130            Number of splits for cross-validation. Also serves as minimum required samples per class for training when running _check_ready_for_fit().
131            - Initial value is `5`.
132
133        """
134        logger.info("Initializing the classifier")
135        self.X = np.ndarray([0])
136        """@private (This is just for the API docs, to avoid double listing."""
137        self.y = np.ndarray([0])
138        """@private (This is just for the API docs, to avoid double listing."""
139
140        self.subset_defined = False
141        """@private (This is just for the API docs, to avoid double listing."""
142        self.subset = subset
143        """@private (This is just for the API docs, to avoid double listing."""
144        self.channel_labels = []
145        """@private (This is just for the API docs, to avoid double listing."""
146        self.channel_selection_setup = False
147        """@private (This is just for the API docs, to avoid double listing."""
148
149        # Lists for plotting classifier performance over time
150        self.offline_accuracy = []
151        """@private (This is just for the API docs, to avoid double listing."""
152        self.offline_precision = []
153        """@private (This is just for the API docs, to avoid double listing."""
154        self.offline_recall = []
155        """@private (This is just for the API docs, to avoid double listing."""
156        self.offline_trial_count = 0
157        """@private (This is just for the API docs, to avoid double listing."""
158        self.offline_trial_counts = []
159        """@private (This is just for the API docs, to avoid double listing."""
160
161        # For iterative fitting,
162        self.next_fit_trial = 0
163        """@private (This is just for the API docs, to avoid double listing."""
164
165        # Keep track of predictions
166        self.predictions = []
167        """@private (This is just for the API docs, to avoid double listing."""
168        self.pred_probas = []
169        """@private (This is just for the API docs, to avoid double listing."""
170
171        # N Splits
172        self.n_splits = 5
173        """@private (This is just for the API docs, to avoid double listing."""
174
175    def _check_ready_for_fit(self):
176        """Check if sufficient data is available for fitting with cross-validation.
177
178        This method validates that:
179        1. Training data (X) exists
180        2. At least two classes are present in the labels
181        3. Each class has at least n_splits samples (required for k-fold cross-validation)
182
183        Returns
184        -------
185        bool
186            Returns `True` if data meets all requirements for fitting, otherwise `False`.
187        """
188        if self.X.size == 0:
189            logger.warning("No data available for fitting")
190            return False
191
192        unique_y = np.unique(self.y)
193
194        if len(unique_y) == 1:
195            logger.warning("Only one class available for fitting")
196            return False
197
198        class_counts = np.zeros(len(unique_y))
199        for i, y in enumerate(unique_y):
200            class_counts[i] = np.sum(self.y == y)
201
202        # If n_splits is greater than the min number of samples in a class, return False
203        if np.min(class_counts) < self.n_splits:
204            # Future implementation: Report the class with the least number of samples
205            # Future implementation: Report the number of samples in each class
206            logger.warning(
207                "Need at least %s samples per class for cross-validation. Please collect more training data.",
208                self.n_splits,
209            )
210            return False
211
212        return True
213
214    def get_subset(self, X=[], subset=[], channel_labels=[]):
215        """Get a subset of X according to labels or indices.
216
217        Parameters
218        ----------
219        X : numpy.ndarray, *optional*
220            3D array containing data with `float` type.
221
222            shape = (`n_trials`,`n_channels`,`n_samples`)
223            - Default is `[]`.
224        subset : list of `int` or `str`, *optional*
225            List of indices (int) or labels (str) of the desired channels.
226            - Default is `[]`.
227        channel_labels : list of `str`, *optional*
228            Channel labels from the entire EEG montage.
229            - Default is `[]`.
230
231        Returns
232        -------
233        X : numpy.ndarray
234            Subset of input `X` according to labels or indices.
235            3D array containing data with `float` type.
236
237            shape = (`n_trials`,`n_channels`,`n_samples`)
238
239        """
240
241        # Check for self.subset and/or self.channel_labels
242
243        # Init
244        subset_indices = []
245
246        # Copy the indices based on subset
247        try:
248            # Check if we can use subset indices
249            if self.subset == []:
250                return X
251
252            if type(self.subset[0]) is int:
253                logger.info("Using subset indices")
254
255                subset_indices = self.subset
256                self.subset_defined
257
258            # Or channel labels
259            if type(self.subset[0]) is str:
260                logger.info("Using channel labels and subset labels")
261
262                # Replace indices with those described by labels
263                for sl in self.subset:
264                    subset_indices.append(self.channel_labels.index(sl))
265
266                self.subset_defined = True
267            # Return for the given indices
268            try:
269                if sum(X.shape) == 0:
270                    new_X = self.X[:, subset_indices, :]
271                    self.X = new_X
272                else:
273                    new_X = X[:, subset_indices, :]
274                    X = new_X
275                    return X
276
277            except Exception:
278                if sum(X.shape) == 0:
279                    new_X = self.X[subset_indices, :]
280                    self.X = new_X
281
282                else:
283                    new_X = X[subset_indices, :]
284                    X = new_X
285                    return X
286
287        # notify if failed
288        except Exception:
289            logger.warning("something went wrong, no subset taken")
290            return X
291
292    def setup_channel_selection(
293        self,
294        method="SBS",
295        metric="accuracy",
296        iterative_selection=False,
297        initial_channels=[],  # wrapper setup
298        max_time=999,
299        min_channels=1,
300        max_channels=999,
301        performance_delta=0.001,  # stopping criterion
302        n_jobs=1,
303        record_performance=False,
304    ):
305        """Setup channel selection parameters.
306
307        Parameters
308        ----------
309        method : str, *optional*
310            The method used to add or remove channels.
311            - Default is `"SBS"`.
312        metric : str, *optional*
313            The metric used to measure performance.
314            - Default is `"accuracy"`.
315        iterative_selection : bool, *optional*
316            Whether or not to use the previously selected subset for the initial subset.
317            Default is `False`.
318        initial_channels : type, *optional*
319            List of channels to use as initial subset for selection.
320            If empty, `initial_channels` is set to all available channels.
321            - Default is `[]`.
322        max_time : int, *optional*
323            Maximum time in seconds allowed for channel selection.
324            - Default is `999`.
325        min_channels : int, *optional*
326            Minimum number of channels to select during channel selection.
327            - Default is `1`.
328        max_channels : int, *optional*
329            Maximum number of channels allowed in the final subset.
330            - Default is `999`.
331        performance_delta : float, *optional*
332            Smallest performance increment to allow continue of the search.
333            - Default is `0.001`.
334        n_jobs : int, *optional*
335            The number of threads to dedicate to this calculation.
336            - Default is `1`.
337        record_performance : bool, *optional*
338            Decides whether or not to record performance of channel selection.
339            - Default is `False`.
340
341        Returns
342        -------
343        `None`
344
345        """
346        # Add these to settings later
347        if initial_channels == []:
348            self.chs_initial_subset = self.channel_labels
349        else:
350            self.chs_initial_subset = initial_channels
351        self.chs_method = method  # method to add/remove channels
352        self.chs_metric = metric  # metric by which to measure performance
353        self.chs_iterative_selection = iterative_selection  # whether or not to use the previously selected subset for the initial subset
354        self.chs_n_jobs = n_jobs  # number of threads
355        self.chs_max_time = max_time  # max time in seconds
356        self.chs_min_channels = min_channels  # minimum number of channels
357        self.chs_max_channels = max_channels  # maximum number of channels
358        self.chs_performance_delta = performance_delta  # smallest performance increment to justify continuing search
359        self.chs_record_performance = record_performance  # record performance
360
361        self.channel_selection_setup = True
362
363    # add training data, to the training set using a decision block and a label
364    def add_to_train(self, decision_block, labels, num_options=0, meta=[]):
365        """Add training data to the training set using a decision block
366        and a label.
367
368        Parameters
369        ----------
370        decision_block : numpy.ndarray
371            Decision block containing EEG data for training.
372            3D array with shape = (`n_epochs`, `n_channels`, `n_samples`).
373        labels : numpy.ndarray
374            Labels corresponding to each epoch in `decision_block`.
375            1D array with shape = (`n_epochs`, ).
376        num_options : int, *optional*
377            Number of options available for each trial.
378            - Default is `0`.
379        meta : list, *optional*
380            Additional metadata related to the training data.
381            - Default is `[]`.
382
383        Returns
384        -------
385        `None`
386
387        """
388        logger.debug("Adding to training set")
389        # n = number of channels
390        # m = number of samples
391        # p = number of epochs
392        p, n, m = decision_block.shape
393
394        self.num_options = num_options
395        self.meta = meta
396
397        if self.X.size == 0:
398            self.X = decision_block
399            self.y = labels
400
401        else:
402            self.X = np.append(self.X, decision_block, axis=0)
403            self.y = np.append(self.y, labels, axis=0)
404
405    @abstractmethod
406    def fit(self):
407        """Abstract method to fit classifier
408
409        Returns
410        -------
411        `None`
412
413        """
414        pass
415
416    @abstractmethod
417    def predict(self, X: np.ndarray) -> Prediction:
418        """Abstract method to predict with classifier
419
420        X : numpy.ndarray
421            3D array where shape = (trials, channels, samples)
422
423        Returns
424        -------
425        prediction : Prediction
426            Results of predict call containing the predicted class labels, and
427            optionally the probabilities of the labels (empty list if not possible).
428
429        """
430        pass
@dataclass
class Prediction:
22@dataclass
23class Prediction:
24    """Prediction data returned by GenericClassifer.predict()
25
26    labels : list
27        List of the predicted class labels.
28        - Default is `[]`.
29
30    probabilities : list
31        List of probabilities for each class label. If the classifier can't
32        provide probabilities, this will be an empty list `[]`.
33        - Default is `[]`
34
35    """
36
37    labels: list = field(default_factory=list)
38    probabilities: list = field(default_factory=list)

Prediction data returned by GenericClassifer.predict()

labels : list List of the predicted class labels. - Default is [].

probabilities : list List of probabilities for each class label. If the classifier can't provide probabilities, this will be an empty list []. - Default is []

Prediction(labels: list = <factory>, probabilities: list = <factory>)
labels: list
probabilities: list
@dataclass
class KernelResults:
41@dataclass
42class KernelResults:
43    """Dataclass to store output from the kernel methods
44
45    model : classifier
46        The trained classification model.
47
48    cv_preds : numpy.ndarray
49        The predictions from the model using cross-validation.
50
51    accuracy : float
52        The accuracy of the trained classification model.
53
54    precision : float
55        The precision of the trained classification model.
56
57    recall : float
58        The recall of the trained classification model.
59    """
60
61    model: Pipeline = field(default=None)
62    cv_preds: np.ndarray = field(default_factory=np.ndarray)
63    accuracy: float = field(default=0.0)
64    precision: float = field(default=0.0)
65    recall: float = field(default=0.0)

Dataclass to store output from the kernel methods

model : classifier The trained classification model.

cv_preds : numpy.ndarray The predictions from the model using cross-validation.

accuracy : float The accuracy of the trained classification model.

precision : float The precision of the trained classification model.

recall : float The recall of the trained classification model.

KernelResults( model: sklearn.pipeline.Pipeline = None, cv_preds: numpy.ndarray = <factory>, accuracy: float = 0.0, precision: float = 0.0, recall: float = 0.0)
model: sklearn.pipeline.Pipeline = None
cv_preds: numpy.ndarray
accuracy: float = 0.0
precision: float = 0.0
recall: float = 0.0
class GenericClassifier(abc.ABC):
 68class GenericClassifier(ABC):
 69    """The base generic classifier class for other classifiers."""
 70
 71    def __init__(self, training_selection=0, subset=[]):
 72        """Initializes `GenericClassifier` class.
 73
 74        Parameters
 75        ----------
 76        training_selection : int, *optional*
 77            Integer representing the object selected for training.
 78            - Default is `0`.
 79        subset : list of `int` or `str`, *optional*
 80            List of indices (int) or labels (str) of the desired channels.
 81            - Default is `[]`.
 82
 83        Attributes
 84        ----------
 85        X : numpy.ndarray
 86            Input features (training data).
 87            3D numpy array with shape = (`n_samples`, `n_channels`, `n_trials`).
 88            - Initial value is `np.ndarray([0])`.
 89        y : numpy.ndarray
 90            Target labels corresponding to input features in `X`.
 91            1D numpy array with shape = (`n_samples`, ).
 92            - Initial value is `np.ndarray([0])`.
 93        subset_defined : bool
 94            Flag indicating whether a subset is defined.
 95            - Initial value is `False`.
 96        subset : list of `int` or `str`
 97            List of indices (int) or labels (str) of the desired channels.
 98            - Initial value is parameter `subset`.
 99        channel_labels : list of `str`
100            Channel labels from the entire EEG montage.
101            - Initial value is `[]`.
102        channel_selection_setup : bool
103            FLag indicating whether channel selection is set up.
104            - Initial value is `False`.
105        offline_accuracy : list of `float`
106            Stores offline accuracy values during training.
107            - Initial value is `[]`.
108        offline_precision : list of `float`
109            Stores offline precision values during training.
110            - Initial value is `[]`.
111        offline_recall : list of `float`
112            Stores offline recall values during training.
113            - Initial value is `[]`.
114        offline_trial_count : int
115            Counter to keep track of the number of offline trials
116            - Initial value is `0`.
117        offline_trial_counts : list of `int`
118            List to store the counts of offline trials.
119            i.e. `offline_trial_count' values.
120            - Initial value is `[]`.
121        next_fit_trial : int
122            Counter to track the next trial for fitting.
123            - Initial value is `0`.
124        predictions : list
125            Stores predications made during training or testing
126            - Initial value is `[]`.
127        pred_probas : list of `float`
128            List to store predication probabilities during testing.
129            - Initial value is `[]`.
130        n_splits : int
131            Number of splits for cross-validation. Also serves as minimum required samples per class for training when running _check_ready_for_fit().
132            - Initial value is `5`.
133
134        """
135        logger.info("Initializing the classifier")
136        self.X = np.ndarray([0])
137        """@private (This is just for the API docs, to avoid double listing."""
138        self.y = np.ndarray([0])
139        """@private (This is just for the API docs, to avoid double listing."""
140
141        self.subset_defined = False
142        """@private (This is just for the API docs, to avoid double listing."""
143        self.subset = subset
144        """@private (This is just for the API docs, to avoid double listing."""
145        self.channel_labels = []
146        """@private (This is just for the API docs, to avoid double listing."""
147        self.channel_selection_setup = False
148        """@private (This is just for the API docs, to avoid double listing."""
149
150        # Lists for plotting classifier performance over time
151        self.offline_accuracy = []
152        """@private (This is just for the API docs, to avoid double listing."""
153        self.offline_precision = []
154        """@private (This is just for the API docs, to avoid double listing."""
155        self.offline_recall = []
156        """@private (This is just for the API docs, to avoid double listing."""
157        self.offline_trial_count = 0
158        """@private (This is just for the API docs, to avoid double listing."""
159        self.offline_trial_counts = []
160        """@private (This is just for the API docs, to avoid double listing."""
161
162        # For iterative fitting,
163        self.next_fit_trial = 0
164        """@private (This is just for the API docs, to avoid double listing."""
165
166        # Keep track of predictions
167        self.predictions = []
168        """@private (This is just for the API docs, to avoid double listing."""
169        self.pred_probas = []
170        """@private (This is just for the API docs, to avoid double listing."""
171
172        # N Splits
173        self.n_splits = 5
174        """@private (This is just for the API docs, to avoid double listing."""
175
176    def _check_ready_for_fit(self):
177        """Check if sufficient data is available for fitting with cross-validation.
178
179        This method validates that:
180        1. Training data (X) exists
181        2. At least two classes are present in the labels
182        3. Each class has at least n_splits samples (required for k-fold cross-validation)
183
184        Returns
185        -------
186        bool
187            Returns `True` if data meets all requirements for fitting, otherwise `False`.
188        """
189        if self.X.size == 0:
190            logger.warning("No data available for fitting")
191            return False
192
193        unique_y = np.unique(self.y)
194
195        if len(unique_y) == 1:
196            logger.warning("Only one class available for fitting")
197            return False
198
199        class_counts = np.zeros(len(unique_y))
200        for i, y in enumerate(unique_y):
201            class_counts[i] = np.sum(self.y == y)
202
203        # If n_splits is greater than the min number of samples in a class, return False
204        if np.min(class_counts) < self.n_splits:
205            # Future implementation: Report the class with the least number of samples
206            # Future implementation: Report the number of samples in each class
207            logger.warning(
208                "Need at least %s samples per class for cross-validation. Please collect more training data.",
209                self.n_splits,
210            )
211            return False
212
213        return True
214
215    def get_subset(self, X=[], subset=[], channel_labels=[]):
216        """Get a subset of X according to labels or indices.
217
218        Parameters
219        ----------
220        X : numpy.ndarray, *optional*
221            3D array containing data with `float` type.
222
223            shape = (`n_trials`,`n_channels`,`n_samples`)
224            - Default is `[]`.
225        subset : list of `int` or `str`, *optional*
226            List of indices (int) or labels (str) of the desired channels.
227            - Default is `[]`.
228        channel_labels : list of `str`, *optional*
229            Channel labels from the entire EEG montage.
230            - Default is `[]`.
231
232        Returns
233        -------
234        X : numpy.ndarray
235            Subset of input `X` according to labels or indices.
236            3D array containing data with `float` type.
237
238            shape = (`n_trials`,`n_channels`,`n_samples`)
239
240        """
241
242        # Check for self.subset and/or self.channel_labels
243
244        # Init
245        subset_indices = []
246
247        # Copy the indices based on subset
248        try:
249            # Check if we can use subset indices
250            if self.subset == []:
251                return X
252
253            if type(self.subset[0]) is int:
254                logger.info("Using subset indices")
255
256                subset_indices = self.subset
257                self.subset_defined
258
259            # Or channel labels
260            if type(self.subset[0]) is str:
261                logger.info("Using channel labels and subset labels")
262
263                # Replace indices with those described by labels
264                for sl in self.subset:
265                    subset_indices.append(self.channel_labels.index(sl))
266
267                self.subset_defined = True
268            # Return for the given indices
269            try:
270                if sum(X.shape) == 0:
271                    new_X = self.X[:, subset_indices, :]
272                    self.X = new_X
273                else:
274                    new_X = X[:, subset_indices, :]
275                    X = new_X
276                    return X
277
278            except Exception:
279                if sum(X.shape) == 0:
280                    new_X = self.X[subset_indices, :]
281                    self.X = new_X
282
283                else:
284                    new_X = X[subset_indices, :]
285                    X = new_X
286                    return X
287
288        # notify if failed
289        except Exception:
290            logger.warning("something went wrong, no subset taken")
291            return X
292
293    def setup_channel_selection(
294        self,
295        method="SBS",
296        metric="accuracy",
297        iterative_selection=False,
298        initial_channels=[],  # wrapper setup
299        max_time=999,
300        min_channels=1,
301        max_channels=999,
302        performance_delta=0.001,  # stopping criterion
303        n_jobs=1,
304        record_performance=False,
305    ):
306        """Setup channel selection parameters.
307
308        Parameters
309        ----------
310        method : str, *optional*
311            The method used to add or remove channels.
312            - Default is `"SBS"`.
313        metric : str, *optional*
314            The metric used to measure performance.
315            - Default is `"accuracy"`.
316        iterative_selection : bool, *optional*
317            Whether or not to use the previously selected subset for the initial subset.
318            Default is `False`.
319        initial_channels : type, *optional*
320            List of channels to use as initial subset for selection.
321            If empty, `initial_channels` is set to all available channels.
322            - Default is `[]`.
323        max_time : int, *optional*
324            Maximum time in seconds allowed for channel selection.
325            - Default is `999`.
326        min_channels : int, *optional*
327            Minimum number of channels to select during channel selection.
328            - Default is `1`.
329        max_channels : int, *optional*
330            Maximum number of channels allowed in the final subset.
331            - Default is `999`.
332        performance_delta : float, *optional*
333            Smallest performance increment to allow continue of the search.
334            - Default is `0.001`.
335        n_jobs : int, *optional*
336            The number of threads to dedicate to this calculation.
337            - Default is `1`.
338        record_performance : bool, *optional*
339            Decides whether or not to record performance of channel selection.
340            - Default is `False`.
341
342        Returns
343        -------
344        `None`
345
346        """
347        # Add these to settings later
348        if initial_channels == []:
349            self.chs_initial_subset = self.channel_labels
350        else:
351            self.chs_initial_subset = initial_channels
352        self.chs_method = method  # method to add/remove channels
353        self.chs_metric = metric  # metric by which to measure performance
354        self.chs_iterative_selection = iterative_selection  # whether or not to use the previously selected subset for the initial subset
355        self.chs_n_jobs = n_jobs  # number of threads
356        self.chs_max_time = max_time  # max time in seconds
357        self.chs_min_channels = min_channels  # minimum number of channels
358        self.chs_max_channels = max_channels  # maximum number of channels
359        self.chs_performance_delta = performance_delta  # smallest performance increment to justify continuing search
360        self.chs_record_performance = record_performance  # record performance
361
362        self.channel_selection_setup = True
363
364    # add training data, to the training set using a decision block and a label
365    def add_to_train(self, decision_block, labels, num_options=0, meta=[]):
366        """Add training data to the training set using a decision block
367        and a label.
368
369        Parameters
370        ----------
371        decision_block : numpy.ndarray
372            Decision block containing EEG data for training.
373            3D array with shape = (`n_epochs`, `n_channels`, `n_samples`).
374        labels : numpy.ndarray
375            Labels corresponding to each epoch in `decision_block`.
376            1D array with shape = (`n_epochs`, ).
377        num_options : int, *optional*
378            Number of options available for each trial.
379            - Default is `0`.
380        meta : list, *optional*
381            Additional metadata related to the training data.
382            - Default is `[]`.
383
384        Returns
385        -------
386        `None`
387
388        """
389        logger.debug("Adding to training set")
390        # n = number of channels
391        # m = number of samples
392        # p = number of epochs
393        p, n, m = decision_block.shape
394
395        self.num_options = num_options
396        self.meta = meta
397
398        if self.X.size == 0:
399            self.X = decision_block
400            self.y = labels
401
402        else:
403            self.X = np.append(self.X, decision_block, axis=0)
404            self.y = np.append(self.y, labels, axis=0)
405
406    @abstractmethod
407    def fit(self):
408        """Abstract method to fit classifier
409
410        Returns
411        -------
412        `None`
413
414        """
415        pass
416
417    @abstractmethod
418    def predict(self, X: np.ndarray) -> Prediction:
419        """Abstract method to predict with classifier
420
421        X : numpy.ndarray
422            3D array where shape = (trials, channels, samples)
423
424        Returns
425        -------
426        prediction : Prediction
427            Results of predict call containing the predicted class labels, and
428            optionally the probabilities of the labels (empty list if not possible).
429
430        """
431        pass

The base generic classifier class for other classifiers.

GenericClassifier(training_selection=0, subset=[])
 71    def __init__(self, training_selection=0, subset=[]):
 72        """Initializes `GenericClassifier` class.
 73
 74        Parameters
 75        ----------
 76        training_selection : int, *optional*
 77            Integer representing the object selected for training.
 78            - Default is `0`.
 79        subset : list of `int` or `str`, *optional*
 80            List of indices (int) or labels (str) of the desired channels.
 81            - Default is `[]`.
 82
 83        Attributes
 84        ----------
 85        X : numpy.ndarray
 86            Input features (training data).
 87            3D numpy array with shape = (`n_samples`, `n_channels`, `n_trials`).
 88            - Initial value is `np.ndarray([0])`.
 89        y : numpy.ndarray
 90            Target labels corresponding to input features in `X`.
 91            1D numpy array with shape = (`n_samples`, ).
 92            - Initial value is `np.ndarray([0])`.
 93        subset_defined : bool
 94            Flag indicating whether a subset is defined.
 95            - Initial value is `False`.
 96        subset : list of `int` or `str`
 97            List of indices (int) or labels (str) of the desired channels.
 98            - Initial value is parameter `subset`.
 99        channel_labels : list of `str`
100            Channel labels from the entire EEG montage.
101            - Initial value is `[]`.
102        channel_selection_setup : bool
103            FLag indicating whether channel selection is set up.
104            - Initial value is `False`.
105        offline_accuracy : list of `float`
106            Stores offline accuracy values during training.
107            - Initial value is `[]`.
108        offline_precision : list of `float`
109            Stores offline precision values during training.
110            - Initial value is `[]`.
111        offline_recall : list of `float`
112            Stores offline recall values during training.
113            - Initial value is `[]`.
114        offline_trial_count : int
115            Counter to keep track of the number of offline trials
116            - Initial value is `0`.
117        offline_trial_counts : list of `int`
118            List to store the counts of offline trials.
119            i.e. `offline_trial_count' values.
120            - Initial value is `[]`.
121        next_fit_trial : int
122            Counter to track the next trial for fitting.
123            - Initial value is `0`.
124        predictions : list
125            Stores predications made during training or testing
126            - Initial value is `[]`.
127        pred_probas : list of `float`
128            List to store predication probabilities during testing.
129            - Initial value is `[]`.
130        n_splits : int
131            Number of splits for cross-validation. Also serves as minimum required samples per class for training when running _check_ready_for_fit().
132            - Initial value is `5`.
133
134        """
135        logger.info("Initializing the classifier")
136        self.X = np.ndarray([0])
137        """@private (This is just for the API docs, to avoid double listing."""
138        self.y = np.ndarray([0])
139        """@private (This is just for the API docs, to avoid double listing."""
140
141        self.subset_defined = False
142        """@private (This is just for the API docs, to avoid double listing."""
143        self.subset = subset
144        """@private (This is just for the API docs, to avoid double listing."""
145        self.channel_labels = []
146        """@private (This is just for the API docs, to avoid double listing."""
147        self.channel_selection_setup = False
148        """@private (This is just for the API docs, to avoid double listing."""
149
150        # Lists for plotting classifier performance over time
151        self.offline_accuracy = []
152        """@private (This is just for the API docs, to avoid double listing."""
153        self.offline_precision = []
154        """@private (This is just for the API docs, to avoid double listing."""
155        self.offline_recall = []
156        """@private (This is just for the API docs, to avoid double listing."""
157        self.offline_trial_count = 0
158        """@private (This is just for the API docs, to avoid double listing."""
159        self.offline_trial_counts = []
160        """@private (This is just for the API docs, to avoid double listing."""
161
162        # For iterative fitting,
163        self.next_fit_trial = 0
164        """@private (This is just for the API docs, to avoid double listing."""
165
166        # Keep track of predictions
167        self.predictions = []
168        """@private (This is just for the API docs, to avoid double listing."""
169        self.pred_probas = []
170        """@private (This is just for the API docs, to avoid double listing."""
171
172        # N Splits
173        self.n_splits = 5
174        """@private (This is just for the API docs, to avoid double listing."""

Initializes GenericClassifier class.

Parameters
  • training_selection (int, optional): Integer representing the object selected for training.
    • Default is 0.
  • subset (list of int or str, optional): List of indices (int) or labels (str) of the desired channels.
    • Default is [].
Attributes
  • X (numpy.ndarray): Input features (training data). 3D numpy array with shape = (n_samples, n_channels, n_trials).
    • Initial value is np.ndarray([0]).
  • y (numpy.ndarray): Target labels corresponding to input features in X. 1D numpy array with shape = (n_samples, ).
    • Initial value is np.ndarray([0]).
  • subset_defined (bool): Flag indicating whether a subset is defined.
    • Initial value is False.
  • subset (list of int or str): List of indices (int) or labels (str) of the desired channels.
    • Initial value is parameter subset.
  • channel_labels (list of str): Channel labels from the entire EEG montage.
    • Initial value is [].
  • channel_selection_setup (bool): FLag indicating whether channel selection is set up.
    • Initial value is False.
  • offline_accuracy (list of float): Stores offline accuracy values during training.
    • Initial value is [].
  • offline_precision (list of float): Stores offline precision values during training.
    • Initial value is [].
  • offline_recall (list of float): Stores offline recall values during training.
    • Initial value is [].
  • offline_trial_count (int): Counter to keep track of the number of offline trials
    • Initial value is 0.
  • offline_trial_counts (list of int): List to store the counts of offline trials. i.e. `offline_trial_count' values.
    • Initial value is [].
  • next_fit_trial (int): Counter to track the next trial for fitting.
    • Initial value is 0.
  • predictions (list): Stores predications made during training or testing
    • Initial value is [].
  • pred_probas (list of float): List to store predication probabilities during testing.
    • Initial value is [].
  • n_splits (int): Number of splits for cross-validation. Also serves as minimum required samples per class for training when running _check_ready_for_fit().
    • Initial value is 5.
def get_subset(self, X=[], subset=[], channel_labels=[]):
215    def get_subset(self, X=[], subset=[], channel_labels=[]):
216        """Get a subset of X according to labels or indices.
217
218        Parameters
219        ----------
220        X : numpy.ndarray, *optional*
221            3D array containing data with `float` type.
222
223            shape = (`n_trials`,`n_channels`,`n_samples`)
224            - Default is `[]`.
225        subset : list of `int` or `str`, *optional*
226            List of indices (int) or labels (str) of the desired channels.
227            - Default is `[]`.
228        channel_labels : list of `str`, *optional*
229            Channel labels from the entire EEG montage.
230            - Default is `[]`.
231
232        Returns
233        -------
234        X : numpy.ndarray
235            Subset of input `X` according to labels or indices.
236            3D array containing data with `float` type.
237
238            shape = (`n_trials`,`n_channels`,`n_samples`)
239
240        """
241
242        # Check for self.subset and/or self.channel_labels
243
244        # Init
245        subset_indices = []
246
247        # Copy the indices based on subset
248        try:
249            # Check if we can use subset indices
250            if self.subset == []:
251                return X
252
253            if type(self.subset[0]) is int:
254                logger.info("Using subset indices")
255
256                subset_indices = self.subset
257                self.subset_defined
258
259            # Or channel labels
260            if type(self.subset[0]) is str:
261                logger.info("Using channel labels and subset labels")
262
263                # Replace indices with those described by labels
264                for sl in self.subset:
265                    subset_indices.append(self.channel_labels.index(sl))
266
267                self.subset_defined = True
268            # Return for the given indices
269            try:
270                if sum(X.shape) == 0:
271                    new_X = self.X[:, subset_indices, :]
272                    self.X = new_X
273                else:
274                    new_X = X[:, subset_indices, :]
275                    X = new_X
276                    return X
277
278            except Exception:
279                if sum(X.shape) == 0:
280                    new_X = self.X[subset_indices, :]
281                    self.X = new_X
282
283                else:
284                    new_X = X[subset_indices, :]
285                    X = new_X
286                    return X
287
288        # notify if failed
289        except Exception:
290            logger.warning("something went wrong, no subset taken")
291            return X

Get a subset of X according to labels or indices.

Parameters
  • X (numpy.ndarray, optional): 3D array containing data with float type.

    shape = (n_trials,n_channels,n_samples)

    • Default is [].
  • subset (list of int or str, optional): List of indices (int) or labels (str) of the desired channels.
    • Default is [].
  • channel_labels (list of str, optional): Channel labels from the entire EEG montage.
    • Default is [].
Returns
  • X (numpy.ndarray): Subset of input X according to labels or indices. 3D array containing data with float type.

    shape = (n_trials,n_channels,n_samples)

def setup_channel_selection( self, method='SBS', metric='accuracy', iterative_selection=False, initial_channels=[], max_time=999, min_channels=1, max_channels=999, performance_delta=0.001, n_jobs=1, record_performance=False):
293    def setup_channel_selection(
294        self,
295        method="SBS",
296        metric="accuracy",
297        iterative_selection=False,
298        initial_channels=[],  # wrapper setup
299        max_time=999,
300        min_channels=1,
301        max_channels=999,
302        performance_delta=0.001,  # stopping criterion
303        n_jobs=1,
304        record_performance=False,
305    ):
306        """Setup channel selection parameters.
307
308        Parameters
309        ----------
310        method : str, *optional*
311            The method used to add or remove channels.
312            - Default is `"SBS"`.
313        metric : str, *optional*
314            The metric used to measure performance.
315            - Default is `"accuracy"`.
316        iterative_selection : bool, *optional*
317            Whether or not to use the previously selected subset for the initial subset.
318            Default is `False`.
319        initial_channels : type, *optional*
320            List of channels to use as initial subset for selection.
321            If empty, `initial_channels` is set to all available channels.
322            - Default is `[]`.
323        max_time : int, *optional*
324            Maximum time in seconds allowed for channel selection.
325            - Default is `999`.
326        min_channels : int, *optional*
327            Minimum number of channels to select during channel selection.
328            - Default is `1`.
329        max_channels : int, *optional*
330            Maximum number of channels allowed in the final subset.
331            - Default is `999`.
332        performance_delta : float, *optional*
333            Smallest performance increment to allow continue of the search.
334            - Default is `0.001`.
335        n_jobs : int, *optional*
336            The number of threads to dedicate to this calculation.
337            - Default is `1`.
338        record_performance : bool, *optional*
339            Decides whether or not to record performance of channel selection.
340            - Default is `False`.
341
342        Returns
343        -------
344        `None`
345
346        """
347        # Add these to settings later
348        if initial_channels == []:
349            self.chs_initial_subset = self.channel_labels
350        else:
351            self.chs_initial_subset = initial_channels
352        self.chs_method = method  # method to add/remove channels
353        self.chs_metric = metric  # metric by which to measure performance
354        self.chs_iterative_selection = iterative_selection  # whether or not to use the previously selected subset for the initial subset
355        self.chs_n_jobs = n_jobs  # number of threads
356        self.chs_max_time = max_time  # max time in seconds
357        self.chs_min_channels = min_channels  # minimum number of channels
358        self.chs_max_channels = max_channels  # maximum number of channels
359        self.chs_performance_delta = performance_delta  # smallest performance increment to justify continuing search
360        self.chs_record_performance = record_performance  # record performance
361
362        self.channel_selection_setup = True

Setup channel selection parameters.

Parameters
  • method (str, optional): The method used to add or remove channels.
    • Default is "SBS".
  • metric (str, optional): The metric used to measure performance.
    • Default is "accuracy".
  • iterative_selection (bool, optional): Whether or not to use the previously selected subset for the initial subset. Default is False.
  • initial_channels (type, optional): List of channels to use as initial subset for selection. If empty, initial_channels is set to all available channels.
    • Default is [].
  • max_time (int, optional): Maximum time in seconds allowed for channel selection.
    • Default is 999.
  • min_channels (int, optional): Minimum number of channels to select during channel selection.
    • Default is 1.
  • max_channels (int, optional): Maximum number of channels allowed in the final subset.
    • Default is 999.
  • performance_delta (float, optional): Smallest performance increment to allow continue of the search.
    • Default is 0.001.
  • n_jobs (int, optional): The number of threads to dedicate to this calculation.
    • Default is 1.
  • record_performance (bool, optional): Decides whether or not to record performance of channel selection.
    • Default is False.
Returns
  • None
def add_to_train(self, decision_block, labels, num_options=0, meta=[]):
365    def add_to_train(self, decision_block, labels, num_options=0, meta=[]):
366        """Add training data to the training set using a decision block
367        and a label.
368
369        Parameters
370        ----------
371        decision_block : numpy.ndarray
372            Decision block containing EEG data for training.
373            3D array with shape = (`n_epochs`, `n_channels`, `n_samples`).
374        labels : numpy.ndarray
375            Labels corresponding to each epoch in `decision_block`.
376            1D array with shape = (`n_epochs`, ).
377        num_options : int, *optional*
378            Number of options available for each trial.
379            - Default is `0`.
380        meta : list, *optional*
381            Additional metadata related to the training data.
382            - Default is `[]`.
383
384        Returns
385        -------
386        `None`
387
388        """
389        logger.debug("Adding to training set")
390        # n = number of channels
391        # m = number of samples
392        # p = number of epochs
393        p, n, m = decision_block.shape
394
395        self.num_options = num_options
396        self.meta = meta
397
398        if self.X.size == 0:
399            self.X = decision_block
400            self.y = labels
401
402        else:
403            self.X = np.append(self.X, decision_block, axis=0)
404            self.y = np.append(self.y, labels, axis=0)

Add training data to the training set using a decision block and a label.

Parameters
  • decision_block (numpy.ndarray): Decision block containing EEG data for training. 3D array with shape = (n_epochs, n_channels, n_samples).
  • labels (numpy.ndarray): Labels corresponding to each epoch in decision_block. 1D array with shape = (n_epochs, ).
  • num_options (int, optional): Number of options available for each trial.
    • Default is 0.
  • meta (list, optional): Additional metadata related to the training data.
    • Default is [].
Returns
  • None
@abstractmethod
def fit(self):
406    @abstractmethod
407    def fit(self):
408        """Abstract method to fit classifier
409
410        Returns
411        -------
412        `None`
413
414        """
415        pass

Abstract method to fit classifier

Returns
  • None
@abstractmethod
def predict( self, X: numpy.ndarray) -> Prediction:
417    @abstractmethod
418    def predict(self, X: np.ndarray) -> Prediction:
419        """Abstract method to predict with classifier
420
421        X : numpy.ndarray
422            3D array where shape = (trials, channels, samples)
423
424        Returns
425        -------
426        prediction : Prediction
427            Results of predict call containing the predicted class labels, and
428            optionally the probabilities of the labels (empty list if not possible).
429
430        """
431        pass

Abstract method to predict with classifier

X : numpy.ndarray 3D array where shape = (trials, channels, samples)

Returns
  • prediction (Prediction): Results of predict call containing the predicted class labels, and optionally the probabilities of the labels (empty list if not possible).