bci_essentials.bci_controller
Module for managing BCI data.
This module provides data classes for different BCI paradigms.
It includes the loading of offline data in xdf format
or the live streaming of LSL data.
The loaded/streamed data is added to a buffer such that offline and online processing pipelines are identical.
Data is pre-processed (using the signal_processing module), divided into trials,
and classified (using one of the classification sub-modules).
Classes
BciController: For processing continuous data in trials of a defined length.
1"""Module for managing BCI data. 2 3This module provides data classes for different BCI paradigms. 4 5It includes the loading of offline data in `xdf` format 6or the live streaming of LSL data. 7 8The loaded/streamed data is added to a buffer such that offline and 9online processing pipelines are identical. 10 11Data is pre-processed (using the `signal_processing` module), divided into trials, 12and classified (using one of the `classification` sub-modules). 13 14Classes 15------- 16- `BciController` : For processing continuous data in trials of a defined length. 17 18""" 19 20import time 21import os 22import numpy as np 23from enum import Enum 24 25from .paradigm.paradigm import Paradigm 26from .data_tank.data_tank import DataTank 27from .classification.generic_classifier import GenericClassifier 28from .io.sources import EegSource, MarkerSource 29from .io.messenger import Messenger 30from .utils.logger import Logger 31 32# Instantiate a logger for the module at the default level of logging.INFO 33# Logs to bci_essentials.__module__) where __module__ is the name of the module 34logger = Logger(name=__name__) 35 36 37class MarkerTypes(Enum): 38 TRIAL_STARTED = "Trial Started" 39 TRIAL_ENDS = "Trial Ends" 40 TRAINING_COMPLETE = "Training Complete" 41 TRAIN_CLASSIFIER = "Train Classifier" 42 DONE_RS_COLLECTION = "Done with all RS collection" 43 UPDATE_CLASSIFIER = "Update Classifier" 44 45 46# EEG data 47class BciController: 48 """ 49 Class that holds, trials, processes, and classifies EEG data. 50 This class is used for processing of continuous EEG data in trials of a defined length. 51 """ 52 53 # 0. Special methods (e.g. __init__) 54 def __init__( 55 self, 56 classifier: GenericClassifier, 57 eeg_source: EegSource, 58 marker_source: MarkerSource | None = None, 59 paradigm: Paradigm | None = None, 60 data_tank: DataTank | None = None, 61 messenger: Messenger | None = None, 62 ): 63 """Initializes `BciController` class. 64 65 Parameters 66 ---------- 67 classifier : GenericClassifier 68 The classifier used by BciController. 69 eeg_source : EegSource 70 Source of EEG data and timestamps, this could be from a file or headset via LSL, etc. 71 marker_source : EegSource 72 Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc. 73 - Default is `None`. 74 paradigm : Paradigm 75 The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data. 76 data_tank : DataTank 77 DataTank object to handle the storage of EEG trials and labels. 78 - Default is `None`. 79 messenger: Messenger 80 Messenger object to handle events from BciController, ex: acknowledging markers and 81 predictions. 82 - Default is `None`. 83 84 """ 85 86 # Ensure the incoming dependencies are the right type 87 assert isinstance(classifier, GenericClassifier) 88 assert isinstance(eeg_source, EegSource) 89 assert isinstance(marker_source, MarkerSource | None) 90 assert isinstance(paradigm, Paradigm | None) 91 assert isinstance(data_tank, DataTank | None) 92 assert isinstance(messenger, Messenger | None) 93 94 self._classifier = classifier 95 self.__eeg_source = eeg_source 96 self.__marker_source = marker_source 97 self.__paradigm = paradigm 98 self.__data_tank = data_tank 99 self._messenger = messenger 100 101 self.headset_string = self.__eeg_source.name 102 self.fsample = self.__eeg_source.fsample 103 self.n_channels = self.__eeg_source.n_channels 104 self.ch_type = self.__eeg_source.channel_types 105 self.ch_units = self.__eeg_source.channel_units 106 self.channel_labels = self.__eeg_source.channel_labels 107 108 self.__data_tank.set_source_data( 109 self.headset_string, 110 self.fsample, 111 self.n_channels, 112 self.ch_type, 113 self.ch_units, 114 self.channel_labels, 115 ) 116 117 # Switch any trigger channels to stim, this is for mne/bids export (?) 118 self.ch_type = [type.replace("trg", "stim") for type in self.ch_type] 119 120 self._classifier.channel_labels = self.channel_labels 121 122 logger.info(self.headset_string) 123 logger.info(self.channel_labels) 124 125 # Initialize data and timestamp arrays to the right dimensions, but zero elements 126 self.marker_data = np.zeros((0, 1)) 127 self.marker_timestamps = np.zeros((0)) 128 self.bci_controller = np.zeros((0, self.n_channels)) 129 self.eeg_timestamps = np.zeros((0)) 130 131 # Initialize marker methods dictionary 132 self.marker_methods = { 133 MarkerTypes.DONE_RS_COLLECTION.value: self.__process_resting_state_data, 134 MarkerTypes.TRIAL_STARTED.value: self.__log_trial_start, 135 MarkerTypes.TRIAL_ENDS.value: self.__handle_trial_end, 136 MarkerTypes.TRAINING_COMPLETE.value: self.__update_and_train_classifier, 137 MarkerTypes.TRAIN_CLASSIFIER.value: self.__update_and_train_classifier, 138 MarkerTypes.UPDATE_CLASSIFIER.value: self.__update_and_train_classifier, 139 } 140 141 self.ping_count = 0 142 self.n_samples = 0 143 self.time_units = "" 144 145 # 1. Core public API methods 146 def setup( 147 self, 148 online=True, 149 train_complete=False, 150 train_lock=False, 151 auto_save_epochs=True, 152 ): 153 """Configure processing loop. 154 155 This should be called before starting the loop with run() or step(). 156 157 Calling after will reset the loop state. 158 159 The processing loop reads in EEG and marker data and processes it. 160 The loop can be run in "offline" or "online" modes: 161 - If in `online` mode, then the loop will continuously try to read 162 in data from the `BciController` object and process it. The loop will 163 terminate when `max_loops` is reached, or when manually terminated. 164 - If in `offline` mode, then the loop will read in all of the data 165 at once, process it, and then terminate. 166 167 Parameters 168 ---------- 169 online : bool, *optional* 170 Flag to indicate if the data will be processed in `online` mode. 171 - `True`: The data will be processed in `online` mode. 172 - `False`: The data will be processed in `offline` mode. 173 - Default is `True`. 174 train_complete : bool, *optional* 175 Flag to indicate if the classifier has been trained. 176 - `True`: The classifier has been trained. 177 - `False`: The classifier has not been trained. 178 - Default is `False`. 179 train_lock : bool, *optional* 180 Flag to indicate if the classifier is locked (ie. no more training). 181 - `True`: The classifier is locked. 182 - `False`: The classifier is not locked. 183 - Default is `False`. 184 auto_save_epochs : bool, *optional* 185 Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes. 186 - `True`: Epochs will be saved to a temp file. 187 - `False`: Epochs will not be saved to a temp file. 188 189 190 Returns 191 ------- 192 `None` 193 194 """ 195 self.online = online 196 self.train_complete = train_complete 197 self.train_lock = train_lock 198 self.auto_save_epochs = auto_save_epochs 199 200 # initialize the numbers of markers and trials to zero 201 self.marker_count = 0 202 self.current_num_trials = 0 203 self.n_trials = 0 204 205 self.num_online_selections = 0 206 self.online_selection_indices = [] 207 self.online_selections = [] 208 209 # Check for a temp_epochs file 210 if online: 211 self.__load_temp_epochs_if_available() 212 213 def step(self): 214 """Runs a single BciController processing step. 215 216 See setup() for configuration of processing. 217 218 The method: 219 1. Pulls data from sources (EEG and markers). 220 2. Run a while loop to process markers as long as there are unprocessed markers. 221 3. The while loop processes the markers in the following order: 222 - First checks if the marker is a known command marker from self.marker_methods. 223 - Then checks if it's an event marker (contains commas) 224 - If neither, logs a warning about unknown marker type 225 3. If the marker is a command marker, handles it by calling __handle_command_marker(). 226 4. If the marker is an event marker, handles it by calling __handle_event_marker(). 227 5. If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker. 228 - Note: If there is an unknown marker type, the marker count is still incremented and processing continues. 229 6. If the command or event marker handling return continue_flag as False, break out of the while loop and end the step. 230 231 Parameters 232 ---------- 233 `None` 234 235 Returns 236 ------ 237 `None` 238 239 """ 240 # read from sources to get new data. 241 # This puts command markers in the marker_data array and 242 # event markers in the event_marker_strings array 243 self._pull_data_from_sources() 244 245 # Process markers while there are unprocessed markers 246 # REMOVE COMMENT: check if there is an available command marker, if not, break and wait for more data 247 while len(self.marker_timestamps) > self.marker_count: 248 # Get the current marker 249 current_step_marker = self.marker_data[self.marker_count] # String 250 current_timestamp = self.marker_timestamps[self.marker_count] # Float 251 252 # If marker is empty, skip it 253 if not current_step_marker: 254 logger.warning("Empty marker received") 255 self.marker_count += 1 256 continue 257 258 # If messenger is available, send feedback for each marker received 259 if self._messenger is not None: 260 self._messenger.marker_received(current_step_marker) 261 262 # Process markers in order specified in the docstrings 263 # First check if it's a known command marker 264 if current_step_marker in self.marker_methods: 265 continue_flag = self.__handle_command_marker(current_step_marker) 266 # Then check if it's an event marker (contains commas) 267 elif "," in current_step_marker: 268 continue_flag = self.__handle_event_marker( 269 current_step_marker, current_timestamp 270 ) 271 # Otherwise, log a warning about unknown marker type 272 else: 273 # Log warning for unknown marker types 274 logger.warning("Unknown marker type received: %s", current_step_marker) 275 continue_flag = True 276 277 # Check if we should continue processing markers in the while loop 278 # if continue_flag is False, then break out of the while loop 279 # else, increment the marker count and process the next marker 280 if continue_flag is False: 281 break 282 else: 283 logger.info("Processed Marker: %s", current_step_marker) 284 self.marker_count += 1 285 286 def run(self, max_loops: int = 1000000): 287 """Runs BciController processing in a loop. 288 289 See setup() for configuration of processing. 290 291 Parameters 292 ---------- 293 max_loops : int, *optional* 294 Maximum number of loops to run, default is `1000000`. 295 296 Returns 297 ------ 298 `None` 299 300 """ 301 # if offline, then all data is already loaded, only need to loop once 302 if self.online is False: 303 self.loops = max_loops - 1 304 else: 305 self.loops = 0 306 307 # Initialize the event marker buffer 308 self.event_marker_buffer = [] 309 self.event_timestamp_buffer = [] 310 311 # start the main loop, stops after pulling new data, max_loops times 312 while self.loops < max_loops: 313 # print out loop status 314 if self.loops % 100 == 0: 315 logger.debug(self.loops) 316 317 if self.loops == max_loops - 1: 318 logger.debug("last loop") 319 320 # read from sources and process 321 self.step() 322 323 # Wait a short period of time and then try to pull more data 324 if self.online: 325 time.sleep(0.00001) 326 327 self.loops += 1 328 329 # 2. Protected methods (single underscore) 330 def _pull_data_from_sources(self): 331 """Get pull data from EEG and optionally, the marker source. 332 333 This method will fill up the marker_data, bci_controller and corresponding timestamp arrays. 334 335 Parameters 336 ---------- 337 `None` 338 339 Returns 340 ------- 341 `None` 342 343 """ 344 # Get new data from source, whatever it is 345 self.__pull_marker_data_from_source() 346 self.__pull_eeg_data_from_source() 347 348 # If the outlet exists send a ping 349 if self._messenger is not None: 350 self.ping_count += 1 351 self._messenger.ping() 352 353 # 3. Private methods (double underscore) 354 # 3a. Private methods for retrieving data from sources 355 def __pull_marker_data_from_source(self): 356 """Pulls marker samples from source, sanity checks and appends to buffer. 357 358 Parameters 359 ---------- 360 `None` 361 362 Returns 363 ------- 364 `None` 365 366 """ 367 368 # if there isn't a marker source, abort 369 if self.__marker_source is None: 370 return 371 372 # read in the data 373 markers, timestamps = self.__marker_source.get_markers() 374 markers = np.array(markers) 375 timestamps = np.array(timestamps) 376 377 if markers.size == 0: 378 return 379 380 if markers.ndim != 2: 381 logger.warning("discarded invalid marker data") 382 return 383 384 # apply time correction 385 time_correction = self.__marker_source.time_correction() 386 timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))] 387 388 for i, marker in enumerate(markers): 389 marker = marker[0] 390 if "Ping" in marker: 391 continue 392 393 # Add all markers to the controller 394 self.marker_data = np.append(self.marker_data, marker) 395 self.marker_timestamps = np.append(self.marker_timestamps, timestamps[i]) 396 397 # Add all markers to the data tank 398 self.__data_tank.add_raw_markers( 399 np.array([marker]), np.array([timestamps[i]]) 400 ) 401 402 def __pull_eeg_data_from_source(self): 403 """Pulls eeg samples from source, sanity checks and appends to buffer. 404 405 Parameters 406 ---------- 407 `None` 408 409 Returns 410 ------- 411 `None` 412 413 """ 414 415 # read in the data 416 eeg, timestamps = self.__eeg_source.get_samples() 417 eeg = np.array(eeg) 418 timestamps = np.array(timestamps) 419 420 if eeg.size == 0: 421 return 422 423 if eeg.ndim != 2: 424 logger.warning("discarded invalid eeg data") 425 return 426 427 # if time is in milliseconds, divide by 1000, works for sampling rates above 10Hz 428 try: 429 if self.time_units == "milliseconds": 430 timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))] 431 432 # If time units are not defined then define them 433 except Exception: 434 dif_low = -2 435 dif_high = -1 436 while timestamps[dif_high] - timestamps[dif_low] == 0: 437 dif_low -= 1 438 dif_high -= 1 439 440 if timestamps[dif_high] - timestamps[dif_low] > 0.1: 441 timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))] 442 self.time_units = "milliseconds" 443 else: 444 self.time_units = "seconds" 445 446 # apply time correction, this is essential for headsets like neurosity which have their own clock 447 time_correction = self.__eeg_source.time_correction() 448 timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))] 449 450 self.__data_tank.add_raw_eeg(eeg.T, timestamps) 451 452 # Update latest EEG timestamp 453 self.latest_eeg_timestamp = timestamps[-1] 454 455 # 3b. Private methods for data processing and classification 456 def __process_resting_state_data(self): 457 """Handles the resting state data by packaging it and adding it to the data tank. 458 459 Parameters 460 ---------- 461 `None` 462 463 Returns 464 ------ 465 continue_flag : bool 466 Flag indicating to continue the while loop in step(). 467 468 """ 469 ( 470 self.bci_controller, 471 self.eeg_timestamps, 472 ) = self.__data_tank.get_raw_eeg() 473 474 resting_state_data = self.__paradigm.package_resting_state_data( 475 self.marker_data, 476 self.marker_timestamps, 477 self.bci_controller, 478 self.eeg_timestamps, 479 self.fsample, 480 ) 481 482 self.__data_tank.add_resting_state_data(resting_state_data) 483 484 return True # Continue processing 485 486 def __process_and_classify(self): 487 """Process the markers and classify the data. 488 489 Parameters 490 ---------- 491 `None` 492 493 Returns 494 ------- 495 success_string : str 496 String indicating if the processing and classification was successful. 497 Potential values are "Success", "Skip", "Wait". 498 - "Success": The processing and classification was successful. 499 - "Skip": EEG is either absent entirely or contains lost packets. 500 - "Wait": The processing is waiting for more data. 501 502 """ 503 504 eeg_start_time, eeg_end_time = self.__paradigm.get_eeg_start_and_end_times( 505 self.event_marker_buffer, self.event_timestamp_buffer 506 ) 507 508 # No we actually need to wait until we have all the data for these markers 509 eeg, timestamps = self.__data_tank.get_raw_eeg() 510 511 # Check if there is available EEG data 512 if len(eeg) == 0: 513 logger.warning("No EEG data available") 514 return "Skip" 515 516 # If the last timestamp is less than the end time, then we don't have the necessary EEG to process 517 if timestamps[-1] < eeg_end_time: 518 return "Wait" 519 520 # Check if EEG sampling is continuous over this time period 521 start_indices = np.where(timestamps > eeg_start_time)[0] 522 if len(start_indices) == 0: 523 logger.warning("No timestamps exceed eeg_start_time") 524 return "Skip" 525 start_index = start_indices[0] 526 end_index = np.where(timestamps < eeg_end_time)[0][-1] 527 528 time_diffs = np.diff(timestamps[start_index:end_index]) 529 if np.any(time_diffs > 2 / self.fsample): 530 logger.warning("Time gaps in EEG data") 531 return "Skip" 532 533 X, y = self.__paradigm.process_markers( 534 self.event_marker_buffer, 535 self.event_timestamp_buffer, 536 eeg, 537 timestamps, 538 self.fsample, 539 ) 540 541 sum_new_labeled_trials = np.sum(y != -1) 542 543 # Add the epochs to the data tank 544 self.__data_tank.add_epochs(X, y) 545 546 # Save epochs to temp_epochs file 547 if self.auto_save_epochs and self.online and sum_new_labeled_trials > 0: 548 paradigm_str = self.__paradigm.paradigm_name 549 550 with open(self.temp_epochs, "wb") as f: 551 np.savez( 552 f, 553 X=self.__data_tank.epochs, 554 y=self.__data_tank.labels, 555 paradigm=paradigm_str, 556 ) 557 558 # If either there are no labels OR iterative training is on, then make a prediction 559 if self.train_complete: 560 if -1 in y or self.__paradigm.iterative_training: 561 prediction = self._classifier.predict(X) 562 self.__send_prediction(prediction) 563 564 self.event_marker_buffer = [] 565 self.event_timestamp_buffer = [] 566 567 return "Success" 568 569 def __update_and_train_classifier(self): 570 """Updates the classifier if required. 571 572 Parameters 573 ---------- 574 `None` 575 576 Returns 577 ------- 578 continue_flag : bool 579 Flag indicating to continue the while loop in step(). 580 """ 581 if self.train_lock is False: 582 # Pull the epochs from the data tank and pass them to the classifier 583 X, y = self.__data_tank.get_epochs(latest=True) 584 585 # Remove epochs with label -1 586 ind_to_remove = [] 587 for i, label in enumerate(y): 588 if label == -1: 589 ind_to_remove.append(i) 590 X = np.delete(X, ind_to_remove, axis=0) 591 y = np.delete(y, ind_to_remove, axis=0) 592 593 # Check that there are epochs 594 if len(y) > 0: 595 self._classifier.add_to_train(X, y) 596 597 if self._classifier._check_ready_for_fit(): 598 self._classifier.fit() 599 self.train_complete = True 600 601 return True 602 603 # 3c. Private methods for event handling (trial and markers) and messaging 604 def __log_trial_start(self): 605 """Logs the start of a trial. 606 607 Parameters 608 ---------- 609 `None` 610 611 Returns 612 ------- 613 continue_flag : bool 614 Flag indicating to continue the while loop in step(). 615 616 """ 617 logger.debug("Trial started, incrementing marker count and continuing") 618 # Note that a marker occured, but do nothing else 619 return True # Continue processing 620 621 def __handle_trial_end(self): 622 """Handles the end of a trial. Processes and classifies trial data if required. 623 624 Parameters 625 ---------- 626 `None` 627 628 Returns 629 ------ 630 success_flag : bool 631 Flag indicating if the processing and classification was successful. 632 - Returns `True` if not classifying. 633 """ 634 # If we are classifying based on trials, then process the trial, 635 if self.__paradigm.classify_each_trial: 636 success_string = self.__process_and_classify() 637 if success_string == "Wait": 638 logger.debug("Processing of trial not run: waiting for more data") 639 return False 640 if success_string == "Skip": 641 logger.warning("Processing of trial failed: skipping trial") 642 self.event_marker_buffer = [] 643 self.event_timestamp_buffer = [] 644 self.marker_count += 1 645 return False 646 647 return True # Return True by default if not classifying 648 649 def __handle_event_marker(self, marker, timestamp): 650 """Processes and classifies event markers. 651 652 Parameters 653 ---------- 654 marker : str 655 Event marker string containing comma-separated values. 656 - Format depends on paradigm implementation. 657 timestamp : float 658 Timestamp of the marker in seconds (after time correction). 659 660 661 Returns 662 ------ 663 continue_flag : bool 664 Flag indicating to continue the while loop in step(). 665 666 """ 667 # Add the marker to the event marker buffer 668 self.event_marker_buffer.append(marker) 669 self.event_timestamp_buffer.append(timestamp) 670 671 # If classification is on epochs, then update epochs, maybe classify, and clear the buffer 672 if self.__paradigm.classify_each_epoch: 673 success_string = self.__process_and_classify() 674 if success_string == "Wait": 675 logger.debug("Processing of epoch not run: waiting for more data") 676 self.event_marker_buffer = [] 677 self.event_timestamp_buffer = [] 678 return False # Stop processing 679 elif success_string == "Skip": 680 logger.warning("Processing of epoch failed: skipping epoch") 681 self.event_marker_buffer = [] 682 self.event_timestamp_buffer = [] 683 684 self.marker_count += 1 685 return False # Stop processing 686 687 return True # Continue processing 688 689 def __handle_command_marker(self, marker: str) -> bool: 690 """Processes a command marker by invoking its associated method. 691 692 The command marker string is assumed to be in the self.marker_methods dictionary. 693 The associated method is retrieved and called. 694 The return value of the method is used to determine if processing should continue. 695 696 Parameters 697 ---------- 698 marker : str 699 A command marker string (assumed to be in self.marker_methods). 700 701 Returns 702 ------- 703 bool 704 A flag indicating if the processing should continue. 705 706 """ 707 command_marker_method = self.marker_methods[marker] # Retrieve method 708 continue_flag = command_marker_method() # Call method 709 710 # Debug level logging if continue_flag is FALSE 711 if continue_flag is False: 712 logger.debug("Command marker '%s' set continue_flag to FALSE", marker) 713 714 return continue_flag 715 716 def __send_prediction(self, prediction): 717 """Send a prediction to the messenger object. 718 719 Parameters 720 ---------- 721 `None` 722 723 Returns 724 ------- 725 `None` 726 727 """ 728 if self._messenger is not None: 729 logger.debug("Sending prediction: %s", prediction) 730 self._messenger.prediction(prediction) 731 elif self._messenger is None and self.online is True: 732 # If running in online mode and messenger is not available, log a warning 733 logger.warning( 734 "Messenger not available (self._messenger is None). Prediction not sent: %s", 735 prediction, 736 ) 737 738 def __load_temp_epochs_if_available(self, reload_data_time: int = 300): 739 """Load temp_epochs if available and valid. 740 741 Parameters 742 ---------- 743 reload_data_time : int, *optional* 744 Time in seconds of the last temp_epochs file to reload the data from. 745 Default is `300` seconds (5 minutes). 746 747 Returns 748 ------- 749 `None` 750 751 """ 752 self.temp_epochs = os.path.join( 753 os.path.dirname(os.path.dirname(__file__)), "temp_epochs.npz" 754 ) 755 756 if not os.path.exists(self.temp_epochs): 757 return 758 759 # If temp_epochs is older than `reload_data_time`, delete it 760 if os.path.getmtime(self.temp_epochs) < (time.time() - reload_data_time): 761 os.remove(self.temp_epochs) 762 logger.info("Deleted old temp_epochs file.") 763 return 764 765 # Load the temp_epochs file 766 with open(self.temp_epochs, "rb") as f: 767 npz = np.load(f, allow_pickle=True) 768 X = npz["X"] 769 y = npz["y"] 770 paradigm_str = npz["paradigm"].item() 771 772 # If the paradigm is different, delete the file 773 if self.__paradigm.paradigm_name != paradigm_str: 774 logger.warning( 775 "Paradigm in temp_epochs file does not match current paradigm. Deleting file." 776 ) 777 os.remove(self.temp_epochs) 778 return 779 780 # If the paradigm is the same, then add the epochs to the data tank 781 logger.info("Loading epochs from temp_epochs file.") 782 logger.info("X shape: %s", X.shape) 783 logger.info("y shape: %s", y.shape) 784 self.__data_tank.add_epochs(X, y) 785 786 # If there are epochs in the data tank, then train the classifier 787 if len(self.__data_tank.labels) > 0: 788 self.__update_and_train_classifier()
38class MarkerTypes(Enum): 39 TRIAL_STARTED = "Trial Started" 40 TRIAL_ENDS = "Trial Ends" 41 TRAINING_COMPLETE = "Training Complete" 42 TRAIN_CLASSIFIER = "Train Classifier" 43 DONE_RS_COLLECTION = "Done with all RS collection" 44 UPDATE_CLASSIFIER = "Update Classifier"
An enumeration.
48class BciController: 49 """ 50 Class that holds, trials, processes, and classifies EEG data. 51 This class is used for processing of continuous EEG data in trials of a defined length. 52 """ 53 54 # 0. Special methods (e.g. __init__) 55 def __init__( 56 self, 57 classifier: GenericClassifier, 58 eeg_source: EegSource, 59 marker_source: MarkerSource | None = None, 60 paradigm: Paradigm | None = None, 61 data_tank: DataTank | None = None, 62 messenger: Messenger | None = None, 63 ): 64 """Initializes `BciController` class. 65 66 Parameters 67 ---------- 68 classifier : GenericClassifier 69 The classifier used by BciController. 70 eeg_source : EegSource 71 Source of EEG data and timestamps, this could be from a file or headset via LSL, etc. 72 marker_source : EegSource 73 Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc. 74 - Default is `None`. 75 paradigm : Paradigm 76 The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data. 77 data_tank : DataTank 78 DataTank object to handle the storage of EEG trials and labels. 79 - Default is `None`. 80 messenger: Messenger 81 Messenger object to handle events from BciController, ex: acknowledging markers and 82 predictions. 83 - Default is `None`. 84 85 """ 86 87 # Ensure the incoming dependencies are the right type 88 assert isinstance(classifier, GenericClassifier) 89 assert isinstance(eeg_source, EegSource) 90 assert isinstance(marker_source, MarkerSource | None) 91 assert isinstance(paradigm, Paradigm | None) 92 assert isinstance(data_tank, DataTank | None) 93 assert isinstance(messenger, Messenger | None) 94 95 self._classifier = classifier 96 self.__eeg_source = eeg_source 97 self.__marker_source = marker_source 98 self.__paradigm = paradigm 99 self.__data_tank = data_tank 100 self._messenger = messenger 101 102 self.headset_string = self.__eeg_source.name 103 self.fsample = self.__eeg_source.fsample 104 self.n_channels = self.__eeg_source.n_channels 105 self.ch_type = self.__eeg_source.channel_types 106 self.ch_units = self.__eeg_source.channel_units 107 self.channel_labels = self.__eeg_source.channel_labels 108 109 self.__data_tank.set_source_data( 110 self.headset_string, 111 self.fsample, 112 self.n_channels, 113 self.ch_type, 114 self.ch_units, 115 self.channel_labels, 116 ) 117 118 # Switch any trigger channels to stim, this is for mne/bids export (?) 119 self.ch_type = [type.replace("trg", "stim") for type in self.ch_type] 120 121 self._classifier.channel_labels = self.channel_labels 122 123 logger.info(self.headset_string) 124 logger.info(self.channel_labels) 125 126 # Initialize data and timestamp arrays to the right dimensions, but zero elements 127 self.marker_data = np.zeros((0, 1)) 128 self.marker_timestamps = np.zeros((0)) 129 self.bci_controller = np.zeros((0, self.n_channels)) 130 self.eeg_timestamps = np.zeros((0)) 131 132 # Initialize marker methods dictionary 133 self.marker_methods = { 134 MarkerTypes.DONE_RS_COLLECTION.value: self.__process_resting_state_data, 135 MarkerTypes.TRIAL_STARTED.value: self.__log_trial_start, 136 MarkerTypes.TRIAL_ENDS.value: self.__handle_trial_end, 137 MarkerTypes.TRAINING_COMPLETE.value: self.__update_and_train_classifier, 138 MarkerTypes.TRAIN_CLASSIFIER.value: self.__update_and_train_classifier, 139 MarkerTypes.UPDATE_CLASSIFIER.value: self.__update_and_train_classifier, 140 } 141 142 self.ping_count = 0 143 self.n_samples = 0 144 self.time_units = "" 145 146 # 1. Core public API methods 147 def setup( 148 self, 149 online=True, 150 train_complete=False, 151 train_lock=False, 152 auto_save_epochs=True, 153 ): 154 """Configure processing loop. 155 156 This should be called before starting the loop with run() or step(). 157 158 Calling after will reset the loop state. 159 160 The processing loop reads in EEG and marker data and processes it. 161 The loop can be run in "offline" or "online" modes: 162 - If in `online` mode, then the loop will continuously try to read 163 in data from the `BciController` object and process it. The loop will 164 terminate when `max_loops` is reached, or when manually terminated. 165 - If in `offline` mode, then the loop will read in all of the data 166 at once, process it, and then terminate. 167 168 Parameters 169 ---------- 170 online : bool, *optional* 171 Flag to indicate if the data will be processed in `online` mode. 172 - `True`: The data will be processed in `online` mode. 173 - `False`: The data will be processed in `offline` mode. 174 - Default is `True`. 175 train_complete : bool, *optional* 176 Flag to indicate if the classifier has been trained. 177 - `True`: The classifier has been trained. 178 - `False`: The classifier has not been trained. 179 - Default is `False`. 180 train_lock : bool, *optional* 181 Flag to indicate if the classifier is locked (ie. no more training). 182 - `True`: The classifier is locked. 183 - `False`: The classifier is not locked. 184 - Default is `False`. 185 auto_save_epochs : bool, *optional* 186 Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes. 187 - `True`: Epochs will be saved to a temp file. 188 - `False`: Epochs will not be saved to a temp file. 189 190 191 Returns 192 ------- 193 `None` 194 195 """ 196 self.online = online 197 self.train_complete = train_complete 198 self.train_lock = train_lock 199 self.auto_save_epochs = auto_save_epochs 200 201 # initialize the numbers of markers and trials to zero 202 self.marker_count = 0 203 self.current_num_trials = 0 204 self.n_trials = 0 205 206 self.num_online_selections = 0 207 self.online_selection_indices = [] 208 self.online_selections = [] 209 210 # Check for a temp_epochs file 211 if online: 212 self.__load_temp_epochs_if_available() 213 214 def step(self): 215 """Runs a single BciController processing step. 216 217 See setup() for configuration of processing. 218 219 The method: 220 1. Pulls data from sources (EEG and markers). 221 2. Run a while loop to process markers as long as there are unprocessed markers. 222 3. The while loop processes the markers in the following order: 223 - First checks if the marker is a known command marker from self.marker_methods. 224 - Then checks if it's an event marker (contains commas) 225 - If neither, logs a warning about unknown marker type 226 3. If the marker is a command marker, handles it by calling __handle_command_marker(). 227 4. If the marker is an event marker, handles it by calling __handle_event_marker(). 228 5. If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker. 229 - Note: If there is an unknown marker type, the marker count is still incremented and processing continues. 230 6. If the command or event marker handling return continue_flag as False, break out of the while loop and end the step. 231 232 Parameters 233 ---------- 234 `None` 235 236 Returns 237 ------ 238 `None` 239 240 """ 241 # read from sources to get new data. 242 # This puts command markers in the marker_data array and 243 # event markers in the event_marker_strings array 244 self._pull_data_from_sources() 245 246 # Process markers while there are unprocessed markers 247 # REMOVE COMMENT: check if there is an available command marker, if not, break and wait for more data 248 while len(self.marker_timestamps) > self.marker_count: 249 # Get the current marker 250 current_step_marker = self.marker_data[self.marker_count] # String 251 current_timestamp = self.marker_timestamps[self.marker_count] # Float 252 253 # If marker is empty, skip it 254 if not current_step_marker: 255 logger.warning("Empty marker received") 256 self.marker_count += 1 257 continue 258 259 # If messenger is available, send feedback for each marker received 260 if self._messenger is not None: 261 self._messenger.marker_received(current_step_marker) 262 263 # Process markers in order specified in the docstrings 264 # First check if it's a known command marker 265 if current_step_marker in self.marker_methods: 266 continue_flag = self.__handle_command_marker(current_step_marker) 267 # Then check if it's an event marker (contains commas) 268 elif "," in current_step_marker: 269 continue_flag = self.__handle_event_marker( 270 current_step_marker, current_timestamp 271 ) 272 # Otherwise, log a warning about unknown marker type 273 else: 274 # Log warning for unknown marker types 275 logger.warning("Unknown marker type received: %s", current_step_marker) 276 continue_flag = True 277 278 # Check if we should continue processing markers in the while loop 279 # if continue_flag is False, then break out of the while loop 280 # else, increment the marker count and process the next marker 281 if continue_flag is False: 282 break 283 else: 284 logger.info("Processed Marker: %s", current_step_marker) 285 self.marker_count += 1 286 287 def run(self, max_loops: int = 1000000): 288 """Runs BciController processing in a loop. 289 290 See setup() for configuration of processing. 291 292 Parameters 293 ---------- 294 max_loops : int, *optional* 295 Maximum number of loops to run, default is `1000000`. 296 297 Returns 298 ------ 299 `None` 300 301 """ 302 # if offline, then all data is already loaded, only need to loop once 303 if self.online is False: 304 self.loops = max_loops - 1 305 else: 306 self.loops = 0 307 308 # Initialize the event marker buffer 309 self.event_marker_buffer = [] 310 self.event_timestamp_buffer = [] 311 312 # start the main loop, stops after pulling new data, max_loops times 313 while self.loops < max_loops: 314 # print out loop status 315 if self.loops % 100 == 0: 316 logger.debug(self.loops) 317 318 if self.loops == max_loops - 1: 319 logger.debug("last loop") 320 321 # read from sources and process 322 self.step() 323 324 # Wait a short period of time and then try to pull more data 325 if self.online: 326 time.sleep(0.00001) 327 328 self.loops += 1 329 330 # 2. Protected methods (single underscore) 331 def _pull_data_from_sources(self): 332 """Get pull data from EEG and optionally, the marker source. 333 334 This method will fill up the marker_data, bci_controller and corresponding timestamp arrays. 335 336 Parameters 337 ---------- 338 `None` 339 340 Returns 341 ------- 342 `None` 343 344 """ 345 # Get new data from source, whatever it is 346 self.__pull_marker_data_from_source() 347 self.__pull_eeg_data_from_source() 348 349 # If the outlet exists send a ping 350 if self._messenger is not None: 351 self.ping_count += 1 352 self._messenger.ping() 353 354 # 3. Private methods (double underscore) 355 # 3a. Private methods for retrieving data from sources 356 def __pull_marker_data_from_source(self): 357 """Pulls marker samples from source, sanity checks and appends to buffer. 358 359 Parameters 360 ---------- 361 `None` 362 363 Returns 364 ------- 365 `None` 366 367 """ 368 369 # if there isn't a marker source, abort 370 if self.__marker_source is None: 371 return 372 373 # read in the data 374 markers, timestamps = self.__marker_source.get_markers() 375 markers = np.array(markers) 376 timestamps = np.array(timestamps) 377 378 if markers.size == 0: 379 return 380 381 if markers.ndim != 2: 382 logger.warning("discarded invalid marker data") 383 return 384 385 # apply time correction 386 time_correction = self.__marker_source.time_correction() 387 timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))] 388 389 for i, marker in enumerate(markers): 390 marker = marker[0] 391 if "Ping" in marker: 392 continue 393 394 # Add all markers to the controller 395 self.marker_data = np.append(self.marker_data, marker) 396 self.marker_timestamps = np.append(self.marker_timestamps, timestamps[i]) 397 398 # Add all markers to the data tank 399 self.__data_tank.add_raw_markers( 400 np.array([marker]), np.array([timestamps[i]]) 401 ) 402 403 def __pull_eeg_data_from_source(self): 404 """Pulls eeg samples from source, sanity checks and appends to buffer. 405 406 Parameters 407 ---------- 408 `None` 409 410 Returns 411 ------- 412 `None` 413 414 """ 415 416 # read in the data 417 eeg, timestamps = self.__eeg_source.get_samples() 418 eeg = np.array(eeg) 419 timestamps = np.array(timestamps) 420 421 if eeg.size == 0: 422 return 423 424 if eeg.ndim != 2: 425 logger.warning("discarded invalid eeg data") 426 return 427 428 # if time is in milliseconds, divide by 1000, works for sampling rates above 10Hz 429 try: 430 if self.time_units == "milliseconds": 431 timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))] 432 433 # If time units are not defined then define them 434 except Exception: 435 dif_low = -2 436 dif_high = -1 437 while timestamps[dif_high] - timestamps[dif_low] == 0: 438 dif_low -= 1 439 dif_high -= 1 440 441 if timestamps[dif_high] - timestamps[dif_low] > 0.1: 442 timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))] 443 self.time_units = "milliseconds" 444 else: 445 self.time_units = "seconds" 446 447 # apply time correction, this is essential for headsets like neurosity which have their own clock 448 time_correction = self.__eeg_source.time_correction() 449 timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))] 450 451 self.__data_tank.add_raw_eeg(eeg.T, timestamps) 452 453 # Update latest EEG timestamp 454 self.latest_eeg_timestamp = timestamps[-1] 455 456 # 3b. Private methods for data processing and classification 457 def __process_resting_state_data(self): 458 """Handles the resting state data by packaging it and adding it to the data tank. 459 460 Parameters 461 ---------- 462 `None` 463 464 Returns 465 ------ 466 continue_flag : bool 467 Flag indicating to continue the while loop in step(). 468 469 """ 470 ( 471 self.bci_controller, 472 self.eeg_timestamps, 473 ) = self.__data_tank.get_raw_eeg() 474 475 resting_state_data = self.__paradigm.package_resting_state_data( 476 self.marker_data, 477 self.marker_timestamps, 478 self.bci_controller, 479 self.eeg_timestamps, 480 self.fsample, 481 ) 482 483 self.__data_tank.add_resting_state_data(resting_state_data) 484 485 return True # Continue processing 486 487 def __process_and_classify(self): 488 """Process the markers and classify the data. 489 490 Parameters 491 ---------- 492 `None` 493 494 Returns 495 ------- 496 success_string : str 497 String indicating if the processing and classification was successful. 498 Potential values are "Success", "Skip", "Wait". 499 - "Success": The processing and classification was successful. 500 - "Skip": EEG is either absent entirely or contains lost packets. 501 - "Wait": The processing is waiting for more data. 502 503 """ 504 505 eeg_start_time, eeg_end_time = self.__paradigm.get_eeg_start_and_end_times( 506 self.event_marker_buffer, self.event_timestamp_buffer 507 ) 508 509 # No we actually need to wait until we have all the data for these markers 510 eeg, timestamps = self.__data_tank.get_raw_eeg() 511 512 # Check if there is available EEG data 513 if len(eeg) == 0: 514 logger.warning("No EEG data available") 515 return "Skip" 516 517 # If the last timestamp is less than the end time, then we don't have the necessary EEG to process 518 if timestamps[-1] < eeg_end_time: 519 return "Wait" 520 521 # Check if EEG sampling is continuous over this time period 522 start_indices = np.where(timestamps > eeg_start_time)[0] 523 if len(start_indices) == 0: 524 logger.warning("No timestamps exceed eeg_start_time") 525 return "Skip" 526 start_index = start_indices[0] 527 end_index = np.where(timestamps < eeg_end_time)[0][-1] 528 529 time_diffs = np.diff(timestamps[start_index:end_index]) 530 if np.any(time_diffs > 2 / self.fsample): 531 logger.warning("Time gaps in EEG data") 532 return "Skip" 533 534 X, y = self.__paradigm.process_markers( 535 self.event_marker_buffer, 536 self.event_timestamp_buffer, 537 eeg, 538 timestamps, 539 self.fsample, 540 ) 541 542 sum_new_labeled_trials = np.sum(y != -1) 543 544 # Add the epochs to the data tank 545 self.__data_tank.add_epochs(X, y) 546 547 # Save epochs to temp_epochs file 548 if self.auto_save_epochs and self.online and sum_new_labeled_trials > 0: 549 paradigm_str = self.__paradigm.paradigm_name 550 551 with open(self.temp_epochs, "wb") as f: 552 np.savez( 553 f, 554 X=self.__data_tank.epochs, 555 y=self.__data_tank.labels, 556 paradigm=paradigm_str, 557 ) 558 559 # If either there are no labels OR iterative training is on, then make a prediction 560 if self.train_complete: 561 if -1 in y or self.__paradigm.iterative_training: 562 prediction = self._classifier.predict(X) 563 self.__send_prediction(prediction) 564 565 self.event_marker_buffer = [] 566 self.event_timestamp_buffer = [] 567 568 return "Success" 569 570 def __update_and_train_classifier(self): 571 """Updates the classifier if required. 572 573 Parameters 574 ---------- 575 `None` 576 577 Returns 578 ------- 579 continue_flag : bool 580 Flag indicating to continue the while loop in step(). 581 """ 582 if self.train_lock is False: 583 # Pull the epochs from the data tank and pass them to the classifier 584 X, y = self.__data_tank.get_epochs(latest=True) 585 586 # Remove epochs with label -1 587 ind_to_remove = [] 588 for i, label in enumerate(y): 589 if label == -1: 590 ind_to_remove.append(i) 591 X = np.delete(X, ind_to_remove, axis=0) 592 y = np.delete(y, ind_to_remove, axis=0) 593 594 # Check that there are epochs 595 if len(y) > 0: 596 self._classifier.add_to_train(X, y) 597 598 if self._classifier._check_ready_for_fit(): 599 self._classifier.fit() 600 self.train_complete = True 601 602 return True 603 604 # 3c. Private methods for event handling (trial and markers) and messaging 605 def __log_trial_start(self): 606 """Logs the start of a trial. 607 608 Parameters 609 ---------- 610 `None` 611 612 Returns 613 ------- 614 continue_flag : bool 615 Flag indicating to continue the while loop in step(). 616 617 """ 618 logger.debug("Trial started, incrementing marker count and continuing") 619 # Note that a marker occured, but do nothing else 620 return True # Continue processing 621 622 def __handle_trial_end(self): 623 """Handles the end of a trial. Processes and classifies trial data if required. 624 625 Parameters 626 ---------- 627 `None` 628 629 Returns 630 ------ 631 success_flag : bool 632 Flag indicating if the processing and classification was successful. 633 - Returns `True` if not classifying. 634 """ 635 # If we are classifying based on trials, then process the trial, 636 if self.__paradigm.classify_each_trial: 637 success_string = self.__process_and_classify() 638 if success_string == "Wait": 639 logger.debug("Processing of trial not run: waiting for more data") 640 return False 641 if success_string == "Skip": 642 logger.warning("Processing of trial failed: skipping trial") 643 self.event_marker_buffer = [] 644 self.event_timestamp_buffer = [] 645 self.marker_count += 1 646 return False 647 648 return True # Return True by default if not classifying 649 650 def __handle_event_marker(self, marker, timestamp): 651 """Processes and classifies event markers. 652 653 Parameters 654 ---------- 655 marker : str 656 Event marker string containing comma-separated values. 657 - Format depends on paradigm implementation. 658 timestamp : float 659 Timestamp of the marker in seconds (after time correction). 660 661 662 Returns 663 ------ 664 continue_flag : bool 665 Flag indicating to continue the while loop in step(). 666 667 """ 668 # Add the marker to the event marker buffer 669 self.event_marker_buffer.append(marker) 670 self.event_timestamp_buffer.append(timestamp) 671 672 # If classification is on epochs, then update epochs, maybe classify, and clear the buffer 673 if self.__paradigm.classify_each_epoch: 674 success_string = self.__process_and_classify() 675 if success_string == "Wait": 676 logger.debug("Processing of epoch not run: waiting for more data") 677 self.event_marker_buffer = [] 678 self.event_timestamp_buffer = [] 679 return False # Stop processing 680 elif success_string == "Skip": 681 logger.warning("Processing of epoch failed: skipping epoch") 682 self.event_marker_buffer = [] 683 self.event_timestamp_buffer = [] 684 685 self.marker_count += 1 686 return False # Stop processing 687 688 return True # Continue processing 689 690 def __handle_command_marker(self, marker: str) -> bool: 691 """Processes a command marker by invoking its associated method. 692 693 The command marker string is assumed to be in the self.marker_methods dictionary. 694 The associated method is retrieved and called. 695 The return value of the method is used to determine if processing should continue. 696 697 Parameters 698 ---------- 699 marker : str 700 A command marker string (assumed to be in self.marker_methods). 701 702 Returns 703 ------- 704 bool 705 A flag indicating if the processing should continue. 706 707 """ 708 command_marker_method = self.marker_methods[marker] # Retrieve method 709 continue_flag = command_marker_method() # Call method 710 711 # Debug level logging if continue_flag is FALSE 712 if continue_flag is False: 713 logger.debug("Command marker '%s' set continue_flag to FALSE", marker) 714 715 return continue_flag 716 717 def __send_prediction(self, prediction): 718 """Send a prediction to the messenger object. 719 720 Parameters 721 ---------- 722 `None` 723 724 Returns 725 ------- 726 `None` 727 728 """ 729 if self._messenger is not None: 730 logger.debug("Sending prediction: %s", prediction) 731 self._messenger.prediction(prediction) 732 elif self._messenger is None and self.online is True: 733 # If running in online mode and messenger is not available, log a warning 734 logger.warning( 735 "Messenger not available (self._messenger is None). Prediction not sent: %s", 736 prediction, 737 ) 738 739 def __load_temp_epochs_if_available(self, reload_data_time: int = 300): 740 """Load temp_epochs if available and valid. 741 742 Parameters 743 ---------- 744 reload_data_time : int, *optional* 745 Time in seconds of the last temp_epochs file to reload the data from. 746 Default is `300` seconds (5 minutes). 747 748 Returns 749 ------- 750 `None` 751 752 """ 753 self.temp_epochs = os.path.join( 754 os.path.dirname(os.path.dirname(__file__)), "temp_epochs.npz" 755 ) 756 757 if not os.path.exists(self.temp_epochs): 758 return 759 760 # If temp_epochs is older than `reload_data_time`, delete it 761 if os.path.getmtime(self.temp_epochs) < (time.time() - reload_data_time): 762 os.remove(self.temp_epochs) 763 logger.info("Deleted old temp_epochs file.") 764 return 765 766 # Load the temp_epochs file 767 with open(self.temp_epochs, "rb") as f: 768 npz = np.load(f, allow_pickle=True) 769 X = npz["X"] 770 y = npz["y"] 771 paradigm_str = npz["paradigm"].item() 772 773 # If the paradigm is different, delete the file 774 if self.__paradigm.paradigm_name != paradigm_str: 775 logger.warning( 776 "Paradigm in temp_epochs file does not match current paradigm. Deleting file." 777 ) 778 os.remove(self.temp_epochs) 779 return 780 781 # If the paradigm is the same, then add the epochs to the data tank 782 logger.info("Loading epochs from temp_epochs file.") 783 logger.info("X shape: %s", X.shape) 784 logger.info("y shape: %s", y.shape) 785 self.__data_tank.add_epochs(X, y) 786 787 # If there are epochs in the data tank, then train the classifier 788 if len(self.__data_tank.labels) > 0: 789 self.__update_and_train_classifier()
Class that holds, trials, processes, and classifies EEG data. This class is used for processing of continuous EEG data in trials of a defined length.
55 def __init__( 56 self, 57 classifier: GenericClassifier, 58 eeg_source: EegSource, 59 marker_source: MarkerSource | None = None, 60 paradigm: Paradigm | None = None, 61 data_tank: DataTank | None = None, 62 messenger: Messenger | None = None, 63 ): 64 """Initializes `BciController` class. 65 66 Parameters 67 ---------- 68 classifier : GenericClassifier 69 The classifier used by BciController. 70 eeg_source : EegSource 71 Source of EEG data and timestamps, this could be from a file or headset via LSL, etc. 72 marker_source : EegSource 73 Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc. 74 - Default is `None`. 75 paradigm : Paradigm 76 The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data. 77 data_tank : DataTank 78 DataTank object to handle the storage of EEG trials and labels. 79 - Default is `None`. 80 messenger: Messenger 81 Messenger object to handle events from BciController, ex: acknowledging markers and 82 predictions. 83 - Default is `None`. 84 85 """ 86 87 # Ensure the incoming dependencies are the right type 88 assert isinstance(classifier, GenericClassifier) 89 assert isinstance(eeg_source, EegSource) 90 assert isinstance(marker_source, MarkerSource | None) 91 assert isinstance(paradigm, Paradigm | None) 92 assert isinstance(data_tank, DataTank | None) 93 assert isinstance(messenger, Messenger | None) 94 95 self._classifier = classifier 96 self.__eeg_source = eeg_source 97 self.__marker_source = marker_source 98 self.__paradigm = paradigm 99 self.__data_tank = data_tank 100 self._messenger = messenger 101 102 self.headset_string = self.__eeg_source.name 103 self.fsample = self.__eeg_source.fsample 104 self.n_channels = self.__eeg_source.n_channels 105 self.ch_type = self.__eeg_source.channel_types 106 self.ch_units = self.__eeg_source.channel_units 107 self.channel_labels = self.__eeg_source.channel_labels 108 109 self.__data_tank.set_source_data( 110 self.headset_string, 111 self.fsample, 112 self.n_channels, 113 self.ch_type, 114 self.ch_units, 115 self.channel_labels, 116 ) 117 118 # Switch any trigger channels to stim, this is for mne/bids export (?) 119 self.ch_type = [type.replace("trg", "stim") for type in self.ch_type] 120 121 self._classifier.channel_labels = self.channel_labels 122 123 logger.info(self.headset_string) 124 logger.info(self.channel_labels) 125 126 # Initialize data and timestamp arrays to the right dimensions, but zero elements 127 self.marker_data = np.zeros((0, 1)) 128 self.marker_timestamps = np.zeros((0)) 129 self.bci_controller = np.zeros((0, self.n_channels)) 130 self.eeg_timestamps = np.zeros((0)) 131 132 # Initialize marker methods dictionary 133 self.marker_methods = { 134 MarkerTypes.DONE_RS_COLLECTION.value: self.__process_resting_state_data, 135 MarkerTypes.TRIAL_STARTED.value: self.__log_trial_start, 136 MarkerTypes.TRIAL_ENDS.value: self.__handle_trial_end, 137 MarkerTypes.TRAINING_COMPLETE.value: self.__update_and_train_classifier, 138 MarkerTypes.TRAIN_CLASSIFIER.value: self.__update_and_train_classifier, 139 MarkerTypes.UPDATE_CLASSIFIER.value: self.__update_and_train_classifier, 140 } 141 142 self.ping_count = 0 143 self.n_samples = 0 144 self.time_units = ""
Initializes BciController class.
Parameters
- classifier (GenericClassifier): The classifier used by BciController.
- eeg_source (EegSource): Source of EEG data and timestamps, this could be from a file or headset via LSL, etc.
- marker_source (EegSource):
Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc.
- Default is
None.
- Default is
- paradigm (Paradigm): The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data.
- data_tank (DataTank):
DataTank object to handle the storage of EEG trials and labels.
- Default is
None.
- Default is
- messenger (Messenger):
Messenger object to handle events from BciController, ex: acknowledging markers and
predictions.
- Default is
None.
- Default is
147 def setup( 148 self, 149 online=True, 150 train_complete=False, 151 train_lock=False, 152 auto_save_epochs=True, 153 ): 154 """Configure processing loop. 155 156 This should be called before starting the loop with run() or step(). 157 158 Calling after will reset the loop state. 159 160 The processing loop reads in EEG and marker data and processes it. 161 The loop can be run in "offline" or "online" modes: 162 - If in `online` mode, then the loop will continuously try to read 163 in data from the `BciController` object and process it. The loop will 164 terminate when `max_loops` is reached, or when manually terminated. 165 - If in `offline` mode, then the loop will read in all of the data 166 at once, process it, and then terminate. 167 168 Parameters 169 ---------- 170 online : bool, *optional* 171 Flag to indicate if the data will be processed in `online` mode. 172 - `True`: The data will be processed in `online` mode. 173 - `False`: The data will be processed in `offline` mode. 174 - Default is `True`. 175 train_complete : bool, *optional* 176 Flag to indicate if the classifier has been trained. 177 - `True`: The classifier has been trained. 178 - `False`: The classifier has not been trained. 179 - Default is `False`. 180 train_lock : bool, *optional* 181 Flag to indicate if the classifier is locked (ie. no more training). 182 - `True`: The classifier is locked. 183 - `False`: The classifier is not locked. 184 - Default is `False`. 185 auto_save_epochs : bool, *optional* 186 Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes. 187 - `True`: Epochs will be saved to a temp file. 188 - `False`: Epochs will not be saved to a temp file. 189 190 191 Returns 192 ------- 193 `None` 194 195 """ 196 self.online = online 197 self.train_complete = train_complete 198 self.train_lock = train_lock 199 self.auto_save_epochs = auto_save_epochs 200 201 # initialize the numbers of markers and trials to zero 202 self.marker_count = 0 203 self.current_num_trials = 0 204 self.n_trials = 0 205 206 self.num_online_selections = 0 207 self.online_selection_indices = [] 208 self.online_selections = [] 209 210 # Check for a temp_epochs file 211 if online: 212 self.__load_temp_epochs_if_available()
Configure processing loop.
This should be called before starting the loop with run() or step().
Calling after will reset the loop state.
The processing loop reads in EEG and marker data and processes it. The loop can be run in "offline" or "online" modes:
- If in
onlinemode, then the loop will continuously try to read in data from theBciControllerobject and process it. The loop will terminate whenmax_loopsis reached, or when manually terminated. - If in
offlinemode, then the loop will read in all of the data at once, process it, and then terminate.
Parameters
- online (bool, optional):
Flag to indicate if the data will be processed in
onlinemode.True: The data will be processed inonlinemode.False: The data will be processed inofflinemode.- Default is
True.
- train_complete (bool, optional):
Flag to indicate if the classifier has been trained.
True: The classifier has been trained.False: The classifier has not been trained.- Default is
False.
- train_lock (bool, optional):
Flag to indicate if the classifier is locked (ie. no more training).
True: The classifier is locked.False: The classifier is not locked.- Default is
False.
- auto_save_epochs (bool, optional):
Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes.
True: Epochs will be saved to a temp file.False: Epochs will not be saved to a temp file.
Returns
None
214 def step(self): 215 """Runs a single BciController processing step. 216 217 See setup() for configuration of processing. 218 219 The method: 220 1. Pulls data from sources (EEG and markers). 221 2. Run a while loop to process markers as long as there are unprocessed markers. 222 3. The while loop processes the markers in the following order: 223 - First checks if the marker is a known command marker from self.marker_methods. 224 - Then checks if it's an event marker (contains commas) 225 - If neither, logs a warning about unknown marker type 226 3. If the marker is a command marker, handles it by calling __handle_command_marker(). 227 4. If the marker is an event marker, handles it by calling __handle_event_marker(). 228 5. If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker. 229 - Note: If there is an unknown marker type, the marker count is still incremented and processing continues. 230 6. If the command or event marker handling return continue_flag as False, break out of the while loop and end the step. 231 232 Parameters 233 ---------- 234 `None` 235 236 Returns 237 ------ 238 `None` 239 240 """ 241 # read from sources to get new data. 242 # This puts command markers in the marker_data array and 243 # event markers in the event_marker_strings array 244 self._pull_data_from_sources() 245 246 # Process markers while there are unprocessed markers 247 # REMOVE COMMENT: check if there is an available command marker, if not, break and wait for more data 248 while len(self.marker_timestamps) > self.marker_count: 249 # Get the current marker 250 current_step_marker = self.marker_data[self.marker_count] # String 251 current_timestamp = self.marker_timestamps[self.marker_count] # Float 252 253 # If marker is empty, skip it 254 if not current_step_marker: 255 logger.warning("Empty marker received") 256 self.marker_count += 1 257 continue 258 259 # If messenger is available, send feedback for each marker received 260 if self._messenger is not None: 261 self._messenger.marker_received(current_step_marker) 262 263 # Process markers in order specified in the docstrings 264 # First check if it's a known command marker 265 if current_step_marker in self.marker_methods: 266 continue_flag = self.__handle_command_marker(current_step_marker) 267 # Then check if it's an event marker (contains commas) 268 elif "," in current_step_marker: 269 continue_flag = self.__handle_event_marker( 270 current_step_marker, current_timestamp 271 ) 272 # Otherwise, log a warning about unknown marker type 273 else: 274 # Log warning for unknown marker types 275 logger.warning("Unknown marker type received: %s", current_step_marker) 276 continue_flag = True 277 278 # Check if we should continue processing markers in the while loop 279 # if continue_flag is False, then break out of the while loop 280 # else, increment the marker count and process the next marker 281 if continue_flag is False: 282 break 283 else: 284 logger.info("Processed Marker: %s", current_step_marker) 285 self.marker_count += 1
Runs a single BciController processing step.
See setup() for configuration of processing.
The method:
- Pulls data from sources (EEG and markers).
- Run a while loop to process markers as long as there are unprocessed markers.
- The while loop processes the markers in the following order:
- First checks if the marker is a known command marker from self.marker_methods.
- Then checks if it's an event marker (contains commas)
- If neither, logs a warning about unknown marker type
- If the marker is a command marker, handles it by calling __handle_command_marker().
- If the marker is an event marker, handles it by calling __handle_event_marker().
- If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker.
- Note: If there is an unknown marker type, the marker count is still incremented and processing continues.
- If the command or event marker handling return continue_flag as False, break out of the while loop and end the step.
Parameters
None
Returns
None
287 def run(self, max_loops: int = 1000000): 288 """Runs BciController processing in a loop. 289 290 See setup() for configuration of processing. 291 292 Parameters 293 ---------- 294 max_loops : int, *optional* 295 Maximum number of loops to run, default is `1000000`. 296 297 Returns 298 ------ 299 `None` 300 301 """ 302 # if offline, then all data is already loaded, only need to loop once 303 if self.online is False: 304 self.loops = max_loops - 1 305 else: 306 self.loops = 0 307 308 # Initialize the event marker buffer 309 self.event_marker_buffer = [] 310 self.event_timestamp_buffer = [] 311 312 # start the main loop, stops after pulling new data, max_loops times 313 while self.loops < max_loops: 314 # print out loop status 315 if self.loops % 100 == 0: 316 logger.debug(self.loops) 317 318 if self.loops == max_loops - 1: 319 logger.debug("last loop") 320 321 # read from sources and process 322 self.step() 323 324 # Wait a short period of time and then try to pull more data 325 if self.online: 326 time.sleep(0.00001) 327 328 self.loops += 1
Runs BciController processing in a loop.
See setup() for configuration of processing.
Parameters
- max_loops (int, optional):
Maximum number of loops to run, default is
1000000.
Returns
None