bci_essentials.classification.ssvep_basic_tf_classifier

SSVEP Basic Training-Free Classifier

Classifies SSVEP based on relative bandpower, taking only the maximum.

 1"""**SSVEP Basic Training-Free Classifier**
 2
 3Classifies SSVEP based on relative bandpower, taking only the maximum.
 4
 5"""
 6
 7# Stock libraries
 8import numpy as np
 9from scipy import signal
10
11# Import bci_essentials modules and methods
12from ..classification.generic_classifier import GenericClassifier, Prediction
13from ..utils.logger import Logger  # Logger wrapper
14
15# Instantiate a logger for the module at the default level of logging.INFO
16# Logs to bci_essentials.__module__) where __module__ is the name of the module
17logger = Logger(name=__name__)
18
19
20class SsvepBasicTrainFreeClassifier(GenericClassifier):
21    """SSVEP Basic Training-Free Classifier class
22    (*inherits from GenericClassifier*).
23
24    """
25
26    def set_ssvep_settings(self, sampling_freq, target_freqs):
27        """Set the SSVEP settings.
28
29        Parameters
30        ----------
31        sampling_freq : int
32            Sampling frequency of the EEG data.
33        target_freqs : list of `int`
34            List of the target frequencies for SSVEP detection.
35
36        Returns
37        -------
38        `None`
39            Models created used in `fit()`.
40
41        """
42        self.sampling_freq = sampling_freq
43        self.target_freqs = target_freqs
44        self.setup = False
45
46    def fit(self):
47        """Fit the model.
48
49        Returns
50        -------
51        `None`
52            Models created used in `predict()`.
53
54        """
55        logger.warning("This classifier does not require training.")
56
57    def predict(self, X):
58        """Predict the class labels for the provided data.
59
60        Parameters
61        ----------
62        X : numpy.ndarray
63            3D array where shape = (trials, channels, samples)
64
65        Returns
66        -------
67        prediction : Prediction
68            Results of predict call containing the predicted class labels.  Probabilities
69            are not available (empty list).
70
71        """
72        # get the shape
73        n_trials, n_channels, n_samples = X.shape
74        # The first time it is called it must be set up
75        if self.setup is False:
76            logger.info("Setting up the training free classifier")
77
78            self.setup = True
79
80        # Build one augmented channel, here by just adding them all together
81        augmented_X = np.mean(X, axis=1)
82
83        # Get the PSD estimate using Welch's method
84        f, Pxx = signal.welch(augmented_X, fs=self.sampling_freq, nperseg=n_samples)
85
86        # Get a vote for each trial
87        prediction = np.zeros(n_trials)
88        for trial in range(n_trials):
89            # Get the frequency with the greatest PSD
90            Pxx_of_f_bins = np.zeros(len(self.target_freqs))
91            for i, tf in enumerate(self.target_freqs):
92                # Get the closest frequency bin
93                closest_freq_bin = np.argmin(np.abs(f - tf))
94
95                Pxx_of_f_bins[i] = Pxx[trial][int(closest_freq_bin)]
96
97            prediction[trial] = np.argmax(Pxx_of_f_bins)
98
99        return Prediction(labels=prediction)
class SsvepBasicTrainFreeClassifier(bci_essentials.classification.generic_classifier.GenericClassifier):
 21class SsvepBasicTrainFreeClassifier(GenericClassifier):
 22    """SSVEP Basic Training-Free Classifier class
 23    (*inherits from GenericClassifier*).
 24
 25    """
 26
 27    def set_ssvep_settings(self, sampling_freq, target_freqs):
 28        """Set the SSVEP settings.
 29
 30        Parameters
 31        ----------
 32        sampling_freq : int
 33            Sampling frequency of the EEG data.
 34        target_freqs : list of `int`
 35            List of the target frequencies for SSVEP detection.
 36
 37        Returns
 38        -------
 39        `None`
 40            Models created used in `fit()`.
 41
 42        """
 43        self.sampling_freq = sampling_freq
 44        self.target_freqs = target_freqs
 45        self.setup = False
 46
 47    def fit(self):
 48        """Fit the model.
 49
 50        Returns
 51        -------
 52        `None`
 53            Models created used in `predict()`.
 54
 55        """
 56        logger.warning("This classifier does not require training.")
 57
 58    def predict(self, X):
 59        """Predict the class labels for the provided data.
 60
 61        Parameters
 62        ----------
 63        X : numpy.ndarray
 64            3D array where shape = (trials, channels, samples)
 65
 66        Returns
 67        -------
 68        prediction : Prediction
 69            Results of predict call containing the predicted class labels.  Probabilities
 70            are not available (empty list).
 71
 72        """
 73        # get the shape
 74        n_trials, n_channels, n_samples = X.shape
 75        # The first time it is called it must be set up
 76        if self.setup is False:
 77            logger.info("Setting up the training free classifier")
 78
 79            self.setup = True
 80
 81        # Build one augmented channel, here by just adding them all together
 82        augmented_X = np.mean(X, axis=1)
 83
 84        # Get the PSD estimate using Welch's method
 85        f, Pxx = signal.welch(augmented_X, fs=self.sampling_freq, nperseg=n_samples)
 86
 87        # Get a vote for each trial
 88        prediction = np.zeros(n_trials)
 89        for trial in range(n_trials):
 90            # Get the frequency with the greatest PSD
 91            Pxx_of_f_bins = np.zeros(len(self.target_freqs))
 92            for i, tf in enumerate(self.target_freqs):
 93                # Get the closest frequency bin
 94                closest_freq_bin = np.argmin(np.abs(f - tf))
 95
 96                Pxx_of_f_bins[i] = Pxx[trial][int(closest_freq_bin)]
 97
 98            prediction[trial] = np.argmax(Pxx_of_f_bins)
 99
100        return Prediction(labels=prediction)

SSVEP Basic Training-Free Classifier class (inherits from GenericClassifier).

def set_ssvep_settings(self, sampling_freq, target_freqs):
27    def set_ssvep_settings(self, sampling_freq, target_freqs):
28        """Set the SSVEP settings.
29
30        Parameters
31        ----------
32        sampling_freq : int
33            Sampling frequency of the EEG data.
34        target_freqs : list of `int`
35            List of the target frequencies for SSVEP detection.
36
37        Returns
38        -------
39        `None`
40            Models created used in `fit()`.
41
42        """
43        self.sampling_freq = sampling_freq
44        self.target_freqs = target_freqs
45        self.setup = False

Set the SSVEP settings.

Parameters
  • sampling_freq (int): Sampling frequency of the EEG data.
  • target_freqs (list of int): List of the target frequencies for SSVEP detection.
Returns
  • None: Models created used in fit().
def fit(self):
47    def fit(self):
48        """Fit the model.
49
50        Returns
51        -------
52        `None`
53            Models created used in `predict()`.
54
55        """
56        logger.warning("This classifier does not require training.")

Fit the model.

Returns
def predict(self, X):
 58    def predict(self, X):
 59        """Predict the class labels for the provided data.
 60
 61        Parameters
 62        ----------
 63        X : numpy.ndarray
 64            3D array where shape = (trials, channels, samples)
 65
 66        Returns
 67        -------
 68        prediction : Prediction
 69            Results of predict call containing the predicted class labels.  Probabilities
 70            are not available (empty list).
 71
 72        """
 73        # get the shape
 74        n_trials, n_channels, n_samples = X.shape
 75        # The first time it is called it must be set up
 76        if self.setup is False:
 77            logger.info("Setting up the training free classifier")
 78
 79            self.setup = True
 80
 81        # Build one augmented channel, here by just adding them all together
 82        augmented_X = np.mean(X, axis=1)
 83
 84        # Get the PSD estimate using Welch's method
 85        f, Pxx = signal.welch(augmented_X, fs=self.sampling_freq, nperseg=n_samples)
 86
 87        # Get a vote for each trial
 88        prediction = np.zeros(n_trials)
 89        for trial in range(n_trials):
 90            # Get the frequency with the greatest PSD
 91            Pxx_of_f_bins = np.zeros(len(self.target_freqs))
 92            for i, tf in enumerate(self.target_freqs):
 93                # Get the closest frequency bin
 94                closest_freq_bin = np.argmin(np.abs(f - tf))
 95
 96                Pxx_of_f_bins[i] = Pxx[trial][int(closest_freq_bin)]
 97
 98            prediction[trial] = np.argmax(Pxx_of_f_bins)
 99
100        return Prediction(labels=prediction)

Predict the class labels for the provided data.

Parameters
  • X (numpy.ndarray): 3D array where shape = (trials, channels, samples)
Returns
  • prediction (Prediction): Results of predict call containing the predicted class labels. Probabilities are not available (empty list).