bci_essentials.bci_controller

Module for managing BCI data.

This module provides data classes for different BCI paradigms.

It includes the loading of offline data in xdf format or the live streaming of LSL data.

The loaded/streamed data is added to a buffer such that offline and online processing pipelines are identical.

Data is pre-processed (using the signal_processing module), divided into trials, and classified (using one of the classification sub-modules).

Classes
  • BciController : For processing continuous data in trials of a defined length.
  1"""Module for managing BCI data.
  2
  3This module provides data classes for different BCI paradigms.
  4
  5It includes the loading of offline data in `xdf` format
  6or the live streaming of LSL data.
  7
  8The loaded/streamed data is added to a buffer such that offline and
  9online processing pipelines are identical.
 10
 11Data is pre-processed (using the `signal_processing` module), divided into trials,
 12and classified (using one of the `classification` sub-modules).
 13
 14Classes
 15-------
 16- `BciController` : For processing continuous data in trials of a defined length.
 17
 18"""
 19
 20import time
 21import os
 22import numpy as np
 23from enum import Enum
 24
 25from .paradigm.paradigm import Paradigm
 26from .data_tank.data_tank import DataTank
 27from .classification.generic_classifier import GenericClassifier
 28from .io.sources import EegSource, MarkerSource
 29from .io.messenger import Messenger
 30from .utils.logger import Logger
 31
 32# Instantiate a logger for the module at the default level of logging.INFO
 33# Logs to bci_essentials.__module__) where __module__ is the name of the module
 34logger = Logger(name=__name__)
 35
 36
 37class MarkerTypes(Enum):
 38    TRIAL_STARTED = "Trial Started"
 39    TRIAL_ENDS = "Trial Ends"
 40    TRAINING_COMPLETE = "Training Complete"
 41    TRAIN_CLASSIFIER = "Train Classifier"
 42    DONE_RS_COLLECTION = "Done with all RS collection"
 43    UPDATE_CLASSIFIER = "Update Classifier"
 44
 45
 46# EEG data
 47class BciController:
 48    """
 49    Class that holds, trials, processes, and classifies EEG data.
 50    This class is used for processing of continuous EEG data in trials of a defined length.
 51    """
 52
 53    # 0. Special methods (e.g. __init__)
 54    def __init__(
 55        self,
 56        classifier: GenericClassifier,
 57        eeg_source: EegSource,
 58        marker_source: MarkerSource | None = None,
 59        paradigm: Paradigm | None = None,
 60        data_tank: DataTank | None = None,
 61        messenger: Messenger | None = None,
 62    ):
 63        """Initializes `BciController` class.
 64
 65        Parameters
 66        ----------
 67        classifier : GenericClassifier
 68            The classifier used by BciController.
 69        eeg_source : EegSource
 70            Source of EEG data and timestamps, this could be from a file or headset via LSL, etc.
 71        marker_source : EegSource
 72            Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc.
 73            - Default is `None`.
 74        paradigm : Paradigm
 75            The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data.
 76        data_tank : DataTank
 77            DataTank object to handle the storage of EEG trials and labels.
 78            - Default is `None`.
 79        messenger: Messenger
 80            Messenger object to handle events from BciController, ex: acknowledging markers and
 81            predictions.
 82            - Default is `None`.
 83
 84        """
 85
 86        # Ensure the incoming dependencies are the right type
 87        assert isinstance(classifier, GenericClassifier)
 88        assert isinstance(eeg_source, EegSource)
 89        assert isinstance(marker_source, MarkerSource | None)
 90        assert isinstance(paradigm, Paradigm | None)
 91        assert isinstance(data_tank, DataTank | None)
 92        assert isinstance(messenger, Messenger | None)
 93
 94        self._classifier = classifier
 95        self.__eeg_source = eeg_source
 96        self.__marker_source = marker_source
 97        self.__paradigm = paradigm
 98        self.__data_tank = data_tank
 99        self._messenger = messenger
100
101        self.headset_string = self.__eeg_source.name
102        self.fsample = self.__eeg_source.fsample
103        self.n_channels = self.__eeg_source.n_channels
104        self.ch_type = self.__eeg_source.channel_types
105        self.ch_units = self.__eeg_source.channel_units
106        self.channel_labels = self.__eeg_source.channel_labels
107
108        self.__data_tank.set_source_data(
109            self.headset_string,
110            self.fsample,
111            self.n_channels,
112            self.ch_type,
113            self.ch_units,
114            self.channel_labels,
115        )
116
117        # Switch any trigger channels to stim, this is for mne/bids export (?)
118        self.ch_type = [type.replace("trg", "stim") for type in self.ch_type]
119
120        self._classifier.channel_labels = self.channel_labels
121
122        logger.info(self.headset_string)
123        logger.info(self.channel_labels)
124
125        # Initialize data and timestamp arrays to the right dimensions, but zero elements
126        self.marker_data = np.zeros((0, 1))
127        self.marker_timestamps = np.zeros((0))
128        self.bci_controller = np.zeros((0, self.n_channels))
129        self.eeg_timestamps = np.zeros((0))
130
131        # Initialize marker methods dictionary
132        self.marker_methods = {
133            MarkerTypes.DONE_RS_COLLECTION.value: self.__process_resting_state_data,
134            MarkerTypes.TRIAL_STARTED.value: self.__log_trial_start,
135            MarkerTypes.TRIAL_ENDS.value: self.__handle_trial_end,
136            MarkerTypes.TRAINING_COMPLETE.value: self.__update_and_train_classifier,
137            MarkerTypes.TRAIN_CLASSIFIER.value: self.__update_and_train_classifier,
138            MarkerTypes.UPDATE_CLASSIFIER.value: self.__update_and_train_classifier,
139        }
140
141        self.ping_count = 0
142        self.n_samples = 0
143        self.time_units = ""
144
145    # 1. Core public API methods
146    def setup(
147        self,
148        online=True,
149        train_complete=False,
150        train_lock=False,
151        auto_save_epochs=True,
152    ):
153        """Configure processing loop.
154
155        This should be called before starting the loop with run() or step().
156
157        Calling after will reset the loop state.
158
159        The processing loop reads in EEG and marker data and processes it.
160        The loop can be run in "offline" or "online" modes:
161        - If in `online` mode, then the loop will continuously try to read
162        in data from the `BciController` object and process it. The loop will
163        terminate when `max_loops` is reached, or when manually terminated.
164        - If in `offline` mode, then the loop will read in all of the data
165        at once, process it, and then terminate.
166
167        Parameters
168        ----------
169        online : bool, *optional*
170            Flag to indicate if the data will be processed in `online` mode.
171            - `True`: The data will be processed in `online` mode.
172            - `False`: The data will be processed in `offline` mode.
173            - Default is `True`.
174        train_complete : bool, *optional*
175            Flag to indicate if the classifier has been trained.
176            - `True`: The classifier has been trained.
177            - `False`: The classifier has not been trained.
178            - Default is `False`.
179        train_lock : bool, *optional*
180            Flag to indicate if the classifier is locked (ie. no more training).
181            - `True`: The classifier is locked.
182            - `False`: The classifier is not locked.
183            - Default is `False`.
184        auto_save_epochs : bool, *optional*
185            Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes.
186            - `True`: Epochs will be saved to a temp file.
187            - `False`: Epochs will not be saved to a temp file.
188
189
190        Returns
191        -------
192        `None`
193
194        """
195        self.online = online
196        self.train_complete = train_complete
197        self.train_lock = train_lock
198        self.auto_save_epochs = auto_save_epochs
199
200        # initialize the numbers of markers and trials to zero
201        self.marker_count = 0
202        self.current_num_trials = 0
203        self.n_trials = 0
204
205        self.num_online_selections = 0
206        self.online_selection_indices = []
207        self.online_selections = []
208
209        # Check for a temp_epochs file
210        if online:
211            self.__load_temp_epochs_if_available()
212
213    def step(self):
214        """Runs a single BciController processing step.
215
216        See setup() for configuration of processing.
217
218        The method:
219        1. Pulls data from sources (EEG and markers).
220        2. Run a while loop to process markers as long as there are unprocessed markers.
221        3. The while loop processes the markers in the following order:
222            - First checks if the marker is a known command marker from self.marker_methods.
223            - Then checks if it's an event marker (contains commas)
224            - If neither, logs a warning about unknown marker type
225        3. If the marker is a command marker, handles it by calling __handle_command_marker().
226        4. If the marker is an event marker, handles it by calling __handle_event_marker().
227        5. If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker.
228            - Note: If there is an unknown marker type, the marker count is still incremented and processing continues.
229        6. If the command or event marker handling return continue_flag as False, break out of the while loop and end the step.
230
231        Parameters
232        ----------
233        `None`
234
235        Returns
236        ------
237        `None`
238
239        """
240        # read from sources to get new data.
241        # This puts command markers in the marker_data array and
242        # event markers in the event_marker_strings array
243        self._pull_data_from_sources()
244
245        # Process markers while there are unprocessed markers
246        # REMOVE COMMENT: check if there is an available command marker, if not, break and wait for more data
247        while len(self.marker_timestamps) > self.marker_count:
248            # Get the current marker
249            current_step_marker = self.marker_data[self.marker_count]  # String
250            current_timestamp = self.marker_timestamps[self.marker_count]  # Float
251
252            # If marker is empty, skip it
253            if not current_step_marker:
254                logger.warning("Empty marker received")
255                self.marker_count += 1
256                continue
257
258            # If messenger is available, send feedback for each marker received
259            if self._messenger is not None:
260                self._messenger.marker_received(current_step_marker)
261
262            # Process markers in order specified in the docstrings
263            # First check if it's a known command marker
264            if current_step_marker in self.marker_methods:
265                continue_flag = self.__handle_command_marker(current_step_marker)
266            # Then check if it's an event marker (contains commas)
267            elif "," in current_step_marker:
268                continue_flag = self.__handle_event_marker(
269                    current_step_marker, current_timestamp
270                )
271            # Otherwise, log a warning about unknown marker type
272            else:
273                # Log warning for unknown marker types
274                logger.warning("Unknown marker type received: %s", current_step_marker)
275                continue_flag = True
276
277            # Check if we should continue processing markers in the while loop
278            # if continue_flag is False, then break out of the while loop
279            # else, increment the marker count and process the next marker
280            if continue_flag is False:
281                break
282            else:
283                logger.info("Processed Marker: %s", current_step_marker)
284                self.marker_count += 1
285
286    def run(self, max_loops: int = 1000000):
287        """Runs BciController processing in a loop.
288
289        See setup() for configuration of processing.
290
291        Parameters
292        ----------
293        max_loops : int, *optional*
294            Maximum number of loops to run, default is `1000000`.
295
296        Returns
297        ------
298        `None`
299
300        """
301        # if offline, then all data is already loaded, only need to loop once
302        if self.online is False:
303            self.loops = max_loops - 1
304        else:
305            self.loops = 0
306
307        # Initialize the event marker buffer
308        self.event_marker_buffer = []
309        self.event_timestamp_buffer = []
310
311        # start the main loop, stops after pulling new data, max_loops times
312        while self.loops < max_loops:
313            # print out loop status
314            if self.loops % 100 == 0:
315                logger.debug(self.loops)
316
317            if self.loops == max_loops - 1:
318                logger.debug("last loop")
319
320            # read from sources and process
321            self.step()
322
323            # Wait a short period of time and then try to pull more data
324            if self.online:
325                time.sleep(0.00001)
326
327            self.loops += 1
328
329    # 2. Protected methods (single underscore)
330    def _pull_data_from_sources(self):
331        """Get pull data from EEG and optionally, the marker source.
332
333        This method will fill up the marker_data, bci_controller and corresponding timestamp arrays.
334
335        Parameters
336        ----------
337        `None`
338
339        Returns
340        -------
341        `None`
342
343        """
344        # Get new data from source, whatever it is
345        self.__pull_marker_data_from_source()
346        self.__pull_eeg_data_from_source()
347
348        # If the outlet exists send a ping
349        if self._messenger is not None:
350            self.ping_count += 1
351            self._messenger.ping()
352
353    # 3. Private methods (double underscore)
354    # 3a. Private methods for retrieving data from sources
355    def __pull_marker_data_from_source(self):
356        """Pulls marker samples from source, sanity checks and appends to buffer.
357
358        Parameters
359        ----------
360        `None`
361
362        Returns
363        -------
364        `None`
365
366        """
367
368        # if there isn't a marker source, abort
369        if self.__marker_source is None:
370            return
371
372        # read in the data
373        markers, timestamps = self.__marker_source.get_markers()
374        markers = np.array(markers)
375        timestamps = np.array(timestamps)
376
377        if markers.size == 0:
378            return
379
380        if markers.ndim != 2:
381            logger.warning("discarded invalid marker data")
382            return
383
384        # apply time correction
385        time_correction = self.__marker_source.time_correction()
386        timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))]
387
388        for i, marker in enumerate(markers):
389            marker = marker[0]
390            if "Ping" in marker:
391                continue
392
393            # Add all markers to the controller
394            self.marker_data = np.append(self.marker_data, marker)
395            self.marker_timestamps = np.append(self.marker_timestamps, timestamps[i])
396
397            # Add all markers to the data tank
398            self.__data_tank.add_raw_markers(
399                np.array([marker]), np.array([timestamps[i]])
400            )
401
402    def __pull_eeg_data_from_source(self):
403        """Pulls eeg samples from source, sanity checks and appends to buffer.
404
405        Parameters
406        ----------
407        `None`
408
409        Returns
410        -------
411        `None`
412
413        """
414
415        # read in the data
416        eeg, timestamps = self.__eeg_source.get_samples()
417        eeg = np.array(eeg)
418        timestamps = np.array(timestamps)
419
420        if eeg.size == 0:
421            return
422
423        if eeg.ndim != 2:
424            logger.warning("discarded invalid eeg data")
425            return
426
427        # if time is in milliseconds, divide by 1000, works for sampling rates above 10Hz
428        try:
429            if self.time_units == "milliseconds":
430                timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))]
431
432        # If time units are not defined then define them
433        except Exception:
434            dif_low = -2
435            dif_high = -1
436            while timestamps[dif_high] - timestamps[dif_low] == 0:
437                dif_low -= 1
438                dif_high -= 1
439
440            if timestamps[dif_high] - timestamps[dif_low] > 0.1:
441                timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))]
442                self.time_units = "milliseconds"
443            else:
444                self.time_units = "seconds"
445
446        # apply time correction, this is essential for headsets like neurosity which have their own clock
447        time_correction = self.__eeg_source.time_correction()
448        timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))]
449
450        self.__data_tank.add_raw_eeg(eeg.T, timestamps)
451
452        # Update latest EEG timestamp
453        self.latest_eeg_timestamp = timestamps[-1]
454
455    # 3b. Private methods for data processing and classification
456    def __process_resting_state_data(self):
457        """Handles the resting state data by packaging it and adding it to the data tank.
458
459        Parameters
460        ----------
461        `None`
462
463        Returns
464        ------
465        continue_flag : bool
466            Flag indicating to continue the while loop in step().
467
468        """
469        (
470            self.bci_controller,
471            self.eeg_timestamps,
472        ) = self.__data_tank.get_raw_eeg()
473
474        resting_state_data = self.__paradigm.package_resting_state_data(
475            self.marker_data,
476            self.marker_timestamps,
477            self.bci_controller,
478            self.eeg_timestamps,
479            self.fsample,
480        )
481
482        self.__data_tank.add_resting_state_data(resting_state_data)
483
484        return True  # Continue processing
485
486    def __process_and_classify(self):
487        """Process the markers and classify the data.
488
489        Parameters
490        ----------
491        `None`
492
493        Returns
494        -------
495        success_string : str
496            String indicating if the processing and classification was successful.
497            Potential values are "Success", "Skip", "Wait".
498            - "Success": The processing and classification was successful.
499            - "Skip": EEG is either absent entirely or contains lost packets.
500            - "Wait": The processing is waiting for more data.
501
502        """
503
504        eeg_start_time, eeg_end_time = self.__paradigm.get_eeg_start_and_end_times(
505            self.event_marker_buffer, self.event_timestamp_buffer
506        )
507
508        # No we actually need to wait until we have all the data for these markers
509        eeg, timestamps = self.__data_tank.get_raw_eeg()
510
511        # Check if there is available EEG data
512        if len(eeg) == 0:
513            logger.warning("No EEG data available")
514            return "Skip"
515
516        # If the last timestamp is less than the end time, then we don't have the necessary EEG to process
517        if timestamps[-1] < eeg_end_time:
518            return "Wait"
519
520        # Check if EEG sampling is continuous over this time period
521        start_indices = np.where(timestamps > eeg_start_time)[0]
522        if len(start_indices) == 0:
523            logger.warning("No timestamps exceed eeg_start_time")
524            return "Skip"
525        start_index = start_indices[0]
526        end_index = np.where(timestamps < eeg_end_time)[0][-1]
527
528        time_diffs = np.diff(timestamps[start_index:end_index])
529        if np.any(time_diffs > 2 / self.fsample):
530            logger.warning("Time gaps in EEG data")
531            return "Skip"
532
533        X, y = self.__paradigm.process_markers(
534            self.event_marker_buffer,
535            self.event_timestamp_buffer,
536            eeg,
537            timestamps,
538            self.fsample,
539        )
540
541        sum_new_labeled_trials = np.sum(y != -1)
542
543        # Add the epochs to the data tank
544        self.__data_tank.add_epochs(X, y)
545
546        # Save epochs to temp_epochs file
547        if self.auto_save_epochs and self.online and sum_new_labeled_trials > 0:
548            paradigm_str = self.__paradigm.paradigm_name
549
550            with open(self.temp_epochs, "wb") as f:
551                np.savez(
552                    f,
553                    X=self.__data_tank.epochs,
554                    y=self.__data_tank.labels,
555                    paradigm=paradigm_str,
556                )
557
558        # If either there are no labels OR iterative training is on, then make a prediction
559        if self.train_complete:
560            if -1 in y or self.__paradigm.iterative_training:
561                prediction = self._classifier.predict(X)
562                self.__send_prediction(prediction)
563
564        self.event_marker_buffer = []
565        self.event_timestamp_buffer = []
566
567        return "Success"
568
569    def __update_and_train_classifier(self):
570        """Updates the classifier if required.
571
572        Parameters
573        ----------
574        `None`
575
576        Returns
577        -------
578        continue_flag : bool
579            Flag indicating to continue the while loop in step().
580        """
581        if self.train_lock is False:
582            # Pull the epochs from the data tank and pass them to the classifier
583            X, y = self.__data_tank.get_epochs(latest=True)
584
585            # Remove epochs with label -1
586            ind_to_remove = []
587            for i, label in enumerate(y):
588                if label == -1:
589                    ind_to_remove.append(i)
590            X = np.delete(X, ind_to_remove, axis=0)
591            y = np.delete(y, ind_to_remove, axis=0)
592
593            # Check that there are epochs
594            if len(y) > 0:
595                self._classifier.add_to_train(X, y)
596
597            if self._classifier._check_ready_for_fit():
598                self._classifier.fit()
599                self.train_complete = True
600
601        return True
602
603    # 3c. Private methods for event handling (trial and markers) and messaging
604    def __log_trial_start(self):
605        """Logs the start of a trial.
606
607        Parameters
608        ----------
609        `None`
610
611        Returns
612        -------
613        continue_flag : bool
614            Flag indicating to continue the while loop in step().
615
616        """
617        logger.debug("Trial started, incrementing marker count and continuing")
618        # Note that a marker occured, but do nothing else
619        return True  # Continue processing
620
621    def __handle_trial_end(self):
622        """Handles the end of a trial. Processes and classifies trial data if required.
623
624        Parameters
625        ----------
626        `None`
627
628        Returns
629        ------
630        success_flag : bool
631            Flag indicating if the processing and classification was successful.
632            - Returns `True` if not classifying.
633        """
634        # If we are classifying based on trials, then process the trial,
635        if self.__paradigm.classify_each_trial:
636            success_string = self.__process_and_classify()
637            if success_string == "Wait":
638                logger.debug("Processing of trial not run: waiting for more data")
639                return False
640            if success_string == "Skip":
641                logger.warning("Processing of trial failed: skipping trial")
642                self.event_marker_buffer = []
643                self.event_timestamp_buffer = []
644                self.marker_count += 1
645                return False
646
647        return True  # Return True by default if not classifying
648
649    def __handle_event_marker(self, marker, timestamp):
650        """Processes and classifies event markers.
651
652        Parameters
653        ----------
654        marker : str
655            Event marker string containing comma-separated values.
656            - Format depends on paradigm implementation.
657        timestamp : float
658            Timestamp of the marker in seconds (after time correction).
659
660
661        Returns
662        ------
663        continue_flag : bool
664            Flag indicating to continue the while loop in step().
665
666        """
667        # Add the marker to the event marker buffer
668        self.event_marker_buffer.append(marker)
669        self.event_timestamp_buffer.append(timestamp)
670
671        # If classification is on epochs, then update epochs, maybe classify, and clear the buffer
672        if self.__paradigm.classify_each_epoch:
673            success_string = self.__process_and_classify()
674            if success_string == "Wait":
675                logger.debug("Processing of epoch not run: waiting for more data")
676                self.event_marker_buffer = []
677                self.event_timestamp_buffer = []
678                return False  # Stop processing
679            elif success_string == "Skip":
680                logger.warning("Processing of epoch failed: skipping epoch")
681                self.event_marker_buffer = []
682                self.event_timestamp_buffer = []
683
684                self.marker_count += 1
685                return False  # Stop processing
686
687        return True  # Continue processing
688
689    def __handle_command_marker(self, marker: str) -> bool:
690        """Processes a command marker by invoking its associated method.
691
692        The command marker string is assumed to be in the self.marker_methods dictionary.
693        The associated method is retrieved and called.
694        The return value of the method is used to determine if processing should continue.
695
696        Parameters
697        ----------
698        marker : str
699            A command marker string (assumed to be in self.marker_methods).
700
701        Returns
702        -------
703        bool
704            A flag indicating if the processing should continue.
705
706        """
707        command_marker_method = self.marker_methods[marker]  # Retrieve method
708        continue_flag = command_marker_method()  # Call method
709
710        # Debug level logging if continue_flag is FALSE
711        if continue_flag is False:
712            logger.debug("Command marker '%s' set continue_flag to FALSE", marker)
713
714        return continue_flag
715
716    def __send_prediction(self, prediction):
717        """Send a prediction to the messenger object.
718
719        Parameters
720        ----------
721        `None`
722
723        Returns
724        -------
725        `None`
726
727        """
728        if self._messenger is not None:
729            logger.debug("Sending prediction: %s", prediction)
730            self._messenger.prediction(prediction)
731        elif self._messenger is None and self.online is True:
732            # If running in online mode and messenger is not available, log a warning
733            logger.warning(
734                "Messenger not available (self._messenger is None). Prediction not sent: %s",
735                prediction,
736            )
737
738    def __load_temp_epochs_if_available(self, reload_data_time: int = 300):
739        """Load temp_epochs if available and valid.
740
741        Parameters
742        ----------
743        reload_data_time : int, *optional*
744            Time in seconds of the last temp_epochs file to reload the data from.
745            Default is `300` seconds (5 minutes).
746
747        Returns
748        -------
749        `None`
750
751        """
752        self.temp_epochs = os.path.join(
753            os.path.dirname(os.path.dirname(__file__)), "temp_epochs.npz"
754        )
755
756        if not os.path.exists(self.temp_epochs):
757            return
758
759        # If temp_epochs is older than `reload_data_time`, delete it
760        if os.path.getmtime(self.temp_epochs) < (time.time() - reload_data_time):
761            os.remove(self.temp_epochs)
762            logger.info("Deleted old temp_epochs file.")
763            return
764
765        # Load the temp_epochs file
766        with open(self.temp_epochs, "rb") as f:
767            npz = np.load(f, allow_pickle=True)
768            X = npz["X"]
769            y = npz["y"]
770            paradigm_str = npz["paradigm"].item()
771
772        # If the paradigm is different, delete the file
773        if self.__paradigm.paradigm_name != paradigm_str:
774            logger.warning(
775                "Paradigm in temp_epochs file does not match current paradigm. Deleting file."
776            )
777            os.remove(self.temp_epochs)
778            return
779
780        # If the paradigm is the same, then add the epochs to the data tank
781        logger.info("Loading epochs from temp_epochs file.")
782        logger.info("X shape: %s", X.shape)
783        logger.info("y shape: %s", y.shape)
784        self.__data_tank.add_epochs(X, y)
785
786        # If there are epochs in the data tank, then train the classifier
787        if len(self.__data_tank.labels) > 0:
788            self.__update_and_train_classifier()
class MarkerTypes(enum.Enum):
38class MarkerTypes(Enum):
39    TRIAL_STARTED = "Trial Started"
40    TRIAL_ENDS = "Trial Ends"
41    TRAINING_COMPLETE = "Training Complete"
42    TRAIN_CLASSIFIER = "Train Classifier"
43    DONE_RS_COLLECTION = "Done with all RS collection"
44    UPDATE_CLASSIFIER = "Update Classifier"

An enumeration.

TRIAL_STARTED = <MarkerTypes.TRIAL_STARTED: 'Trial Started'>
TRIAL_ENDS = <MarkerTypes.TRIAL_ENDS: 'Trial Ends'>
TRAINING_COMPLETE = <MarkerTypes.TRAINING_COMPLETE: 'Training Complete'>
TRAIN_CLASSIFIER = <MarkerTypes.TRAIN_CLASSIFIER: 'Train Classifier'>
DONE_RS_COLLECTION = <MarkerTypes.DONE_RS_COLLECTION: 'Done with all RS collection'>
UPDATE_CLASSIFIER = <MarkerTypes.UPDATE_CLASSIFIER: 'Update Classifier'>
class BciController:
 48class BciController:
 49    """
 50    Class that holds, trials, processes, and classifies EEG data.
 51    This class is used for processing of continuous EEG data in trials of a defined length.
 52    """
 53
 54    # 0. Special methods (e.g. __init__)
 55    def __init__(
 56        self,
 57        classifier: GenericClassifier,
 58        eeg_source: EegSource,
 59        marker_source: MarkerSource | None = None,
 60        paradigm: Paradigm | None = None,
 61        data_tank: DataTank | None = None,
 62        messenger: Messenger | None = None,
 63    ):
 64        """Initializes `BciController` class.
 65
 66        Parameters
 67        ----------
 68        classifier : GenericClassifier
 69            The classifier used by BciController.
 70        eeg_source : EegSource
 71            Source of EEG data and timestamps, this could be from a file or headset via LSL, etc.
 72        marker_source : EegSource
 73            Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc.
 74            - Default is `None`.
 75        paradigm : Paradigm
 76            The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data.
 77        data_tank : DataTank
 78            DataTank object to handle the storage of EEG trials and labels.
 79            - Default is `None`.
 80        messenger: Messenger
 81            Messenger object to handle events from BciController, ex: acknowledging markers and
 82            predictions.
 83            - Default is `None`.
 84
 85        """
 86
 87        # Ensure the incoming dependencies are the right type
 88        assert isinstance(classifier, GenericClassifier)
 89        assert isinstance(eeg_source, EegSource)
 90        assert isinstance(marker_source, MarkerSource | None)
 91        assert isinstance(paradigm, Paradigm | None)
 92        assert isinstance(data_tank, DataTank | None)
 93        assert isinstance(messenger, Messenger | None)
 94
 95        self._classifier = classifier
 96        self.__eeg_source = eeg_source
 97        self.__marker_source = marker_source
 98        self.__paradigm = paradigm
 99        self.__data_tank = data_tank
100        self._messenger = messenger
101
102        self.headset_string = self.__eeg_source.name
103        self.fsample = self.__eeg_source.fsample
104        self.n_channels = self.__eeg_source.n_channels
105        self.ch_type = self.__eeg_source.channel_types
106        self.ch_units = self.__eeg_source.channel_units
107        self.channel_labels = self.__eeg_source.channel_labels
108
109        self.__data_tank.set_source_data(
110            self.headset_string,
111            self.fsample,
112            self.n_channels,
113            self.ch_type,
114            self.ch_units,
115            self.channel_labels,
116        )
117
118        # Switch any trigger channels to stim, this is for mne/bids export (?)
119        self.ch_type = [type.replace("trg", "stim") for type in self.ch_type]
120
121        self._classifier.channel_labels = self.channel_labels
122
123        logger.info(self.headset_string)
124        logger.info(self.channel_labels)
125
126        # Initialize data and timestamp arrays to the right dimensions, but zero elements
127        self.marker_data = np.zeros((0, 1))
128        self.marker_timestamps = np.zeros((0))
129        self.bci_controller = np.zeros((0, self.n_channels))
130        self.eeg_timestamps = np.zeros((0))
131
132        # Initialize marker methods dictionary
133        self.marker_methods = {
134            MarkerTypes.DONE_RS_COLLECTION.value: self.__process_resting_state_data,
135            MarkerTypes.TRIAL_STARTED.value: self.__log_trial_start,
136            MarkerTypes.TRIAL_ENDS.value: self.__handle_trial_end,
137            MarkerTypes.TRAINING_COMPLETE.value: self.__update_and_train_classifier,
138            MarkerTypes.TRAIN_CLASSIFIER.value: self.__update_and_train_classifier,
139            MarkerTypes.UPDATE_CLASSIFIER.value: self.__update_and_train_classifier,
140        }
141
142        self.ping_count = 0
143        self.n_samples = 0
144        self.time_units = ""
145
146    # 1. Core public API methods
147    def setup(
148        self,
149        online=True,
150        train_complete=False,
151        train_lock=False,
152        auto_save_epochs=True,
153    ):
154        """Configure processing loop.
155
156        This should be called before starting the loop with run() or step().
157
158        Calling after will reset the loop state.
159
160        The processing loop reads in EEG and marker data and processes it.
161        The loop can be run in "offline" or "online" modes:
162        - If in `online` mode, then the loop will continuously try to read
163        in data from the `BciController` object and process it. The loop will
164        terminate when `max_loops` is reached, or when manually terminated.
165        - If in `offline` mode, then the loop will read in all of the data
166        at once, process it, and then terminate.
167
168        Parameters
169        ----------
170        online : bool, *optional*
171            Flag to indicate if the data will be processed in `online` mode.
172            - `True`: The data will be processed in `online` mode.
173            - `False`: The data will be processed in `offline` mode.
174            - Default is `True`.
175        train_complete : bool, *optional*
176            Flag to indicate if the classifier has been trained.
177            - `True`: The classifier has been trained.
178            - `False`: The classifier has not been trained.
179            - Default is `False`.
180        train_lock : bool, *optional*
181            Flag to indicate if the classifier is locked (ie. no more training).
182            - `True`: The classifier is locked.
183            - `False`: The classifier is not locked.
184            - Default is `False`.
185        auto_save_epochs : bool, *optional*
186            Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes.
187            - `True`: Epochs will be saved to a temp file.
188            - `False`: Epochs will not be saved to a temp file.
189
190
191        Returns
192        -------
193        `None`
194
195        """
196        self.online = online
197        self.train_complete = train_complete
198        self.train_lock = train_lock
199        self.auto_save_epochs = auto_save_epochs
200
201        # initialize the numbers of markers and trials to zero
202        self.marker_count = 0
203        self.current_num_trials = 0
204        self.n_trials = 0
205
206        self.num_online_selections = 0
207        self.online_selection_indices = []
208        self.online_selections = []
209
210        # Check for a temp_epochs file
211        if online:
212            self.__load_temp_epochs_if_available()
213
214    def step(self):
215        """Runs a single BciController processing step.
216
217        See setup() for configuration of processing.
218
219        The method:
220        1. Pulls data from sources (EEG and markers).
221        2. Run a while loop to process markers as long as there are unprocessed markers.
222        3. The while loop processes the markers in the following order:
223            - First checks if the marker is a known command marker from self.marker_methods.
224            - Then checks if it's an event marker (contains commas)
225            - If neither, logs a warning about unknown marker type
226        3. If the marker is a command marker, handles it by calling __handle_command_marker().
227        4. If the marker is an event marker, handles it by calling __handle_event_marker().
228        5. If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker.
229            - Note: If there is an unknown marker type, the marker count is still incremented and processing continues.
230        6. If the command or event marker handling return continue_flag as False, break out of the while loop and end the step.
231
232        Parameters
233        ----------
234        `None`
235
236        Returns
237        ------
238        `None`
239
240        """
241        # read from sources to get new data.
242        # This puts command markers in the marker_data array and
243        # event markers in the event_marker_strings array
244        self._pull_data_from_sources()
245
246        # Process markers while there are unprocessed markers
247        # REMOVE COMMENT: check if there is an available command marker, if not, break and wait for more data
248        while len(self.marker_timestamps) > self.marker_count:
249            # Get the current marker
250            current_step_marker = self.marker_data[self.marker_count]  # String
251            current_timestamp = self.marker_timestamps[self.marker_count]  # Float
252
253            # If marker is empty, skip it
254            if not current_step_marker:
255                logger.warning("Empty marker received")
256                self.marker_count += 1
257                continue
258
259            # If messenger is available, send feedback for each marker received
260            if self._messenger is not None:
261                self._messenger.marker_received(current_step_marker)
262
263            # Process markers in order specified in the docstrings
264            # First check if it's a known command marker
265            if current_step_marker in self.marker_methods:
266                continue_flag = self.__handle_command_marker(current_step_marker)
267            # Then check if it's an event marker (contains commas)
268            elif "," in current_step_marker:
269                continue_flag = self.__handle_event_marker(
270                    current_step_marker, current_timestamp
271                )
272            # Otherwise, log a warning about unknown marker type
273            else:
274                # Log warning for unknown marker types
275                logger.warning("Unknown marker type received: %s", current_step_marker)
276                continue_flag = True
277
278            # Check if we should continue processing markers in the while loop
279            # if continue_flag is False, then break out of the while loop
280            # else, increment the marker count and process the next marker
281            if continue_flag is False:
282                break
283            else:
284                logger.info("Processed Marker: %s", current_step_marker)
285                self.marker_count += 1
286
287    def run(self, max_loops: int = 1000000):
288        """Runs BciController processing in a loop.
289
290        See setup() for configuration of processing.
291
292        Parameters
293        ----------
294        max_loops : int, *optional*
295            Maximum number of loops to run, default is `1000000`.
296
297        Returns
298        ------
299        `None`
300
301        """
302        # if offline, then all data is already loaded, only need to loop once
303        if self.online is False:
304            self.loops = max_loops - 1
305        else:
306            self.loops = 0
307
308        # Initialize the event marker buffer
309        self.event_marker_buffer = []
310        self.event_timestamp_buffer = []
311
312        # start the main loop, stops after pulling new data, max_loops times
313        while self.loops < max_loops:
314            # print out loop status
315            if self.loops % 100 == 0:
316                logger.debug(self.loops)
317
318            if self.loops == max_loops - 1:
319                logger.debug("last loop")
320
321            # read from sources and process
322            self.step()
323
324            # Wait a short period of time and then try to pull more data
325            if self.online:
326                time.sleep(0.00001)
327
328            self.loops += 1
329
330    # 2. Protected methods (single underscore)
331    def _pull_data_from_sources(self):
332        """Get pull data from EEG and optionally, the marker source.
333
334        This method will fill up the marker_data, bci_controller and corresponding timestamp arrays.
335
336        Parameters
337        ----------
338        `None`
339
340        Returns
341        -------
342        `None`
343
344        """
345        # Get new data from source, whatever it is
346        self.__pull_marker_data_from_source()
347        self.__pull_eeg_data_from_source()
348
349        # If the outlet exists send a ping
350        if self._messenger is not None:
351            self.ping_count += 1
352            self._messenger.ping()
353
354    # 3. Private methods (double underscore)
355    # 3a. Private methods for retrieving data from sources
356    def __pull_marker_data_from_source(self):
357        """Pulls marker samples from source, sanity checks and appends to buffer.
358
359        Parameters
360        ----------
361        `None`
362
363        Returns
364        -------
365        `None`
366
367        """
368
369        # if there isn't a marker source, abort
370        if self.__marker_source is None:
371            return
372
373        # read in the data
374        markers, timestamps = self.__marker_source.get_markers()
375        markers = np.array(markers)
376        timestamps = np.array(timestamps)
377
378        if markers.size == 0:
379            return
380
381        if markers.ndim != 2:
382            logger.warning("discarded invalid marker data")
383            return
384
385        # apply time correction
386        time_correction = self.__marker_source.time_correction()
387        timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))]
388
389        for i, marker in enumerate(markers):
390            marker = marker[0]
391            if "Ping" in marker:
392                continue
393
394            # Add all markers to the controller
395            self.marker_data = np.append(self.marker_data, marker)
396            self.marker_timestamps = np.append(self.marker_timestamps, timestamps[i])
397
398            # Add all markers to the data tank
399            self.__data_tank.add_raw_markers(
400                np.array([marker]), np.array([timestamps[i]])
401            )
402
403    def __pull_eeg_data_from_source(self):
404        """Pulls eeg samples from source, sanity checks and appends to buffer.
405
406        Parameters
407        ----------
408        `None`
409
410        Returns
411        -------
412        `None`
413
414        """
415
416        # read in the data
417        eeg, timestamps = self.__eeg_source.get_samples()
418        eeg = np.array(eeg)
419        timestamps = np.array(timestamps)
420
421        if eeg.size == 0:
422            return
423
424        if eeg.ndim != 2:
425            logger.warning("discarded invalid eeg data")
426            return
427
428        # if time is in milliseconds, divide by 1000, works for sampling rates above 10Hz
429        try:
430            if self.time_units == "milliseconds":
431                timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))]
432
433        # If time units are not defined then define them
434        except Exception:
435            dif_low = -2
436            dif_high = -1
437            while timestamps[dif_high] - timestamps[dif_low] == 0:
438                dif_low -= 1
439                dif_high -= 1
440
441            if timestamps[dif_high] - timestamps[dif_low] > 0.1:
442                timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))]
443                self.time_units = "milliseconds"
444            else:
445                self.time_units = "seconds"
446
447        # apply time correction, this is essential for headsets like neurosity which have their own clock
448        time_correction = self.__eeg_source.time_correction()
449        timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))]
450
451        self.__data_tank.add_raw_eeg(eeg.T, timestamps)
452
453        # Update latest EEG timestamp
454        self.latest_eeg_timestamp = timestamps[-1]
455
456    # 3b. Private methods for data processing and classification
457    def __process_resting_state_data(self):
458        """Handles the resting state data by packaging it and adding it to the data tank.
459
460        Parameters
461        ----------
462        `None`
463
464        Returns
465        ------
466        continue_flag : bool
467            Flag indicating to continue the while loop in step().
468
469        """
470        (
471            self.bci_controller,
472            self.eeg_timestamps,
473        ) = self.__data_tank.get_raw_eeg()
474
475        resting_state_data = self.__paradigm.package_resting_state_data(
476            self.marker_data,
477            self.marker_timestamps,
478            self.bci_controller,
479            self.eeg_timestamps,
480            self.fsample,
481        )
482
483        self.__data_tank.add_resting_state_data(resting_state_data)
484
485        return True  # Continue processing
486
487    def __process_and_classify(self):
488        """Process the markers and classify the data.
489
490        Parameters
491        ----------
492        `None`
493
494        Returns
495        -------
496        success_string : str
497            String indicating if the processing and classification was successful.
498            Potential values are "Success", "Skip", "Wait".
499            - "Success": The processing and classification was successful.
500            - "Skip": EEG is either absent entirely or contains lost packets.
501            - "Wait": The processing is waiting for more data.
502
503        """
504
505        eeg_start_time, eeg_end_time = self.__paradigm.get_eeg_start_and_end_times(
506            self.event_marker_buffer, self.event_timestamp_buffer
507        )
508
509        # No we actually need to wait until we have all the data for these markers
510        eeg, timestamps = self.__data_tank.get_raw_eeg()
511
512        # Check if there is available EEG data
513        if len(eeg) == 0:
514            logger.warning("No EEG data available")
515            return "Skip"
516
517        # If the last timestamp is less than the end time, then we don't have the necessary EEG to process
518        if timestamps[-1] < eeg_end_time:
519            return "Wait"
520
521        # Check if EEG sampling is continuous over this time period
522        start_indices = np.where(timestamps > eeg_start_time)[0]
523        if len(start_indices) == 0:
524            logger.warning("No timestamps exceed eeg_start_time")
525            return "Skip"
526        start_index = start_indices[0]
527        end_index = np.where(timestamps < eeg_end_time)[0][-1]
528
529        time_diffs = np.diff(timestamps[start_index:end_index])
530        if np.any(time_diffs > 2 / self.fsample):
531            logger.warning("Time gaps in EEG data")
532            return "Skip"
533
534        X, y = self.__paradigm.process_markers(
535            self.event_marker_buffer,
536            self.event_timestamp_buffer,
537            eeg,
538            timestamps,
539            self.fsample,
540        )
541
542        sum_new_labeled_trials = np.sum(y != -1)
543
544        # Add the epochs to the data tank
545        self.__data_tank.add_epochs(X, y)
546
547        # Save epochs to temp_epochs file
548        if self.auto_save_epochs and self.online and sum_new_labeled_trials > 0:
549            paradigm_str = self.__paradigm.paradigm_name
550
551            with open(self.temp_epochs, "wb") as f:
552                np.savez(
553                    f,
554                    X=self.__data_tank.epochs,
555                    y=self.__data_tank.labels,
556                    paradigm=paradigm_str,
557                )
558
559        # If either there are no labels OR iterative training is on, then make a prediction
560        if self.train_complete:
561            if -1 in y or self.__paradigm.iterative_training:
562                prediction = self._classifier.predict(X)
563                self.__send_prediction(prediction)
564
565        self.event_marker_buffer = []
566        self.event_timestamp_buffer = []
567
568        return "Success"
569
570    def __update_and_train_classifier(self):
571        """Updates the classifier if required.
572
573        Parameters
574        ----------
575        `None`
576
577        Returns
578        -------
579        continue_flag : bool
580            Flag indicating to continue the while loop in step().
581        """
582        if self.train_lock is False:
583            # Pull the epochs from the data tank and pass them to the classifier
584            X, y = self.__data_tank.get_epochs(latest=True)
585
586            # Remove epochs with label -1
587            ind_to_remove = []
588            for i, label in enumerate(y):
589                if label == -1:
590                    ind_to_remove.append(i)
591            X = np.delete(X, ind_to_remove, axis=0)
592            y = np.delete(y, ind_to_remove, axis=0)
593
594            # Check that there are epochs
595            if len(y) > 0:
596                self._classifier.add_to_train(X, y)
597
598            if self._classifier._check_ready_for_fit():
599                self._classifier.fit()
600                self.train_complete = True
601
602        return True
603
604    # 3c. Private methods for event handling (trial and markers) and messaging
605    def __log_trial_start(self):
606        """Logs the start of a trial.
607
608        Parameters
609        ----------
610        `None`
611
612        Returns
613        -------
614        continue_flag : bool
615            Flag indicating to continue the while loop in step().
616
617        """
618        logger.debug("Trial started, incrementing marker count and continuing")
619        # Note that a marker occured, but do nothing else
620        return True  # Continue processing
621
622    def __handle_trial_end(self):
623        """Handles the end of a trial. Processes and classifies trial data if required.
624
625        Parameters
626        ----------
627        `None`
628
629        Returns
630        ------
631        success_flag : bool
632            Flag indicating if the processing and classification was successful.
633            - Returns `True` if not classifying.
634        """
635        # If we are classifying based on trials, then process the trial,
636        if self.__paradigm.classify_each_trial:
637            success_string = self.__process_and_classify()
638            if success_string == "Wait":
639                logger.debug("Processing of trial not run: waiting for more data")
640                return False
641            if success_string == "Skip":
642                logger.warning("Processing of trial failed: skipping trial")
643                self.event_marker_buffer = []
644                self.event_timestamp_buffer = []
645                self.marker_count += 1
646                return False
647
648        return True  # Return True by default if not classifying
649
650    def __handle_event_marker(self, marker, timestamp):
651        """Processes and classifies event markers.
652
653        Parameters
654        ----------
655        marker : str
656            Event marker string containing comma-separated values.
657            - Format depends on paradigm implementation.
658        timestamp : float
659            Timestamp of the marker in seconds (after time correction).
660
661
662        Returns
663        ------
664        continue_flag : bool
665            Flag indicating to continue the while loop in step().
666
667        """
668        # Add the marker to the event marker buffer
669        self.event_marker_buffer.append(marker)
670        self.event_timestamp_buffer.append(timestamp)
671
672        # If classification is on epochs, then update epochs, maybe classify, and clear the buffer
673        if self.__paradigm.classify_each_epoch:
674            success_string = self.__process_and_classify()
675            if success_string == "Wait":
676                logger.debug("Processing of epoch not run: waiting for more data")
677                self.event_marker_buffer = []
678                self.event_timestamp_buffer = []
679                return False  # Stop processing
680            elif success_string == "Skip":
681                logger.warning("Processing of epoch failed: skipping epoch")
682                self.event_marker_buffer = []
683                self.event_timestamp_buffer = []
684
685                self.marker_count += 1
686                return False  # Stop processing
687
688        return True  # Continue processing
689
690    def __handle_command_marker(self, marker: str) -> bool:
691        """Processes a command marker by invoking its associated method.
692
693        The command marker string is assumed to be in the self.marker_methods dictionary.
694        The associated method is retrieved and called.
695        The return value of the method is used to determine if processing should continue.
696
697        Parameters
698        ----------
699        marker : str
700            A command marker string (assumed to be in self.marker_methods).
701
702        Returns
703        -------
704        bool
705            A flag indicating if the processing should continue.
706
707        """
708        command_marker_method = self.marker_methods[marker]  # Retrieve method
709        continue_flag = command_marker_method()  # Call method
710
711        # Debug level logging if continue_flag is FALSE
712        if continue_flag is False:
713            logger.debug("Command marker '%s' set continue_flag to FALSE", marker)
714
715        return continue_flag
716
717    def __send_prediction(self, prediction):
718        """Send a prediction to the messenger object.
719
720        Parameters
721        ----------
722        `None`
723
724        Returns
725        -------
726        `None`
727
728        """
729        if self._messenger is not None:
730            logger.debug("Sending prediction: %s", prediction)
731            self._messenger.prediction(prediction)
732        elif self._messenger is None and self.online is True:
733            # If running in online mode and messenger is not available, log a warning
734            logger.warning(
735                "Messenger not available (self._messenger is None). Prediction not sent: %s",
736                prediction,
737            )
738
739    def __load_temp_epochs_if_available(self, reload_data_time: int = 300):
740        """Load temp_epochs if available and valid.
741
742        Parameters
743        ----------
744        reload_data_time : int, *optional*
745            Time in seconds of the last temp_epochs file to reload the data from.
746            Default is `300` seconds (5 minutes).
747
748        Returns
749        -------
750        `None`
751
752        """
753        self.temp_epochs = os.path.join(
754            os.path.dirname(os.path.dirname(__file__)), "temp_epochs.npz"
755        )
756
757        if not os.path.exists(self.temp_epochs):
758            return
759
760        # If temp_epochs is older than `reload_data_time`, delete it
761        if os.path.getmtime(self.temp_epochs) < (time.time() - reload_data_time):
762            os.remove(self.temp_epochs)
763            logger.info("Deleted old temp_epochs file.")
764            return
765
766        # Load the temp_epochs file
767        with open(self.temp_epochs, "rb") as f:
768            npz = np.load(f, allow_pickle=True)
769            X = npz["X"]
770            y = npz["y"]
771            paradigm_str = npz["paradigm"].item()
772
773        # If the paradigm is different, delete the file
774        if self.__paradigm.paradigm_name != paradigm_str:
775            logger.warning(
776                "Paradigm in temp_epochs file does not match current paradigm. Deleting file."
777            )
778            os.remove(self.temp_epochs)
779            return
780
781        # If the paradigm is the same, then add the epochs to the data tank
782        logger.info("Loading epochs from temp_epochs file.")
783        logger.info("X shape: %s", X.shape)
784        logger.info("y shape: %s", y.shape)
785        self.__data_tank.add_epochs(X, y)
786
787        # If there are epochs in the data tank, then train the classifier
788        if len(self.__data_tank.labels) > 0:
789            self.__update_and_train_classifier()

Class that holds, trials, processes, and classifies EEG data. This class is used for processing of continuous EEG data in trials of a defined length.

 55    def __init__(
 56        self,
 57        classifier: GenericClassifier,
 58        eeg_source: EegSource,
 59        marker_source: MarkerSource | None = None,
 60        paradigm: Paradigm | None = None,
 61        data_tank: DataTank | None = None,
 62        messenger: Messenger | None = None,
 63    ):
 64        """Initializes `BciController` class.
 65
 66        Parameters
 67        ----------
 68        classifier : GenericClassifier
 69            The classifier used by BciController.
 70        eeg_source : EegSource
 71            Source of EEG data and timestamps, this could be from a file or headset via LSL, etc.
 72        marker_source : EegSource
 73            Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc.
 74            - Default is `None`.
 75        paradigm : Paradigm
 76            The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data.
 77        data_tank : DataTank
 78            DataTank object to handle the storage of EEG trials and labels.
 79            - Default is `None`.
 80        messenger: Messenger
 81            Messenger object to handle events from BciController, ex: acknowledging markers and
 82            predictions.
 83            - Default is `None`.
 84
 85        """
 86
 87        # Ensure the incoming dependencies are the right type
 88        assert isinstance(classifier, GenericClassifier)
 89        assert isinstance(eeg_source, EegSource)
 90        assert isinstance(marker_source, MarkerSource | None)
 91        assert isinstance(paradigm, Paradigm | None)
 92        assert isinstance(data_tank, DataTank | None)
 93        assert isinstance(messenger, Messenger | None)
 94
 95        self._classifier = classifier
 96        self.__eeg_source = eeg_source
 97        self.__marker_source = marker_source
 98        self.__paradigm = paradigm
 99        self.__data_tank = data_tank
100        self._messenger = messenger
101
102        self.headset_string = self.__eeg_source.name
103        self.fsample = self.__eeg_source.fsample
104        self.n_channels = self.__eeg_source.n_channels
105        self.ch_type = self.__eeg_source.channel_types
106        self.ch_units = self.__eeg_source.channel_units
107        self.channel_labels = self.__eeg_source.channel_labels
108
109        self.__data_tank.set_source_data(
110            self.headset_string,
111            self.fsample,
112            self.n_channels,
113            self.ch_type,
114            self.ch_units,
115            self.channel_labels,
116        )
117
118        # Switch any trigger channels to stim, this is for mne/bids export (?)
119        self.ch_type = [type.replace("trg", "stim") for type in self.ch_type]
120
121        self._classifier.channel_labels = self.channel_labels
122
123        logger.info(self.headset_string)
124        logger.info(self.channel_labels)
125
126        # Initialize data and timestamp arrays to the right dimensions, but zero elements
127        self.marker_data = np.zeros((0, 1))
128        self.marker_timestamps = np.zeros((0))
129        self.bci_controller = np.zeros((0, self.n_channels))
130        self.eeg_timestamps = np.zeros((0))
131
132        # Initialize marker methods dictionary
133        self.marker_methods = {
134            MarkerTypes.DONE_RS_COLLECTION.value: self.__process_resting_state_data,
135            MarkerTypes.TRIAL_STARTED.value: self.__log_trial_start,
136            MarkerTypes.TRIAL_ENDS.value: self.__handle_trial_end,
137            MarkerTypes.TRAINING_COMPLETE.value: self.__update_and_train_classifier,
138            MarkerTypes.TRAIN_CLASSIFIER.value: self.__update_and_train_classifier,
139            MarkerTypes.UPDATE_CLASSIFIER.value: self.__update_and_train_classifier,
140        }
141
142        self.ping_count = 0
143        self.n_samples = 0
144        self.time_units = ""

Initializes BciController class.

Parameters
  • classifier (GenericClassifier): The classifier used by BciController.
  • eeg_source (EegSource): Source of EEG data and timestamps, this could be from a file or headset via LSL, etc.
  • marker_source (EegSource): Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc.
    • Default is None.
  • paradigm (Paradigm): The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data.
  • data_tank (DataTank): DataTank object to handle the storage of EEG trials and labels.
    • Default is None.
  • messenger (Messenger): Messenger object to handle events from BciController, ex: acknowledging markers and predictions.
    • Default is None.
headset_string
fsample
n_channels
ch_type
ch_units
channel_labels
marker_data
marker_timestamps
bci_controller
eeg_timestamps
marker_methods
ping_count
n_samples
time_units
def setup( self, online=True, train_complete=False, train_lock=False, auto_save_epochs=True):
147    def setup(
148        self,
149        online=True,
150        train_complete=False,
151        train_lock=False,
152        auto_save_epochs=True,
153    ):
154        """Configure processing loop.
155
156        This should be called before starting the loop with run() or step().
157
158        Calling after will reset the loop state.
159
160        The processing loop reads in EEG and marker data and processes it.
161        The loop can be run in "offline" or "online" modes:
162        - If in `online` mode, then the loop will continuously try to read
163        in data from the `BciController` object and process it. The loop will
164        terminate when `max_loops` is reached, or when manually terminated.
165        - If in `offline` mode, then the loop will read in all of the data
166        at once, process it, and then terminate.
167
168        Parameters
169        ----------
170        online : bool, *optional*
171            Flag to indicate if the data will be processed in `online` mode.
172            - `True`: The data will be processed in `online` mode.
173            - `False`: The data will be processed in `offline` mode.
174            - Default is `True`.
175        train_complete : bool, *optional*
176            Flag to indicate if the classifier has been trained.
177            - `True`: The classifier has been trained.
178            - `False`: The classifier has not been trained.
179            - Default is `False`.
180        train_lock : bool, *optional*
181            Flag to indicate if the classifier is locked (ie. no more training).
182            - `True`: The classifier is locked.
183            - `False`: The classifier is not locked.
184            - Default is `False`.
185        auto_save_epochs : bool, *optional*
186            Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes.
187            - `True`: Epochs will be saved to a temp file.
188            - `False`: Epochs will not be saved to a temp file.
189
190
191        Returns
192        -------
193        `None`
194
195        """
196        self.online = online
197        self.train_complete = train_complete
198        self.train_lock = train_lock
199        self.auto_save_epochs = auto_save_epochs
200
201        # initialize the numbers of markers and trials to zero
202        self.marker_count = 0
203        self.current_num_trials = 0
204        self.n_trials = 0
205
206        self.num_online_selections = 0
207        self.online_selection_indices = []
208        self.online_selections = []
209
210        # Check for a temp_epochs file
211        if online:
212            self.__load_temp_epochs_if_available()

Configure processing loop.

This should be called before starting the loop with run() or step().

Calling after will reset the loop state.

The processing loop reads in EEG and marker data and processes it. The loop can be run in "offline" or "online" modes:

  • If in online mode, then the loop will continuously try to read in data from the BciController object and process it. The loop will terminate when max_loops is reached, or when manually terminated.
  • If in offline mode, then the loop will read in all of the data at once, process it, and then terminate.
Parameters
  • online (bool, optional): Flag to indicate if the data will be processed in online mode.
    • True: The data will be processed in online mode.
    • False: The data will be processed in offline mode.
    • Default is True.
  • train_complete (bool, optional): Flag to indicate if the classifier has been trained.
    • True: The classifier has been trained.
    • False: The classifier has not been trained.
    • Default is False.
  • train_lock (bool, optional): Flag to indicate if the classifier is locked (ie. no more training).
    • True: The classifier is locked.
    • False: The classifier is not locked.
    • Default is False.
  • auto_save_epochs (bool, optional): Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes.
    • True: Epochs will be saved to a temp file.
    • False: Epochs will not be saved to a temp file.
Returns
  • None
def step(self):
214    def step(self):
215        """Runs a single BciController processing step.
216
217        See setup() for configuration of processing.
218
219        The method:
220        1. Pulls data from sources (EEG and markers).
221        2. Run a while loop to process markers as long as there are unprocessed markers.
222        3. The while loop processes the markers in the following order:
223            - First checks if the marker is a known command marker from self.marker_methods.
224            - Then checks if it's an event marker (contains commas)
225            - If neither, logs a warning about unknown marker type
226        3. If the marker is a command marker, handles it by calling __handle_command_marker().
227        4. If the marker is an event marker, handles it by calling __handle_event_marker().
228        5. If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker.
229            - Note: If there is an unknown marker type, the marker count is still incremented and processing continues.
230        6. If the command or event marker handling return continue_flag as False, break out of the while loop and end the step.
231
232        Parameters
233        ----------
234        `None`
235
236        Returns
237        ------
238        `None`
239
240        """
241        # read from sources to get new data.
242        # This puts command markers in the marker_data array and
243        # event markers in the event_marker_strings array
244        self._pull_data_from_sources()
245
246        # Process markers while there are unprocessed markers
247        # REMOVE COMMENT: check if there is an available command marker, if not, break and wait for more data
248        while len(self.marker_timestamps) > self.marker_count:
249            # Get the current marker
250            current_step_marker = self.marker_data[self.marker_count]  # String
251            current_timestamp = self.marker_timestamps[self.marker_count]  # Float
252
253            # If marker is empty, skip it
254            if not current_step_marker:
255                logger.warning("Empty marker received")
256                self.marker_count += 1
257                continue
258
259            # If messenger is available, send feedback for each marker received
260            if self._messenger is not None:
261                self._messenger.marker_received(current_step_marker)
262
263            # Process markers in order specified in the docstrings
264            # First check if it's a known command marker
265            if current_step_marker in self.marker_methods:
266                continue_flag = self.__handle_command_marker(current_step_marker)
267            # Then check if it's an event marker (contains commas)
268            elif "," in current_step_marker:
269                continue_flag = self.__handle_event_marker(
270                    current_step_marker, current_timestamp
271                )
272            # Otherwise, log a warning about unknown marker type
273            else:
274                # Log warning for unknown marker types
275                logger.warning("Unknown marker type received: %s", current_step_marker)
276                continue_flag = True
277
278            # Check if we should continue processing markers in the while loop
279            # if continue_flag is False, then break out of the while loop
280            # else, increment the marker count and process the next marker
281            if continue_flag is False:
282                break
283            else:
284                logger.info("Processed Marker: %s", current_step_marker)
285                self.marker_count += 1

Runs a single BciController processing step.

See setup() for configuration of processing.

The method:

  1. Pulls data from sources (EEG and markers).
  2. Run a while loop to process markers as long as there are unprocessed markers.
  3. The while loop processes the markers in the following order:
    • First checks if the marker is a known command marker from self.marker_methods.
    • Then checks if it's an event marker (contains commas)
    • If neither, logs a warning about unknown marker type
  4. If the marker is a command marker, handles it by calling __handle_command_marker().
  5. If the marker is an event marker, handles it by calling __handle_event_marker().
  6. If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker.
    • Note: If there is an unknown marker type, the marker count is still incremented and processing continues.
  7. If the command or event marker handling return continue_flag as False, break out of the while loop and end the step.
Parameters
  • None
Returns
  • None
def run(self, max_loops: int = 1000000):
287    def run(self, max_loops: int = 1000000):
288        """Runs BciController processing in a loop.
289
290        See setup() for configuration of processing.
291
292        Parameters
293        ----------
294        max_loops : int, *optional*
295            Maximum number of loops to run, default is `1000000`.
296
297        Returns
298        ------
299        `None`
300
301        """
302        # if offline, then all data is already loaded, only need to loop once
303        if self.online is False:
304            self.loops = max_loops - 1
305        else:
306            self.loops = 0
307
308        # Initialize the event marker buffer
309        self.event_marker_buffer = []
310        self.event_timestamp_buffer = []
311
312        # start the main loop, stops after pulling new data, max_loops times
313        while self.loops < max_loops:
314            # print out loop status
315            if self.loops % 100 == 0:
316                logger.debug(self.loops)
317
318            if self.loops == max_loops - 1:
319                logger.debug("last loop")
320
321            # read from sources and process
322            self.step()
323
324            # Wait a short period of time and then try to pull more data
325            if self.online:
326                time.sleep(0.00001)
327
328            self.loops += 1

Runs BciController processing in a loop.

See setup() for configuration of processing.

Parameters
  • max_loops (int, optional): Maximum number of loops to run, default is 1000000.
Returns
  • None