bci_essentials.classification.ssvep_basic_tf_classifier
SSVEP Basic Training-Free Classifier
Classifies SSVEP based on relative bandpower, taking only the maximum.
1"""**SSVEP Basic Training-Free Classifier** 2 3Classifies SSVEP based on relative bandpower, taking only the maximum. 4 5""" 6 7# Stock libraries 8import numpy as np 9from scipy import signal 10 11# Import bci_essentials modules and methods 12from ..classification.generic_classifier import GenericClassifier, Prediction 13from ..utils.logger import Logger # Logger wrapper 14 15# Instantiate a logger for the module at the default level of logging.INFO 16# Logs to bci_essentials.__module__) where __module__ is the name of the module 17logger = Logger(name=__name__) 18 19 20class SsvepBasicTrainFreeClassifier(GenericClassifier): 21 """SSVEP Basic Training-Free Classifier class 22 (*inherits from GenericClassifier*). 23 24 """ 25 26 def set_ssvep_settings(self, sampling_freq, target_freqs): 27 """Set the SSVEP settings. 28 29 Parameters 30 ---------- 31 sampling_freq : int 32 Sampling frequency of the EEG data. 33 target_freqs : list of `int` 34 List of the target frequencies for SSVEP detection. 35 36 Returns 37 ------- 38 `None` 39 Models created used in `fit()`. 40 41 """ 42 self.sampling_freq = sampling_freq 43 self.target_freqs = target_freqs 44 self.setup = False 45 46 def fit(self): 47 """Fit the model. 48 49 Returns 50 ------- 51 `None` 52 Models created used in `predict()`. 53 54 """ 55 logger.warning("This classifier does not require training.") 56 57 def predict(self, X): 58 """Predict the class labels for the provided data. 59 60 Parameters 61 ---------- 62 X : numpy.ndarray 63 3D array where shape = (trials, channels, samples) 64 65 Returns 66 ------- 67 prediction : Prediction 68 Results of predict call containing the predicted class labels. Probabilities 69 are not available (empty list). 70 71 """ 72 # get the shape 73 n_trials, n_channels, n_samples = X.shape 74 # The first time it is called it must be set up 75 if self.setup is False: 76 logger.info("Setting up the training free classifier") 77 78 self.setup = True 79 80 # Build one augmented channel, here by just adding them all together 81 augmented_X = np.mean(X, axis=1) 82 83 # Get the PSD estimate using Welch's method 84 f, Pxx = signal.welch(augmented_X, fs=self.sampling_freq, nperseg=n_samples) 85 86 # Get a vote for each trial 87 prediction = np.zeros(n_trials) 88 for trial in range(n_trials): 89 # Get the frequency with the greatest PSD 90 Pxx_of_f_bins = np.zeros(len(self.target_freqs)) 91 for i, tf in enumerate(self.target_freqs): 92 # Get the closest frequency bin 93 closest_freq_bin = np.argmin(np.abs(f - tf)) 94 95 Pxx_of_f_bins[i] = Pxx[trial][int(closest_freq_bin)] 96 97 prediction[trial] = np.argmax(Pxx_of_f_bins) 98 99 return Prediction(labels=prediction)
logger =
<bci_essentials.utils.logger.Logger object>
class
SsvepBasicTrainFreeClassifier(bci_essentials.classification.generic_classifier.GenericClassifier):
21class SsvepBasicTrainFreeClassifier(GenericClassifier): 22 """SSVEP Basic Training-Free Classifier class 23 (*inherits from GenericClassifier*). 24 25 """ 26 27 def set_ssvep_settings(self, sampling_freq, target_freqs): 28 """Set the SSVEP settings. 29 30 Parameters 31 ---------- 32 sampling_freq : int 33 Sampling frequency of the EEG data. 34 target_freqs : list of `int` 35 List of the target frequencies for SSVEP detection. 36 37 Returns 38 ------- 39 `None` 40 Models created used in `fit()`. 41 42 """ 43 self.sampling_freq = sampling_freq 44 self.target_freqs = target_freqs 45 self.setup = False 46 47 def fit(self): 48 """Fit the model. 49 50 Returns 51 ------- 52 `None` 53 Models created used in `predict()`. 54 55 """ 56 logger.warning("This classifier does not require training.") 57 58 def predict(self, X): 59 """Predict the class labels for the provided data. 60 61 Parameters 62 ---------- 63 X : numpy.ndarray 64 3D array where shape = (trials, channels, samples) 65 66 Returns 67 ------- 68 prediction : Prediction 69 Results of predict call containing the predicted class labels. Probabilities 70 are not available (empty list). 71 72 """ 73 # get the shape 74 n_trials, n_channels, n_samples = X.shape 75 # The first time it is called it must be set up 76 if self.setup is False: 77 logger.info("Setting up the training free classifier") 78 79 self.setup = True 80 81 # Build one augmented channel, here by just adding them all together 82 augmented_X = np.mean(X, axis=1) 83 84 # Get the PSD estimate using Welch's method 85 f, Pxx = signal.welch(augmented_X, fs=self.sampling_freq, nperseg=n_samples) 86 87 # Get a vote for each trial 88 prediction = np.zeros(n_trials) 89 for trial in range(n_trials): 90 # Get the frequency with the greatest PSD 91 Pxx_of_f_bins = np.zeros(len(self.target_freqs)) 92 for i, tf in enumerate(self.target_freqs): 93 # Get the closest frequency bin 94 closest_freq_bin = np.argmin(np.abs(f - tf)) 95 96 Pxx_of_f_bins[i] = Pxx[trial][int(closest_freq_bin)] 97 98 prediction[trial] = np.argmax(Pxx_of_f_bins) 99 100 return Prediction(labels=prediction)
SSVEP Basic Training-Free Classifier class (inherits from GenericClassifier).
def
set_ssvep_settings(self, sampling_freq, target_freqs):
27 def set_ssvep_settings(self, sampling_freq, target_freqs): 28 """Set the SSVEP settings. 29 30 Parameters 31 ---------- 32 sampling_freq : int 33 Sampling frequency of the EEG data. 34 target_freqs : list of `int` 35 List of the target frequencies for SSVEP detection. 36 37 Returns 38 ------- 39 `None` 40 Models created used in `fit()`. 41 42 """ 43 self.sampling_freq = sampling_freq 44 self.target_freqs = target_freqs 45 self.setup = False
Set the SSVEP settings.
Parameters
- sampling_freq (int): Sampling frequency of the EEG data.
- target_freqs (list of
int): List of the target frequencies for SSVEP detection.
Returns
None: Models created used infit().
def
fit(self):
def
predict(self, X):
58 def predict(self, X): 59 """Predict the class labels for the provided data. 60 61 Parameters 62 ---------- 63 X : numpy.ndarray 64 3D array where shape = (trials, channels, samples) 65 66 Returns 67 ------- 68 prediction : Prediction 69 Results of predict call containing the predicted class labels. Probabilities 70 are not available (empty list). 71 72 """ 73 # get the shape 74 n_trials, n_channels, n_samples = X.shape 75 # The first time it is called it must be set up 76 if self.setup is False: 77 logger.info("Setting up the training free classifier") 78 79 self.setup = True 80 81 # Build one augmented channel, here by just adding them all together 82 augmented_X = np.mean(X, axis=1) 83 84 # Get the PSD estimate using Welch's method 85 f, Pxx = signal.welch(augmented_X, fs=self.sampling_freq, nperseg=n_samples) 86 87 # Get a vote for each trial 88 prediction = np.zeros(n_trials) 89 for trial in range(n_trials): 90 # Get the frequency with the greatest PSD 91 Pxx_of_f_bins = np.zeros(len(self.target_freqs)) 92 for i, tf in enumerate(self.target_freqs): 93 # Get the closest frequency bin 94 closest_freq_bin = np.argmin(np.abs(f - tf)) 95 96 Pxx_of_f_bins[i] = Pxx[trial][int(closest_freq_bin)] 97 98 prediction[trial] = np.argmax(Pxx_of_f_bins) 99 100 return Prediction(labels=prediction)
Predict the class labels for the provided data.
Parameters
- X (numpy.ndarray): 3D array where shape = (trials, channels, samples)
Returns
- prediction (Prediction): Results of predict call containing the predicted class labels. Probabilities are not available (empty list).