bci_essentials.bci_controller

Module for managing BCI data.

This module provides data classes for different BCI paradigms.

It includes the loading of offline data in xdf format or the live streaming of LSL data.

The loaded/streamed data is added to a buffer such that offline and online processing pipelines are identical.

Data is pre-processed (using the signal_processing module), divided into trials, and classified (using one of the classification sub-modules).

Classes
  • BciController : For processing continuous data in trials of a defined length.
  1"""Module for managing BCI data.
  2
  3This module provides data classes for different BCI paradigms.
  4
  5It includes the loading of offline data in `xdf` format
  6or the live streaming of LSL data.
  7
  8The loaded/streamed data is added to a buffer such that offline and
  9online processing pipelines are identical.
 10
 11Data is pre-processed (using the `signal_processing` module), divided into trials,
 12and classified (using one of the `classification` sub-modules).
 13
 14Classes
 15-------
 16- `BciController` : For processing continuous data in trials of a defined length.
 17
 18"""
 19
 20import time
 21import os
 22import numpy as np
 23from enum import Enum
 24
 25from .paradigm.paradigm import Paradigm
 26from .data_tank.data_tank import DataTank
 27from .classification.generic_classifier import GenericClassifier
 28from .io.sources import EegSource, MarkerSource
 29from .io.messenger import Messenger
 30from .utils.logger import Logger
 31
 32# Instantiate a logger for the module at the default level of logging.INFO
 33# Logs to bci_essentials.__module__) where __module__ is the name of the module
 34logger = Logger(name=__name__)
 35
 36
 37class MarkerTypes(Enum):
 38    TRIAL_STARTED = "Trial Started"
 39    TRIAL_ENDS = "Trial Ends"
 40    TRAINING_COMPLETE = "Training Complete"
 41    TRAIN_CLASSIFIER = "Train Classifier"
 42    DONE_RS_COLLECTION = "Done with all RS collection"
 43    UPDATE_CLASSIFIER = "Update Classifier"
 44
 45
 46# EEG data
 47class BciController:
 48    """
 49    Class that holds, trials, processes, and classifies EEG data.
 50    This class is used for processing of continuous EEG data in trials of a defined length.
 51    """
 52
 53    # 0. Special methods (e.g. __init__)
 54    def __init__(
 55        self,
 56        classifier: GenericClassifier,
 57        eeg_source: EegSource,
 58        marker_source: MarkerSource | None = None,
 59        paradigm: Paradigm | None = None,
 60        data_tank: DataTank | None = None,
 61        messenger: Messenger | None = None,
 62    ):
 63        """Initializes `BciController` class.
 64
 65        Parameters
 66        ----------
 67        classifier : GenericClassifier
 68            The classifier used by BciController.
 69        eeg_source : EegSource
 70            Source of EEG data and timestamps, this could be from a file or headset via LSL, etc.
 71        marker_source : EegSource
 72            Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc.
 73            - Default is `None`.
 74        paradigm : Paradigm
 75            The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data.
 76        data_tank : DataTank
 77            DataTank object to handle the storage of EEG trials and labels.
 78            - Default is `None`.
 79        messenger: Messenger
 80            Messenger object to handle events from BciController, ex: acknowledging markers and
 81            predictions.
 82            - Default is `None`.
 83
 84        """
 85
 86        # Ensure the incoming dependencies are the right type
 87        assert isinstance(classifier, GenericClassifier)
 88        assert isinstance(eeg_source, EegSource)
 89        assert isinstance(marker_source, MarkerSource | None)
 90        assert isinstance(paradigm, Paradigm | None)
 91        assert isinstance(data_tank, DataTank | None)
 92        assert isinstance(messenger, Messenger | None)
 93
 94        self._classifier = classifier
 95        self.__eeg_source = eeg_source
 96        self.__marker_source = marker_source
 97        self.__paradigm = paradigm
 98        self.__data_tank = data_tank
 99        self._messenger = messenger
100
101        self.headset_string = self.__eeg_source.name
102        self.fsample = self.__eeg_source.fsample
103        self.n_channels = self.__eeg_source.n_channels
104        self.ch_type = self.__eeg_source.channel_types
105        self.ch_units = self.__eeg_source.channel_units
106        self.channel_labels = self.__eeg_source.channel_labels
107
108        # Emily EGI fix
109        # Set default channel types if none
110        if self.ch_type is None:
111            logger.warning("Channel types are none, setting all to 'eeg'")
112            self.ch_type = ["eeg"] * self.n_channels
113
114        self.__data_tank.set_source_data(
115            self.headset_string,
116            self.fsample,
117            self.n_channels,
118            self.ch_type,
119            self.ch_units,
120            self.channel_labels,
121        )
122
123        # Switch any trigger channels to stim, this is for mne/bids export (?)
124        self.ch_type = [type.replace("trg", "stim") for type in self.ch_type]
125
126        self._classifier.channel_labels = self.channel_labels
127
128        logger.info(self.headset_string)
129        logger.info(self.channel_labels)
130
131        # Initialize data and timestamp arrays to the right dimensions, but zero elements
132        self.marker_data = np.zeros((0, 1))
133        self.marker_timestamps = np.zeros((0))
134        self.bci_controller = np.zeros((0, self.n_channels))
135        self.eeg_timestamps = np.zeros((0))
136
137        # Initialize marker methods dictionary
138        self.marker_methods = {
139            MarkerTypes.DONE_RS_COLLECTION.value: self.__process_resting_state_data,
140            MarkerTypes.TRIAL_STARTED.value: self.__log_trial_start,
141            MarkerTypes.TRIAL_ENDS.value: self.__handle_trial_end,
142            MarkerTypes.TRAINING_COMPLETE.value: self.__update_and_train_classifier,
143            MarkerTypes.TRAIN_CLASSIFIER.value: self.__update_and_train_classifier,
144            MarkerTypes.UPDATE_CLASSIFIER.value: self.__update_and_train_classifier,
145        }
146
147        self.step_count = 0
148        self.ping_interval = 1000
149        self.n_samples = 0
150        self.time_units = ""
151
152    # 1. Core public API methods
153    def setup(
154        self,
155        online=True,
156        train_complete=False,
157        train_lock=False,
158        auto_save_epochs=True,
159    ):
160        """Configure processing loop.
161
162        This should be called before starting the loop with run() or step().
163
164        Calling after will reset the loop state.
165
166        The processing loop reads in EEG and marker data and processes it.
167        The loop can be run in "offline" or "online" modes:
168        - If in `online` mode, then the loop will continuously try to read
169        in data from the `BciController` object and process it. The loop will
170        terminate when `max_loops` is reached, or when manually terminated.
171        - If in `offline` mode, then the loop will read in all of the data
172        at once, process it, and then terminate.
173
174        Parameters
175        ----------
176        online : bool, *optional*
177            Flag to indicate if the data will be processed in `online` mode.
178            - `True`: The data will be processed in `online` mode.
179            - `False`: The data will be processed in `offline` mode.
180            - Default is `True`.
181        train_complete : bool, *optional*
182            Flag to indicate if the classifier has been trained.
183            - `True`: The classifier has been trained.
184            - `False`: The classifier has not been trained.
185            - Default is `False`.
186        train_lock : bool, *optional*
187            Flag to indicate if the classifier is locked (ie. no more training).
188            - `True`: The classifier is locked.
189            - `False`: The classifier is not locked.
190            - Default is `False`.
191        auto_save_epochs : bool, *optional*
192            Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes.
193            - `True`: Epochs will be saved to a temp file.
194            - `False`: Epochs will not be saved to a temp file.
195
196
197        Returns
198        -------
199        `None`
200
201        """
202        self.online = online
203        self.train_complete = train_complete
204        self.train_lock = train_lock
205        self.auto_save_epochs = auto_save_epochs
206
207        # initialize the numbers of markers and trials to zero
208        self.marker_count = 0
209        self.current_num_trials = 0
210        self.n_trials = 0
211
212        self.num_online_selections = 0
213        self.online_selection_indices = []
214        self.online_selections = []
215
216        # Check for a temp_epochs file
217        if online:
218            self.__load_temp_epochs_if_available()
219
220    def step(self):
221        """Runs a single BciController processing step.
222
223        See setup() for configuration of processing.
224
225        The method:
226        1. Pulls data from sources (EEG and markers).
227        2. Run a while loop to process markers as long as there are unprocessed markers.
228        3. The while loop processes the markers in the following order:
229            - First checks if the marker is a known command marker from self.marker_methods.
230            - Then checks if it's an event marker (contains commas)
231            - If neither, logs a warning about unknown marker type
232        3. If the marker is a command marker, handles it by calling __handle_command_marker().
233        4. If the marker is an event marker, handles it by calling __handle_event_marker().
234        5. If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker.
235            - Note: If there is an unknown marker type, the marker count is still incremented and processing continues.
236        6. If the command or event marker handling return continue_flag as False, break out of the while loop and end the step.
237
238        Parameters
239        ----------
240        `None`
241
242        Returns
243        ------
244        `None`
245
246        """
247        # read from sources to get new data.
248        # This puts command markers in the marker_data array and
249        # event markers in the event_marker_strings array
250        self._pull_data_from_sources()
251
252        # Process markers while there are unprocessed markers
253        # REMOVE COMMENT: check if there is an available command marker, if not, break and wait for more data
254        while len(self.marker_timestamps) > self.marker_count:
255            # Get the current marker
256            current_step_marker = self.marker_data[self.marker_count]  # String
257            current_timestamp = self.marker_timestamps[self.marker_count]  # Float
258
259            # If marker is empty, skip it
260            if not current_step_marker:
261                logger.warning("Empty marker received")
262                self.marker_count += 1
263                continue
264
265            # If messenger is available, send feedback for each marker received
266            if self._messenger is not None:
267                self._messenger.marker_received(current_step_marker)
268
269            # Process markers in order specified in the docstrings
270            # First check if it's a known command marker
271            if current_step_marker in self.marker_methods:
272                continue_flag = self.__handle_command_marker(current_step_marker)
273            # Then check if it's an event marker (contains commas)
274            elif "," in current_step_marker:
275                continue_flag = self.__handle_event_marker(
276                    current_step_marker, current_timestamp
277                )
278            # Otherwise, log a warning about unknown marker type
279            else:
280                # Log warning for unknown marker types
281                logger.warning("Unknown marker type received: %s", current_step_marker)
282                continue_flag = True
283
284            # Check if we should continue processing markers in the while loop
285            # if continue_flag is False, then break out of the while loop
286            # else, increment the marker count and process the next marker
287            if continue_flag is False:
288                break
289            else:
290                logger.info("Processed Marker: %s", current_step_marker)
291                self.marker_count += 1
292
293        self.step_count += 1
294        if self.step_count % self.ping_interval == 0:
295            if self._messenger is not None:
296                self._messenger.ping()
297
298    def run(self, max_loops: int = 1000000, ping_interval: int = 100):
299        """Runs BciController processing in a loop.
300
301        See setup() for configuration of processing.
302
303        Parameters
304        ----------
305        max_loops : int, *optional*
306            Maximum number of loops to run, default is `1000000`.
307        ping_interval : int, *optional*
308            Number of steps between each messenger ping.
309
310        Returns
311        ------
312        `None`
313
314        """
315        # if offline, then all data is already loaded, only need to loop once
316        if self.online is False:
317            self.loops = max_loops - 1
318        else:
319            self.loops = 0
320
321        self.ping_interval = ping_interval
322
323        # Initialize the event marker buffer
324        self.event_marker_buffer = []
325        self.event_timestamp_buffer = []
326
327        # start the main loop, stops after pulling new data, max_loops times
328        while self.loops < max_loops:
329            # print out loop status
330            if self.loops % 100 == 0:
331                logger.debug(self.loops)
332
333            if self.loops == max_loops - 1:
334                logger.debug("last loop")
335
336            # read from sources and process
337            self.step()
338
339            # Wait a short period of time and then try to pull more data
340            if self.online:
341                time.sleep(0.00001)
342
343            self.loops += 1
344
345    # 2. Protected methods (single underscore)
346    def _pull_data_from_sources(self):
347        """Get pull data from EEG and optionally, the marker source.
348
349        This method will fill up the marker_data, bci_controller and corresponding timestamp arrays.
350
351        Parameters
352        ----------
353        `None`
354
355        Returns
356        -------
357        `None`
358
359        """
360        # Get new data from source, whatever it is
361        self.__pull_marker_data_from_source()
362        self.__pull_eeg_data_from_source()
363
364    # 3. Private methods (double underscore)
365    # 3a. Private methods for retrieving data from sources
366    def __pull_marker_data_from_source(self):
367        """Pulls marker samples from source, sanity checks and appends to buffer.
368
369        Parameters
370        ----------
371        `None`
372
373        Returns
374        -------
375        `None`
376
377        """
378
379        # if there isn't a marker source, abort
380        if self.__marker_source is None:
381            return
382
383        # read in the data
384        markers, timestamps = self.__marker_source.get_markers()
385        markers = np.array(markers)
386        timestamps = np.array(timestamps)
387
388        if markers.size == 0:
389            return
390
391        if markers.ndim != 2:
392            logger.warning("discarded invalid marker data")
393            return
394
395        # apply time correction
396        time_correction = self.__marker_source.time_correction()
397        timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))]
398
399        for i, marker in enumerate(markers):
400            marker = marker[0]
401            if "Ping" in marker:
402                continue
403
404            # Add all markers to the controller
405            self.marker_data = np.append(self.marker_data, marker)
406            self.marker_timestamps = np.append(self.marker_timestamps, timestamps[i])
407
408            # Add all markers to the data tank
409            self.__data_tank.add_raw_markers(
410                np.array([marker]), np.array([timestamps[i]])
411            )
412
413    def __pull_eeg_data_from_source(self):
414        """Pulls eeg samples from source, sanity checks and appends to buffer.
415
416        Parameters
417        ----------
418        `None`
419
420        Returns
421        -------
422        `None`
423
424        """
425
426        # read in the data
427        eeg, timestamps = self.__eeg_source.get_samples()
428        eeg = np.array(eeg)
429        timestamps = np.array(timestamps)
430
431        if eeg.size == 0:
432            return
433
434        if eeg.ndim != 2:
435            logger.warning("discarded invalid eeg data")
436            return
437
438        # if time is in milliseconds, divide by 1000, works for sampling rates above 10Hz
439        try:
440            if self.time_units == "milliseconds":
441                timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))]
442
443        # If time units are not defined then define them
444        except Exception:
445            dif_low = -2
446            dif_high = -1
447            while timestamps[dif_high] - timestamps[dif_low] == 0:
448                dif_low -= 1
449                dif_high -= 1
450
451            if timestamps[dif_high] - timestamps[dif_low] > 0.1:
452                timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))]
453                self.time_units = "milliseconds"
454            else:
455                self.time_units = "seconds"
456
457        # apply time correction, this is essential for headsets like neurosity which have their own clock
458        time_correction = self.__eeg_source.time_correction()
459        timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))]
460
461        self.__data_tank.add_raw_eeg(eeg.T, timestamps)
462
463        # Update latest EEG timestamp
464        self.latest_eeg_timestamp = timestamps[-1]
465
466    # 3b. Private methods for data processing and classification
467    def __process_resting_state_data(self):
468        """Handles the resting state data by packaging it and adding it to the data tank.
469
470        Parameters
471        ----------
472        `None`
473
474        Returns
475        ------
476        continue_flag : bool
477            Flag indicating to continue the while loop in step().
478
479        """
480        (
481            self.bci_controller,
482            self.eeg_timestamps,
483        ) = self.__data_tank.get_raw_eeg()
484
485        resting_state_data = self.__paradigm.package_resting_state_data(
486            self.marker_data,
487            self.marker_timestamps,
488            self.bci_controller,
489            self.eeg_timestamps,
490            self.fsample,
491        )
492
493        self.__data_tank.add_resting_state_data(resting_state_data)
494
495        return True  # Continue processing
496
497    def __process_and_classify(self):
498        """Process the markers and classify the data.
499
500        Parameters
501        ----------
502        `None`
503
504        Returns
505        -------
506        success_string : str
507            String indicating if the processing and classification was successful.
508            Potential values are "Success", "Skip", "Wait".
509            - "Success": The processing and classification was successful.
510            - "Skip": EEG is either absent entirely or contains lost packets.
511            - "Wait": The processing is waiting for more data.
512
513        """
514
515        eeg_start_time, eeg_end_time = self.__paradigm.get_eeg_start_and_end_times(
516            self.event_marker_buffer, self.event_timestamp_buffer
517        )
518
519        # No we actually need to wait until we have all the data for these markers
520        eeg, timestamps = self.__data_tank.get_raw_eeg()
521
522        # Check if there is available EEG data
523        if len(eeg) == 0:
524            logger.warning("No EEG data available")
525            return "Skip"
526
527        # If the last timestamp is less than the end time, then we don't have the necessary EEG to process
528        if timestamps[-1] < eeg_end_time:
529            return "Wait"
530
531        # Check if EEG sampling is continuous over this time period
532        start_indices = np.where(timestamps > eeg_start_time)[0]
533        if len(start_indices) == 0:
534            logger.warning("No timestamps exceed eeg_start_time")
535            return "Skip"
536        start_index = start_indices[0]
537        end_index = np.where(timestamps < eeg_end_time)[0][-1]
538
539        time_diffs = np.diff(timestamps[start_index:end_index])
540        if np.any(time_diffs > 2 / self.fsample):
541            logger.warning("Time gaps in EEG data")
542            return "Skip"
543
544        X, y = self.__paradigm.process_markers(
545            self.event_marker_buffer,
546            self.event_timestamp_buffer,
547            eeg,
548            timestamps,
549            self.fsample,
550        )
551
552        sum_new_labeled_trials = np.sum(y != -1)
553
554        # Add the epochs to the data tank
555        self.__data_tank.add_epochs(X, y)
556
557        # Save epochs to temp_epochs file
558        if self.auto_save_epochs and self.online and sum_new_labeled_trials > 0:
559            paradigm_str = self.__paradigm.paradigm_name
560
561            with open(self.temp_epochs, "wb") as f:
562                np.savez(
563                    f,
564                    X=self.__data_tank.epochs,
565                    y=self.__data_tank.labels,
566                    paradigm=paradigm_str,
567                )
568
569        # If either there are no labels OR iterative training is on, then make a prediction
570        if self.train_complete:
571            if -1 in y or self.__paradigm.iterative_training:
572                prediction = self._classifier.predict(X)
573                self.__send_prediction(prediction)
574
575        self.event_marker_buffer = []
576        self.event_timestamp_buffer = []
577
578        return "Success"
579
580    def __update_and_train_classifier(self):
581        """Updates the classifier if required.
582
583        Parameters
584        ----------
585        `None`
586
587        Returns
588        -------
589        continue_flag : bool
590            Flag indicating to continue the while loop in step().
591        """
592        if self.train_lock is False:
593            # Pull the epochs from the data tank and pass them to the classifier
594            X, y = self.__data_tank.get_epochs(latest=True)
595
596            # Remove epochs with label -1
597            ind_to_remove = []
598            for i, label in enumerate(y):
599                if label == -1:
600                    ind_to_remove.append(i)
601            X = np.delete(X, ind_to_remove, axis=0)
602            y = np.delete(y, ind_to_remove, axis=0)
603
604            # Check that there are epochs
605            if len(y) > 0:
606                self._classifier.add_to_train(X, y)
607
608            if self._classifier._check_ready_for_fit():
609                self._classifier.fit()
610                self.train_complete = True
611
612        return True
613
614    # 3c. Private methods for event handling (trial and markers) and messaging
615    def __log_trial_start(self):
616        """Logs the start of a trial.
617
618        Parameters
619        ----------
620        `None`
621
622        Returns
623        -------
624        continue_flag : bool
625            Flag indicating to continue the while loop in step().
626
627        """
628        logger.debug("Trial started, incrementing marker count and continuing")
629        # Note that a marker occured, but do nothing else
630        return True  # Continue processing
631
632    def __handle_trial_end(self):
633        """Handles the end of a trial. Processes and classifies trial data if required.
634
635        Parameters
636        ----------
637        `None`
638
639        Returns
640        ------
641        success_flag : bool
642            Flag indicating if the processing and classification was successful.
643            - Returns `True` if not classifying.
644        """
645        # If we are classifying based on trials, then process the trial,
646        if self.__paradigm.classify_each_trial:
647            success_string = self.__process_and_classify()
648            if success_string == "Wait":
649                logger.debug("Processing of trial not run: waiting for more data")
650                return False
651            if success_string == "Skip":
652                logger.warning("Processing of trial failed: skipping trial")
653                self.event_marker_buffer = []
654                self.event_timestamp_buffer = []
655                self.marker_count += 1
656                return False
657
658        return True  # Return True by default if not classifying
659
660    def __handle_event_marker(self, marker, timestamp):
661        """Processes and classifies event markers.
662
663        Parameters
664        ----------
665        marker : str
666            Event marker string containing comma-separated values.
667            - Format depends on paradigm implementation.
668        timestamp : float
669            Timestamp of the marker in seconds (after time correction).
670
671
672        Returns
673        ------
674        continue_flag : bool
675            Flag indicating to continue the while loop in step().
676
677        """
678        # Add the marker to the event marker buffer
679        self.event_marker_buffer.append(marker)
680        self.event_timestamp_buffer.append(timestamp)
681
682        # If classification is on epochs, then update epochs, maybe classify, and clear the buffer
683        if self.__paradigm.classify_each_epoch:
684            success_string = self.__process_and_classify()
685            if success_string == "Wait":
686                logger.debug("Processing of epoch not run: waiting for more data")
687                self.event_marker_buffer = []
688                self.event_timestamp_buffer = []
689                return False  # Stop processing
690            elif success_string == "Skip":
691                logger.warning("Processing of epoch failed: skipping epoch")
692                self.event_marker_buffer = []
693                self.event_timestamp_buffer = []
694
695                self.marker_count += 1
696                return False  # Stop processing
697
698        return True  # Continue processing
699
700    def __handle_command_marker(self, marker: str) -> bool:
701        """Processes a command marker by invoking its associated method.
702
703        The command marker string is assumed to be in the self.marker_methods dictionary.
704        The associated method is retrieved and called.
705        The return value of the method is used to determine if processing should continue.
706
707        Parameters
708        ----------
709        marker : str
710            A command marker string (assumed to be in self.marker_methods).
711
712        Returns
713        -------
714        bool
715            A flag indicating if the processing should continue.
716
717        """
718        command_marker_method = self.marker_methods[marker]  # Retrieve method
719        continue_flag = command_marker_method()  # Call method
720
721        # Debug level logging if continue_flag is FALSE
722        if continue_flag is False:
723            logger.debug("Command marker '%s' set continue_flag to FALSE", marker)
724
725        return continue_flag
726
727    def __send_prediction(self, prediction):
728        """Send a prediction to the messenger object.
729
730        Parameters
731        ----------
732        `None`
733
734        Returns
735        -------
736        `None`
737
738        """
739        if self._messenger is not None:
740            logger.debug("Sending prediction: %s", prediction)
741            self._messenger.prediction(prediction)
742        elif self._messenger is None and self.online is True:
743            # If running in online mode and messenger is not available, log a warning
744            logger.warning(
745                "Messenger not available (self._messenger is None). Prediction not sent: %s",
746                prediction,
747            )
748
749    def __load_temp_epochs_if_available(self, reload_data_time: int = 300):
750        """Load temp_epochs if available and valid.
751
752        Parameters
753        ----------
754        reload_data_time : int, *optional*
755            Time in seconds of the last temp_epochs file to reload the data from.
756            Default is `300` seconds (5 minutes).
757
758        Returns
759        -------
760        `None`
761
762        """
763        self.temp_epochs = os.path.join(
764            os.path.dirname(os.path.dirname(__file__)), "temp_epochs.npz"
765        )
766
767        if not os.path.exists(self.temp_epochs):
768            return
769
770        # If temp_epochs is older than `reload_data_time`, delete it
771        if os.path.getmtime(self.temp_epochs) < (time.time() - reload_data_time):
772            os.remove(self.temp_epochs)
773            logger.info("Deleted old temp_epochs file.")
774            return
775
776        # Load the temp_epochs file
777        with open(self.temp_epochs, "rb") as f:
778            npz = np.load(f, allow_pickle=True)
779            X = npz["X"]
780            y = npz["y"]
781            paradigm_str = npz["paradigm"].item()
782
783        # If the paradigm is different, delete the file
784        if self.__paradigm.paradigm_name != paradigm_str:
785            logger.warning(
786                "Paradigm in temp_epochs file does not match current paradigm. Deleting file."
787            )
788            os.remove(self.temp_epochs)
789            return
790
791        # If the paradigm is the same, then add the epochs to the data tank
792        logger.info("Loading epochs from temp_epochs file.")
793        logger.info("X shape: %s", X.shape)
794        logger.info("y shape: %s", y.shape)
795        self.__data_tank.add_epochs(X, y)
796
797        # If there are epochs in the data tank, then train the classifier
798        if len(self.__data_tank.labels) > 0:
799            self.__update_and_train_classifier()
class MarkerTypes(enum.Enum):
38class MarkerTypes(Enum):
39    TRIAL_STARTED = "Trial Started"
40    TRIAL_ENDS = "Trial Ends"
41    TRAINING_COMPLETE = "Training Complete"
42    TRAIN_CLASSIFIER = "Train Classifier"
43    DONE_RS_COLLECTION = "Done with all RS collection"
44    UPDATE_CLASSIFIER = "Update Classifier"

An enumeration.

TRIAL_STARTED = <MarkerTypes.TRIAL_STARTED: 'Trial Started'>
TRIAL_ENDS = <MarkerTypes.TRIAL_ENDS: 'Trial Ends'>
TRAINING_COMPLETE = <MarkerTypes.TRAINING_COMPLETE: 'Training Complete'>
TRAIN_CLASSIFIER = <MarkerTypes.TRAIN_CLASSIFIER: 'Train Classifier'>
DONE_RS_COLLECTION = <MarkerTypes.DONE_RS_COLLECTION: 'Done with all RS collection'>
UPDATE_CLASSIFIER = <MarkerTypes.UPDATE_CLASSIFIER: 'Update Classifier'>
class BciController:
 48class BciController:
 49    """
 50    Class that holds, trials, processes, and classifies EEG data.
 51    This class is used for processing of continuous EEG data in trials of a defined length.
 52    """
 53
 54    # 0. Special methods (e.g. __init__)
 55    def __init__(
 56        self,
 57        classifier: GenericClassifier,
 58        eeg_source: EegSource,
 59        marker_source: MarkerSource | None = None,
 60        paradigm: Paradigm | None = None,
 61        data_tank: DataTank | None = None,
 62        messenger: Messenger | None = None,
 63    ):
 64        """Initializes `BciController` class.
 65
 66        Parameters
 67        ----------
 68        classifier : GenericClassifier
 69            The classifier used by BciController.
 70        eeg_source : EegSource
 71            Source of EEG data and timestamps, this could be from a file or headset via LSL, etc.
 72        marker_source : EegSource
 73            Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc.
 74            - Default is `None`.
 75        paradigm : Paradigm
 76            The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data.
 77        data_tank : DataTank
 78            DataTank object to handle the storage of EEG trials and labels.
 79            - Default is `None`.
 80        messenger: Messenger
 81            Messenger object to handle events from BciController, ex: acknowledging markers and
 82            predictions.
 83            - Default is `None`.
 84
 85        """
 86
 87        # Ensure the incoming dependencies are the right type
 88        assert isinstance(classifier, GenericClassifier)
 89        assert isinstance(eeg_source, EegSource)
 90        assert isinstance(marker_source, MarkerSource | None)
 91        assert isinstance(paradigm, Paradigm | None)
 92        assert isinstance(data_tank, DataTank | None)
 93        assert isinstance(messenger, Messenger | None)
 94
 95        self._classifier = classifier
 96        self.__eeg_source = eeg_source
 97        self.__marker_source = marker_source
 98        self.__paradigm = paradigm
 99        self.__data_tank = data_tank
100        self._messenger = messenger
101
102        self.headset_string = self.__eeg_source.name
103        self.fsample = self.__eeg_source.fsample
104        self.n_channels = self.__eeg_source.n_channels
105        self.ch_type = self.__eeg_source.channel_types
106        self.ch_units = self.__eeg_source.channel_units
107        self.channel_labels = self.__eeg_source.channel_labels
108
109        # Emily EGI fix
110        # Set default channel types if none
111        if self.ch_type is None:
112            logger.warning("Channel types are none, setting all to 'eeg'")
113            self.ch_type = ["eeg"] * self.n_channels
114
115        self.__data_tank.set_source_data(
116            self.headset_string,
117            self.fsample,
118            self.n_channels,
119            self.ch_type,
120            self.ch_units,
121            self.channel_labels,
122        )
123
124        # Switch any trigger channels to stim, this is for mne/bids export (?)
125        self.ch_type = [type.replace("trg", "stim") for type in self.ch_type]
126
127        self._classifier.channel_labels = self.channel_labels
128
129        logger.info(self.headset_string)
130        logger.info(self.channel_labels)
131
132        # Initialize data and timestamp arrays to the right dimensions, but zero elements
133        self.marker_data = np.zeros((0, 1))
134        self.marker_timestamps = np.zeros((0))
135        self.bci_controller = np.zeros((0, self.n_channels))
136        self.eeg_timestamps = np.zeros((0))
137
138        # Initialize marker methods dictionary
139        self.marker_methods = {
140            MarkerTypes.DONE_RS_COLLECTION.value: self.__process_resting_state_data,
141            MarkerTypes.TRIAL_STARTED.value: self.__log_trial_start,
142            MarkerTypes.TRIAL_ENDS.value: self.__handle_trial_end,
143            MarkerTypes.TRAINING_COMPLETE.value: self.__update_and_train_classifier,
144            MarkerTypes.TRAIN_CLASSIFIER.value: self.__update_and_train_classifier,
145            MarkerTypes.UPDATE_CLASSIFIER.value: self.__update_and_train_classifier,
146        }
147
148        self.step_count = 0
149        self.ping_interval = 1000
150        self.n_samples = 0
151        self.time_units = ""
152
153    # 1. Core public API methods
154    def setup(
155        self,
156        online=True,
157        train_complete=False,
158        train_lock=False,
159        auto_save_epochs=True,
160    ):
161        """Configure processing loop.
162
163        This should be called before starting the loop with run() or step().
164
165        Calling after will reset the loop state.
166
167        The processing loop reads in EEG and marker data and processes it.
168        The loop can be run in "offline" or "online" modes:
169        - If in `online` mode, then the loop will continuously try to read
170        in data from the `BciController` object and process it. The loop will
171        terminate when `max_loops` is reached, or when manually terminated.
172        - If in `offline` mode, then the loop will read in all of the data
173        at once, process it, and then terminate.
174
175        Parameters
176        ----------
177        online : bool, *optional*
178            Flag to indicate if the data will be processed in `online` mode.
179            - `True`: The data will be processed in `online` mode.
180            - `False`: The data will be processed in `offline` mode.
181            - Default is `True`.
182        train_complete : bool, *optional*
183            Flag to indicate if the classifier has been trained.
184            - `True`: The classifier has been trained.
185            - `False`: The classifier has not been trained.
186            - Default is `False`.
187        train_lock : bool, *optional*
188            Flag to indicate if the classifier is locked (ie. no more training).
189            - `True`: The classifier is locked.
190            - `False`: The classifier is not locked.
191            - Default is `False`.
192        auto_save_epochs : bool, *optional*
193            Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes.
194            - `True`: Epochs will be saved to a temp file.
195            - `False`: Epochs will not be saved to a temp file.
196
197
198        Returns
199        -------
200        `None`
201
202        """
203        self.online = online
204        self.train_complete = train_complete
205        self.train_lock = train_lock
206        self.auto_save_epochs = auto_save_epochs
207
208        # initialize the numbers of markers and trials to zero
209        self.marker_count = 0
210        self.current_num_trials = 0
211        self.n_trials = 0
212
213        self.num_online_selections = 0
214        self.online_selection_indices = []
215        self.online_selections = []
216
217        # Check for a temp_epochs file
218        if online:
219            self.__load_temp_epochs_if_available()
220
221    def step(self):
222        """Runs a single BciController processing step.
223
224        See setup() for configuration of processing.
225
226        The method:
227        1. Pulls data from sources (EEG and markers).
228        2. Run a while loop to process markers as long as there are unprocessed markers.
229        3. The while loop processes the markers in the following order:
230            - First checks if the marker is a known command marker from self.marker_methods.
231            - Then checks if it's an event marker (contains commas)
232            - If neither, logs a warning about unknown marker type
233        3. If the marker is a command marker, handles it by calling __handle_command_marker().
234        4. If the marker is an event marker, handles it by calling __handle_event_marker().
235        5. If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker.
236            - Note: If there is an unknown marker type, the marker count is still incremented and processing continues.
237        6. If the command or event marker handling return continue_flag as False, break out of the while loop and end the step.
238
239        Parameters
240        ----------
241        `None`
242
243        Returns
244        ------
245        `None`
246
247        """
248        # read from sources to get new data.
249        # This puts command markers in the marker_data array and
250        # event markers in the event_marker_strings array
251        self._pull_data_from_sources()
252
253        # Process markers while there are unprocessed markers
254        # REMOVE COMMENT: check if there is an available command marker, if not, break and wait for more data
255        while len(self.marker_timestamps) > self.marker_count:
256            # Get the current marker
257            current_step_marker = self.marker_data[self.marker_count]  # String
258            current_timestamp = self.marker_timestamps[self.marker_count]  # Float
259
260            # If marker is empty, skip it
261            if not current_step_marker:
262                logger.warning("Empty marker received")
263                self.marker_count += 1
264                continue
265
266            # If messenger is available, send feedback for each marker received
267            if self._messenger is not None:
268                self._messenger.marker_received(current_step_marker)
269
270            # Process markers in order specified in the docstrings
271            # First check if it's a known command marker
272            if current_step_marker in self.marker_methods:
273                continue_flag = self.__handle_command_marker(current_step_marker)
274            # Then check if it's an event marker (contains commas)
275            elif "," in current_step_marker:
276                continue_flag = self.__handle_event_marker(
277                    current_step_marker, current_timestamp
278                )
279            # Otherwise, log a warning about unknown marker type
280            else:
281                # Log warning for unknown marker types
282                logger.warning("Unknown marker type received: %s", current_step_marker)
283                continue_flag = True
284
285            # Check if we should continue processing markers in the while loop
286            # if continue_flag is False, then break out of the while loop
287            # else, increment the marker count and process the next marker
288            if continue_flag is False:
289                break
290            else:
291                logger.info("Processed Marker: %s", current_step_marker)
292                self.marker_count += 1
293
294        self.step_count += 1
295        if self.step_count % self.ping_interval == 0:
296            if self._messenger is not None:
297                self._messenger.ping()
298
299    def run(self, max_loops: int = 1000000, ping_interval: int = 100):
300        """Runs BciController processing in a loop.
301
302        See setup() for configuration of processing.
303
304        Parameters
305        ----------
306        max_loops : int, *optional*
307            Maximum number of loops to run, default is `1000000`.
308        ping_interval : int, *optional*
309            Number of steps between each messenger ping.
310
311        Returns
312        ------
313        `None`
314
315        """
316        # if offline, then all data is already loaded, only need to loop once
317        if self.online is False:
318            self.loops = max_loops - 1
319        else:
320            self.loops = 0
321
322        self.ping_interval = ping_interval
323
324        # Initialize the event marker buffer
325        self.event_marker_buffer = []
326        self.event_timestamp_buffer = []
327
328        # start the main loop, stops after pulling new data, max_loops times
329        while self.loops < max_loops:
330            # print out loop status
331            if self.loops % 100 == 0:
332                logger.debug(self.loops)
333
334            if self.loops == max_loops - 1:
335                logger.debug("last loop")
336
337            # read from sources and process
338            self.step()
339
340            # Wait a short period of time and then try to pull more data
341            if self.online:
342                time.sleep(0.00001)
343
344            self.loops += 1
345
346    # 2. Protected methods (single underscore)
347    def _pull_data_from_sources(self):
348        """Get pull data from EEG and optionally, the marker source.
349
350        This method will fill up the marker_data, bci_controller and corresponding timestamp arrays.
351
352        Parameters
353        ----------
354        `None`
355
356        Returns
357        -------
358        `None`
359
360        """
361        # Get new data from source, whatever it is
362        self.__pull_marker_data_from_source()
363        self.__pull_eeg_data_from_source()
364
365    # 3. Private methods (double underscore)
366    # 3a. Private methods for retrieving data from sources
367    def __pull_marker_data_from_source(self):
368        """Pulls marker samples from source, sanity checks and appends to buffer.
369
370        Parameters
371        ----------
372        `None`
373
374        Returns
375        -------
376        `None`
377
378        """
379
380        # if there isn't a marker source, abort
381        if self.__marker_source is None:
382            return
383
384        # read in the data
385        markers, timestamps = self.__marker_source.get_markers()
386        markers = np.array(markers)
387        timestamps = np.array(timestamps)
388
389        if markers.size == 0:
390            return
391
392        if markers.ndim != 2:
393            logger.warning("discarded invalid marker data")
394            return
395
396        # apply time correction
397        time_correction = self.__marker_source.time_correction()
398        timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))]
399
400        for i, marker in enumerate(markers):
401            marker = marker[0]
402            if "Ping" in marker:
403                continue
404
405            # Add all markers to the controller
406            self.marker_data = np.append(self.marker_data, marker)
407            self.marker_timestamps = np.append(self.marker_timestamps, timestamps[i])
408
409            # Add all markers to the data tank
410            self.__data_tank.add_raw_markers(
411                np.array([marker]), np.array([timestamps[i]])
412            )
413
414    def __pull_eeg_data_from_source(self):
415        """Pulls eeg samples from source, sanity checks and appends to buffer.
416
417        Parameters
418        ----------
419        `None`
420
421        Returns
422        -------
423        `None`
424
425        """
426
427        # read in the data
428        eeg, timestamps = self.__eeg_source.get_samples()
429        eeg = np.array(eeg)
430        timestamps = np.array(timestamps)
431
432        if eeg.size == 0:
433            return
434
435        if eeg.ndim != 2:
436            logger.warning("discarded invalid eeg data")
437            return
438
439        # if time is in milliseconds, divide by 1000, works for sampling rates above 10Hz
440        try:
441            if self.time_units == "milliseconds":
442                timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))]
443
444        # If time units are not defined then define them
445        except Exception:
446            dif_low = -2
447            dif_high = -1
448            while timestamps[dif_high] - timestamps[dif_low] == 0:
449                dif_low -= 1
450                dif_high -= 1
451
452            if timestamps[dif_high] - timestamps[dif_low] > 0.1:
453                timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))]
454                self.time_units = "milliseconds"
455            else:
456                self.time_units = "seconds"
457
458        # apply time correction, this is essential for headsets like neurosity which have their own clock
459        time_correction = self.__eeg_source.time_correction()
460        timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))]
461
462        self.__data_tank.add_raw_eeg(eeg.T, timestamps)
463
464        # Update latest EEG timestamp
465        self.latest_eeg_timestamp = timestamps[-1]
466
467    # 3b. Private methods for data processing and classification
468    def __process_resting_state_data(self):
469        """Handles the resting state data by packaging it and adding it to the data tank.
470
471        Parameters
472        ----------
473        `None`
474
475        Returns
476        ------
477        continue_flag : bool
478            Flag indicating to continue the while loop in step().
479
480        """
481        (
482            self.bci_controller,
483            self.eeg_timestamps,
484        ) = self.__data_tank.get_raw_eeg()
485
486        resting_state_data = self.__paradigm.package_resting_state_data(
487            self.marker_data,
488            self.marker_timestamps,
489            self.bci_controller,
490            self.eeg_timestamps,
491            self.fsample,
492        )
493
494        self.__data_tank.add_resting_state_data(resting_state_data)
495
496        return True  # Continue processing
497
498    def __process_and_classify(self):
499        """Process the markers and classify the data.
500
501        Parameters
502        ----------
503        `None`
504
505        Returns
506        -------
507        success_string : str
508            String indicating if the processing and classification was successful.
509            Potential values are "Success", "Skip", "Wait".
510            - "Success": The processing and classification was successful.
511            - "Skip": EEG is either absent entirely or contains lost packets.
512            - "Wait": The processing is waiting for more data.
513
514        """
515
516        eeg_start_time, eeg_end_time = self.__paradigm.get_eeg_start_and_end_times(
517            self.event_marker_buffer, self.event_timestamp_buffer
518        )
519
520        # No we actually need to wait until we have all the data for these markers
521        eeg, timestamps = self.__data_tank.get_raw_eeg()
522
523        # Check if there is available EEG data
524        if len(eeg) == 0:
525            logger.warning("No EEG data available")
526            return "Skip"
527
528        # If the last timestamp is less than the end time, then we don't have the necessary EEG to process
529        if timestamps[-1] < eeg_end_time:
530            return "Wait"
531
532        # Check if EEG sampling is continuous over this time period
533        start_indices = np.where(timestamps > eeg_start_time)[0]
534        if len(start_indices) == 0:
535            logger.warning("No timestamps exceed eeg_start_time")
536            return "Skip"
537        start_index = start_indices[0]
538        end_index = np.where(timestamps < eeg_end_time)[0][-1]
539
540        time_diffs = np.diff(timestamps[start_index:end_index])
541        if np.any(time_diffs > 2 / self.fsample):
542            logger.warning("Time gaps in EEG data")
543            return "Skip"
544
545        X, y = self.__paradigm.process_markers(
546            self.event_marker_buffer,
547            self.event_timestamp_buffer,
548            eeg,
549            timestamps,
550            self.fsample,
551        )
552
553        sum_new_labeled_trials = np.sum(y != -1)
554
555        # Add the epochs to the data tank
556        self.__data_tank.add_epochs(X, y)
557
558        # Save epochs to temp_epochs file
559        if self.auto_save_epochs and self.online and sum_new_labeled_trials > 0:
560            paradigm_str = self.__paradigm.paradigm_name
561
562            with open(self.temp_epochs, "wb") as f:
563                np.savez(
564                    f,
565                    X=self.__data_tank.epochs,
566                    y=self.__data_tank.labels,
567                    paradigm=paradigm_str,
568                )
569
570        # If either there are no labels OR iterative training is on, then make a prediction
571        if self.train_complete:
572            if -1 in y or self.__paradigm.iterative_training:
573                prediction = self._classifier.predict(X)
574                self.__send_prediction(prediction)
575
576        self.event_marker_buffer = []
577        self.event_timestamp_buffer = []
578
579        return "Success"
580
581    def __update_and_train_classifier(self):
582        """Updates the classifier if required.
583
584        Parameters
585        ----------
586        `None`
587
588        Returns
589        -------
590        continue_flag : bool
591            Flag indicating to continue the while loop in step().
592        """
593        if self.train_lock is False:
594            # Pull the epochs from the data tank and pass them to the classifier
595            X, y = self.__data_tank.get_epochs(latest=True)
596
597            # Remove epochs with label -1
598            ind_to_remove = []
599            for i, label in enumerate(y):
600                if label == -1:
601                    ind_to_remove.append(i)
602            X = np.delete(X, ind_to_remove, axis=0)
603            y = np.delete(y, ind_to_remove, axis=0)
604
605            # Check that there are epochs
606            if len(y) > 0:
607                self._classifier.add_to_train(X, y)
608
609            if self._classifier._check_ready_for_fit():
610                self._classifier.fit()
611                self.train_complete = True
612
613        return True
614
615    # 3c. Private methods for event handling (trial and markers) and messaging
616    def __log_trial_start(self):
617        """Logs the start of a trial.
618
619        Parameters
620        ----------
621        `None`
622
623        Returns
624        -------
625        continue_flag : bool
626            Flag indicating to continue the while loop in step().
627
628        """
629        logger.debug("Trial started, incrementing marker count and continuing")
630        # Note that a marker occured, but do nothing else
631        return True  # Continue processing
632
633    def __handle_trial_end(self):
634        """Handles the end of a trial. Processes and classifies trial data if required.
635
636        Parameters
637        ----------
638        `None`
639
640        Returns
641        ------
642        success_flag : bool
643            Flag indicating if the processing and classification was successful.
644            - Returns `True` if not classifying.
645        """
646        # If we are classifying based on trials, then process the trial,
647        if self.__paradigm.classify_each_trial:
648            success_string = self.__process_and_classify()
649            if success_string == "Wait":
650                logger.debug("Processing of trial not run: waiting for more data")
651                return False
652            if success_string == "Skip":
653                logger.warning("Processing of trial failed: skipping trial")
654                self.event_marker_buffer = []
655                self.event_timestamp_buffer = []
656                self.marker_count += 1
657                return False
658
659        return True  # Return True by default if not classifying
660
661    def __handle_event_marker(self, marker, timestamp):
662        """Processes and classifies event markers.
663
664        Parameters
665        ----------
666        marker : str
667            Event marker string containing comma-separated values.
668            - Format depends on paradigm implementation.
669        timestamp : float
670            Timestamp of the marker in seconds (after time correction).
671
672
673        Returns
674        ------
675        continue_flag : bool
676            Flag indicating to continue the while loop in step().
677
678        """
679        # Add the marker to the event marker buffer
680        self.event_marker_buffer.append(marker)
681        self.event_timestamp_buffer.append(timestamp)
682
683        # If classification is on epochs, then update epochs, maybe classify, and clear the buffer
684        if self.__paradigm.classify_each_epoch:
685            success_string = self.__process_and_classify()
686            if success_string == "Wait":
687                logger.debug("Processing of epoch not run: waiting for more data")
688                self.event_marker_buffer = []
689                self.event_timestamp_buffer = []
690                return False  # Stop processing
691            elif success_string == "Skip":
692                logger.warning("Processing of epoch failed: skipping epoch")
693                self.event_marker_buffer = []
694                self.event_timestamp_buffer = []
695
696                self.marker_count += 1
697                return False  # Stop processing
698
699        return True  # Continue processing
700
701    def __handle_command_marker(self, marker: str) -> bool:
702        """Processes a command marker by invoking its associated method.
703
704        The command marker string is assumed to be in the self.marker_methods dictionary.
705        The associated method is retrieved and called.
706        The return value of the method is used to determine if processing should continue.
707
708        Parameters
709        ----------
710        marker : str
711            A command marker string (assumed to be in self.marker_methods).
712
713        Returns
714        -------
715        bool
716            A flag indicating if the processing should continue.
717
718        """
719        command_marker_method = self.marker_methods[marker]  # Retrieve method
720        continue_flag = command_marker_method()  # Call method
721
722        # Debug level logging if continue_flag is FALSE
723        if continue_flag is False:
724            logger.debug("Command marker '%s' set continue_flag to FALSE", marker)
725
726        return continue_flag
727
728    def __send_prediction(self, prediction):
729        """Send a prediction to the messenger object.
730
731        Parameters
732        ----------
733        `None`
734
735        Returns
736        -------
737        `None`
738
739        """
740        if self._messenger is not None:
741            logger.debug("Sending prediction: %s", prediction)
742            self._messenger.prediction(prediction)
743        elif self._messenger is None and self.online is True:
744            # If running in online mode and messenger is not available, log a warning
745            logger.warning(
746                "Messenger not available (self._messenger is None). Prediction not sent: %s",
747                prediction,
748            )
749
750    def __load_temp_epochs_if_available(self, reload_data_time: int = 300):
751        """Load temp_epochs if available and valid.
752
753        Parameters
754        ----------
755        reload_data_time : int, *optional*
756            Time in seconds of the last temp_epochs file to reload the data from.
757            Default is `300` seconds (5 minutes).
758
759        Returns
760        -------
761        `None`
762
763        """
764        self.temp_epochs = os.path.join(
765            os.path.dirname(os.path.dirname(__file__)), "temp_epochs.npz"
766        )
767
768        if not os.path.exists(self.temp_epochs):
769            return
770
771        # If temp_epochs is older than `reload_data_time`, delete it
772        if os.path.getmtime(self.temp_epochs) < (time.time() - reload_data_time):
773            os.remove(self.temp_epochs)
774            logger.info("Deleted old temp_epochs file.")
775            return
776
777        # Load the temp_epochs file
778        with open(self.temp_epochs, "rb") as f:
779            npz = np.load(f, allow_pickle=True)
780            X = npz["X"]
781            y = npz["y"]
782            paradigm_str = npz["paradigm"].item()
783
784        # If the paradigm is different, delete the file
785        if self.__paradigm.paradigm_name != paradigm_str:
786            logger.warning(
787                "Paradigm in temp_epochs file does not match current paradigm. Deleting file."
788            )
789            os.remove(self.temp_epochs)
790            return
791
792        # If the paradigm is the same, then add the epochs to the data tank
793        logger.info("Loading epochs from temp_epochs file.")
794        logger.info("X shape: %s", X.shape)
795        logger.info("y shape: %s", y.shape)
796        self.__data_tank.add_epochs(X, y)
797
798        # If there are epochs in the data tank, then train the classifier
799        if len(self.__data_tank.labels) > 0:
800            self.__update_and_train_classifier()

Class that holds, trials, processes, and classifies EEG data. This class is used for processing of continuous EEG data in trials of a defined length.

 55    def __init__(
 56        self,
 57        classifier: GenericClassifier,
 58        eeg_source: EegSource,
 59        marker_source: MarkerSource | None = None,
 60        paradigm: Paradigm | None = None,
 61        data_tank: DataTank | None = None,
 62        messenger: Messenger | None = None,
 63    ):
 64        """Initializes `BciController` class.
 65
 66        Parameters
 67        ----------
 68        classifier : GenericClassifier
 69            The classifier used by BciController.
 70        eeg_source : EegSource
 71            Source of EEG data and timestamps, this could be from a file or headset via LSL, etc.
 72        marker_source : EegSource
 73            Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc.
 74            - Default is `None`.
 75        paradigm : Paradigm
 76            The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data.
 77        data_tank : DataTank
 78            DataTank object to handle the storage of EEG trials and labels.
 79            - Default is `None`.
 80        messenger: Messenger
 81            Messenger object to handle events from BciController, ex: acknowledging markers and
 82            predictions.
 83            - Default is `None`.
 84
 85        """
 86
 87        # Ensure the incoming dependencies are the right type
 88        assert isinstance(classifier, GenericClassifier)
 89        assert isinstance(eeg_source, EegSource)
 90        assert isinstance(marker_source, MarkerSource | None)
 91        assert isinstance(paradigm, Paradigm | None)
 92        assert isinstance(data_tank, DataTank | None)
 93        assert isinstance(messenger, Messenger | None)
 94
 95        self._classifier = classifier
 96        self.__eeg_source = eeg_source
 97        self.__marker_source = marker_source
 98        self.__paradigm = paradigm
 99        self.__data_tank = data_tank
100        self._messenger = messenger
101
102        self.headset_string = self.__eeg_source.name
103        self.fsample = self.__eeg_source.fsample
104        self.n_channels = self.__eeg_source.n_channels
105        self.ch_type = self.__eeg_source.channel_types
106        self.ch_units = self.__eeg_source.channel_units
107        self.channel_labels = self.__eeg_source.channel_labels
108
109        # Emily EGI fix
110        # Set default channel types if none
111        if self.ch_type is None:
112            logger.warning("Channel types are none, setting all to 'eeg'")
113            self.ch_type = ["eeg"] * self.n_channels
114
115        self.__data_tank.set_source_data(
116            self.headset_string,
117            self.fsample,
118            self.n_channels,
119            self.ch_type,
120            self.ch_units,
121            self.channel_labels,
122        )
123
124        # Switch any trigger channels to stim, this is for mne/bids export (?)
125        self.ch_type = [type.replace("trg", "stim") for type in self.ch_type]
126
127        self._classifier.channel_labels = self.channel_labels
128
129        logger.info(self.headset_string)
130        logger.info(self.channel_labels)
131
132        # Initialize data and timestamp arrays to the right dimensions, but zero elements
133        self.marker_data = np.zeros((0, 1))
134        self.marker_timestamps = np.zeros((0))
135        self.bci_controller = np.zeros((0, self.n_channels))
136        self.eeg_timestamps = np.zeros((0))
137
138        # Initialize marker methods dictionary
139        self.marker_methods = {
140            MarkerTypes.DONE_RS_COLLECTION.value: self.__process_resting_state_data,
141            MarkerTypes.TRIAL_STARTED.value: self.__log_trial_start,
142            MarkerTypes.TRIAL_ENDS.value: self.__handle_trial_end,
143            MarkerTypes.TRAINING_COMPLETE.value: self.__update_and_train_classifier,
144            MarkerTypes.TRAIN_CLASSIFIER.value: self.__update_and_train_classifier,
145            MarkerTypes.UPDATE_CLASSIFIER.value: self.__update_and_train_classifier,
146        }
147
148        self.step_count = 0
149        self.ping_interval = 1000
150        self.n_samples = 0
151        self.time_units = ""

Initializes BciController class.

Parameters
  • classifier (GenericClassifier): The classifier used by BciController.
  • eeg_source (EegSource): Source of EEG data and timestamps, this could be from a file or headset via LSL, etc.
  • marker_source (EegSource): Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc.
    • Default is None.
  • paradigm (Paradigm): The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data.
  • data_tank (DataTank): DataTank object to handle the storage of EEG trials and labels.
    • Default is None.
  • messenger (Messenger): Messenger object to handle events from BciController, ex: acknowledging markers and predictions.
    • Default is None.
headset_string
fsample
n_channels
ch_type
ch_units
channel_labels
marker_data
marker_timestamps
bci_controller
eeg_timestamps
marker_methods
step_count
ping_interval
n_samples
time_units
def setup( self, online=True, train_complete=False, train_lock=False, auto_save_epochs=True):
154    def setup(
155        self,
156        online=True,
157        train_complete=False,
158        train_lock=False,
159        auto_save_epochs=True,
160    ):
161        """Configure processing loop.
162
163        This should be called before starting the loop with run() or step().
164
165        Calling after will reset the loop state.
166
167        The processing loop reads in EEG and marker data and processes it.
168        The loop can be run in "offline" or "online" modes:
169        - If in `online` mode, then the loop will continuously try to read
170        in data from the `BciController` object and process it. The loop will
171        terminate when `max_loops` is reached, or when manually terminated.
172        - If in `offline` mode, then the loop will read in all of the data
173        at once, process it, and then terminate.
174
175        Parameters
176        ----------
177        online : bool, *optional*
178            Flag to indicate if the data will be processed in `online` mode.
179            - `True`: The data will be processed in `online` mode.
180            - `False`: The data will be processed in `offline` mode.
181            - Default is `True`.
182        train_complete : bool, *optional*
183            Flag to indicate if the classifier has been trained.
184            - `True`: The classifier has been trained.
185            - `False`: The classifier has not been trained.
186            - Default is `False`.
187        train_lock : bool, *optional*
188            Flag to indicate if the classifier is locked (ie. no more training).
189            - `True`: The classifier is locked.
190            - `False`: The classifier is not locked.
191            - Default is `False`.
192        auto_save_epochs : bool, *optional*
193            Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes.
194            - `True`: Epochs will be saved to a temp file.
195            - `False`: Epochs will not be saved to a temp file.
196
197
198        Returns
199        -------
200        `None`
201
202        """
203        self.online = online
204        self.train_complete = train_complete
205        self.train_lock = train_lock
206        self.auto_save_epochs = auto_save_epochs
207
208        # initialize the numbers of markers and trials to zero
209        self.marker_count = 0
210        self.current_num_trials = 0
211        self.n_trials = 0
212
213        self.num_online_selections = 0
214        self.online_selection_indices = []
215        self.online_selections = []
216
217        # Check for a temp_epochs file
218        if online:
219            self.__load_temp_epochs_if_available()

Configure processing loop.

This should be called before starting the loop with run() or step().

Calling after will reset the loop state.

The processing loop reads in EEG and marker data and processes it. The loop can be run in "offline" or "online" modes:

  • If in online mode, then the loop will continuously try to read in data from the BciController object and process it. The loop will terminate when max_loops is reached, or when manually terminated.
  • If in offline mode, then the loop will read in all of the data at once, process it, and then terminate.
Parameters
  • online (bool, optional): Flag to indicate if the data will be processed in online mode.
    • True: The data will be processed in online mode.
    • False: The data will be processed in offline mode.
    • Default is True.
  • train_complete (bool, optional): Flag to indicate if the classifier has been trained.
    • True: The classifier has been trained.
    • False: The classifier has not been trained.
    • Default is False.
  • train_lock (bool, optional): Flag to indicate if the classifier is locked (ie. no more training).
    • True: The classifier is locked.
    • False: The classifier is not locked.
    • Default is False.
  • auto_save_epochs (bool, optional): Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes.
    • True: Epochs will be saved to a temp file.
    • False: Epochs will not be saved to a temp file.
Returns
  • None
def step(self):
221    def step(self):
222        """Runs a single BciController processing step.
223
224        See setup() for configuration of processing.
225
226        The method:
227        1. Pulls data from sources (EEG and markers).
228        2. Run a while loop to process markers as long as there are unprocessed markers.
229        3. The while loop processes the markers in the following order:
230            - First checks if the marker is a known command marker from self.marker_methods.
231            - Then checks if it's an event marker (contains commas)
232            - If neither, logs a warning about unknown marker type
233        3. If the marker is a command marker, handles it by calling __handle_command_marker().
234        4. If the marker is an event marker, handles it by calling __handle_event_marker().
235        5. If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker.
236            - Note: If there is an unknown marker type, the marker count is still incremented and processing continues.
237        6. If the command or event marker handling return continue_flag as False, break out of the while loop and end the step.
238
239        Parameters
240        ----------
241        `None`
242
243        Returns
244        ------
245        `None`
246
247        """
248        # read from sources to get new data.
249        # This puts command markers in the marker_data array and
250        # event markers in the event_marker_strings array
251        self._pull_data_from_sources()
252
253        # Process markers while there are unprocessed markers
254        # REMOVE COMMENT: check if there is an available command marker, if not, break and wait for more data
255        while len(self.marker_timestamps) > self.marker_count:
256            # Get the current marker
257            current_step_marker = self.marker_data[self.marker_count]  # String
258            current_timestamp = self.marker_timestamps[self.marker_count]  # Float
259
260            # If marker is empty, skip it
261            if not current_step_marker:
262                logger.warning("Empty marker received")
263                self.marker_count += 1
264                continue
265
266            # If messenger is available, send feedback for each marker received
267            if self._messenger is not None:
268                self._messenger.marker_received(current_step_marker)
269
270            # Process markers in order specified in the docstrings
271            # First check if it's a known command marker
272            if current_step_marker in self.marker_methods:
273                continue_flag = self.__handle_command_marker(current_step_marker)
274            # Then check if it's an event marker (contains commas)
275            elif "," in current_step_marker:
276                continue_flag = self.__handle_event_marker(
277                    current_step_marker, current_timestamp
278                )
279            # Otherwise, log a warning about unknown marker type
280            else:
281                # Log warning for unknown marker types
282                logger.warning("Unknown marker type received: %s", current_step_marker)
283                continue_flag = True
284
285            # Check if we should continue processing markers in the while loop
286            # if continue_flag is False, then break out of the while loop
287            # else, increment the marker count and process the next marker
288            if continue_flag is False:
289                break
290            else:
291                logger.info("Processed Marker: %s", current_step_marker)
292                self.marker_count += 1
293
294        self.step_count += 1
295        if self.step_count % self.ping_interval == 0:
296            if self._messenger is not None:
297                self._messenger.ping()

Runs a single BciController processing step.

See setup() for configuration of processing.

The method:

  1. Pulls data from sources (EEG and markers).
  2. Run a while loop to process markers as long as there are unprocessed markers.
  3. The while loop processes the markers in the following order:
    • First checks if the marker is a known command marker from self.marker_methods.
    • Then checks if it's an event marker (contains commas)
    • If neither, logs a warning about unknown marker type
  4. If the marker is a command marker, handles it by calling __handle_command_marker().
  5. If the marker is an event marker, handles it by calling __handle_event_marker().
  6. If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker.
    • Note: If there is an unknown marker type, the marker count is still incremented and processing continues.
  7. If the command or event marker handling return continue_flag as False, break out of the while loop and end the step.
Parameters
  • None
Returns
  • None
def run(self, max_loops: int = 1000000, ping_interval: int = 100):
299    def run(self, max_loops: int = 1000000, ping_interval: int = 100):
300        """Runs BciController processing in a loop.
301
302        See setup() for configuration of processing.
303
304        Parameters
305        ----------
306        max_loops : int, *optional*
307            Maximum number of loops to run, default is `1000000`.
308        ping_interval : int, *optional*
309            Number of steps between each messenger ping.
310
311        Returns
312        ------
313        `None`
314
315        """
316        # if offline, then all data is already loaded, only need to loop once
317        if self.online is False:
318            self.loops = max_loops - 1
319        else:
320            self.loops = 0
321
322        self.ping_interval = ping_interval
323
324        # Initialize the event marker buffer
325        self.event_marker_buffer = []
326        self.event_timestamp_buffer = []
327
328        # start the main loop, stops after pulling new data, max_loops times
329        while self.loops < max_loops:
330            # print out loop status
331            if self.loops % 100 == 0:
332                logger.debug(self.loops)
333
334            if self.loops == max_loops - 1:
335                logger.debug("last loop")
336
337            # read from sources and process
338            self.step()
339
340            # Wait a short period of time and then try to pull more data
341            if self.online:
342                time.sleep(0.00001)
343
344            self.loops += 1

Runs BciController processing in a loop.

See setup() for configuration of processing.

Parameters
  • max_loops (int, optional): Maximum number of loops to run, default is 1000000.
  • ping_interval (int, optional): Number of steps between each messenger ping.
Returns
  • None