bci_essentials.bci_controller
Module for managing BCI data.
This module provides data classes for different BCI paradigms.
It includes the loading of offline data in xdf format
or the live streaming of LSL data.
The loaded/streamed data is added to a buffer such that offline and online processing pipelines are identical.
Data is pre-processed (using the signal_processing module), divided into trials,
and classified (using one of the classification sub-modules).
Classes
BciController: For processing continuous data in trials of a defined length.
1"""Module for managing BCI data. 2 3This module provides data classes for different BCI paradigms. 4 5It includes the loading of offline data in `xdf` format 6or the live streaming of LSL data. 7 8The loaded/streamed data is added to a buffer such that offline and 9online processing pipelines are identical. 10 11Data is pre-processed (using the `signal_processing` module), divided into trials, 12and classified (using one of the `classification` sub-modules). 13 14Classes 15------- 16- `BciController` : For processing continuous data in trials of a defined length. 17 18""" 19 20import time 21import os 22import numpy as np 23from enum import Enum 24 25from .paradigm.paradigm import Paradigm 26from .data_tank.data_tank import DataTank 27from .classification.generic_classifier import GenericClassifier 28from .io.sources import EegSource, MarkerSource 29from .io.messenger import Messenger 30from .utils.logger import Logger 31 32# Instantiate a logger for the module at the default level of logging.INFO 33# Logs to bci_essentials.__module__) where __module__ is the name of the module 34logger = Logger(name=__name__) 35 36 37class MarkerTypes(Enum): 38 TRIAL_STARTED = "Trial Started" 39 TRIAL_ENDS = "Trial Ends" 40 TRAINING_COMPLETE = "Training Complete" 41 TRAIN_CLASSIFIER = "Train Classifier" 42 DONE_RS_COLLECTION = "Done with all RS collection" 43 UPDATE_CLASSIFIER = "Update Classifier" 44 45 46# EEG data 47class BciController: 48 """ 49 Class that holds, trials, processes, and classifies EEG data. 50 This class is used for processing of continuous EEG data in trials of a defined length. 51 """ 52 53 # 0. Special methods (e.g. __init__) 54 def __init__( 55 self, 56 classifier: GenericClassifier, 57 eeg_source: EegSource, 58 marker_source: MarkerSource | None = None, 59 paradigm: Paradigm | None = None, 60 data_tank: DataTank | None = None, 61 messenger: Messenger | None = None, 62 ): 63 """Initializes `BciController` class. 64 65 Parameters 66 ---------- 67 classifier : GenericClassifier 68 The classifier used by BciController. 69 eeg_source : EegSource 70 Source of EEG data and timestamps, this could be from a file or headset via LSL, etc. 71 marker_source : EegSource 72 Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc. 73 - Default is `None`. 74 paradigm : Paradigm 75 The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data. 76 data_tank : DataTank 77 DataTank object to handle the storage of EEG trials and labels. 78 - Default is `None`. 79 messenger: Messenger 80 Messenger object to handle events from BciController, ex: acknowledging markers and 81 predictions. 82 - Default is `None`. 83 84 """ 85 86 # Ensure the incoming dependencies are the right type 87 assert isinstance(classifier, GenericClassifier) 88 assert isinstance(eeg_source, EegSource) 89 assert isinstance(marker_source, MarkerSource | None) 90 assert isinstance(paradigm, Paradigm | None) 91 assert isinstance(data_tank, DataTank | None) 92 assert isinstance(messenger, Messenger | None) 93 94 self._classifier = classifier 95 self.__eeg_source = eeg_source 96 self.__marker_source = marker_source 97 self.__paradigm = paradigm 98 self.__data_tank = data_tank 99 self._messenger = messenger 100 101 self.headset_string = self.__eeg_source.name 102 self.fsample = self.__eeg_source.fsample 103 self.n_channels = self.__eeg_source.n_channels 104 self.ch_type = self.__eeg_source.channel_types 105 self.ch_units = self.__eeg_source.channel_units 106 self.channel_labels = self.__eeg_source.channel_labels 107 108 # Emily EGI fix 109 # Set default channel types if none 110 if self.ch_type is None: 111 logger.warning("Channel types are none, setting all to 'eeg'") 112 self.ch_type = ["eeg"] * self.n_channels 113 114 self.__data_tank.set_source_data( 115 self.headset_string, 116 self.fsample, 117 self.n_channels, 118 self.ch_type, 119 self.ch_units, 120 self.channel_labels, 121 ) 122 123 # Switch any trigger channels to stim, this is for mne/bids export (?) 124 self.ch_type = [type.replace("trg", "stim") for type in self.ch_type] 125 126 self._classifier.channel_labels = self.channel_labels 127 128 logger.info(self.headset_string) 129 logger.info(self.channel_labels) 130 131 # Initialize data and timestamp arrays to the right dimensions, but zero elements 132 self.marker_data = np.zeros((0, 1)) 133 self.marker_timestamps = np.zeros((0)) 134 self.bci_controller = np.zeros((0, self.n_channels)) 135 self.eeg_timestamps = np.zeros((0)) 136 137 # Initialize marker methods dictionary 138 self.marker_methods = { 139 MarkerTypes.DONE_RS_COLLECTION.value: self.__process_resting_state_data, 140 MarkerTypes.TRIAL_STARTED.value: self.__log_trial_start, 141 MarkerTypes.TRIAL_ENDS.value: self.__handle_trial_end, 142 MarkerTypes.TRAINING_COMPLETE.value: self.__update_and_train_classifier, 143 MarkerTypes.TRAIN_CLASSIFIER.value: self.__update_and_train_classifier, 144 MarkerTypes.UPDATE_CLASSIFIER.value: self.__update_and_train_classifier, 145 } 146 147 self.step_count = 0 148 self.ping_interval = 1000 149 self.n_samples = 0 150 self.time_units = "" 151 152 # 1. Core public API methods 153 def setup( 154 self, 155 online=True, 156 train_complete=False, 157 train_lock=False, 158 auto_save_epochs=True, 159 ): 160 """Configure processing loop. 161 162 This should be called before starting the loop with run() or step(). 163 164 Calling after will reset the loop state. 165 166 The processing loop reads in EEG and marker data and processes it. 167 The loop can be run in "offline" or "online" modes: 168 - If in `online` mode, then the loop will continuously try to read 169 in data from the `BciController` object and process it. The loop will 170 terminate when `max_loops` is reached, or when manually terminated. 171 - If in `offline` mode, then the loop will read in all of the data 172 at once, process it, and then terminate. 173 174 Parameters 175 ---------- 176 online : bool, *optional* 177 Flag to indicate if the data will be processed in `online` mode. 178 - `True`: The data will be processed in `online` mode. 179 - `False`: The data will be processed in `offline` mode. 180 - Default is `True`. 181 train_complete : bool, *optional* 182 Flag to indicate if the classifier has been trained. 183 - `True`: The classifier has been trained. 184 - `False`: The classifier has not been trained. 185 - Default is `False`. 186 train_lock : bool, *optional* 187 Flag to indicate if the classifier is locked (ie. no more training). 188 - `True`: The classifier is locked. 189 - `False`: The classifier is not locked. 190 - Default is `False`. 191 auto_save_epochs : bool, *optional* 192 Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes. 193 - `True`: Epochs will be saved to a temp file. 194 - `False`: Epochs will not be saved to a temp file. 195 196 197 Returns 198 ------- 199 `None` 200 201 """ 202 self.online = online 203 self.train_complete = train_complete 204 self.train_lock = train_lock 205 self.auto_save_epochs = auto_save_epochs 206 207 # initialize the numbers of markers and trials to zero 208 self.marker_count = 0 209 self.current_num_trials = 0 210 self.n_trials = 0 211 212 self.num_online_selections = 0 213 self.online_selection_indices = [] 214 self.online_selections = [] 215 216 # Check for a temp_epochs file 217 if online: 218 self.__load_temp_epochs_if_available() 219 220 def step(self): 221 """Runs a single BciController processing step. 222 223 See setup() for configuration of processing. 224 225 The method: 226 1. Pulls data from sources (EEG and markers). 227 2. Run a while loop to process markers as long as there are unprocessed markers. 228 3. The while loop processes the markers in the following order: 229 - First checks if the marker is a known command marker from self.marker_methods. 230 - Then checks if it's an event marker (contains commas) 231 - If neither, logs a warning about unknown marker type 232 3. If the marker is a command marker, handles it by calling __handle_command_marker(). 233 4. If the marker is an event marker, handles it by calling __handle_event_marker(). 234 5. If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker. 235 - Note: If there is an unknown marker type, the marker count is still incremented and processing continues. 236 6. If the command or event marker handling return continue_flag as False, break out of the while loop and end the step. 237 238 Parameters 239 ---------- 240 `None` 241 242 Returns 243 ------ 244 `None` 245 246 """ 247 # read from sources to get new data. 248 # This puts command markers in the marker_data array and 249 # event markers in the event_marker_strings array 250 self._pull_data_from_sources() 251 252 # Process markers while there are unprocessed markers 253 # REMOVE COMMENT: check if there is an available command marker, if not, break and wait for more data 254 while len(self.marker_timestamps) > self.marker_count: 255 # Get the current marker 256 current_step_marker = self.marker_data[self.marker_count] # String 257 current_timestamp = self.marker_timestamps[self.marker_count] # Float 258 259 # If marker is empty, skip it 260 if not current_step_marker: 261 logger.warning("Empty marker received") 262 self.marker_count += 1 263 continue 264 265 # If messenger is available, send feedback for each marker received 266 if self._messenger is not None: 267 self._messenger.marker_received(current_step_marker) 268 269 # Process markers in order specified in the docstrings 270 # First check if it's a known command marker 271 if current_step_marker in self.marker_methods: 272 continue_flag = self.__handle_command_marker(current_step_marker) 273 # Then check if it's an event marker (contains commas) 274 elif "," in current_step_marker: 275 continue_flag = self.__handle_event_marker( 276 current_step_marker, current_timestamp 277 ) 278 # Otherwise, log a warning about unknown marker type 279 else: 280 # Log warning for unknown marker types 281 logger.warning("Unknown marker type received: %s", current_step_marker) 282 continue_flag = True 283 284 # Check if we should continue processing markers in the while loop 285 # if continue_flag is False, then break out of the while loop 286 # else, increment the marker count and process the next marker 287 if continue_flag is False: 288 break 289 else: 290 logger.info("Processed Marker: %s", current_step_marker) 291 self.marker_count += 1 292 293 self.step_count += 1 294 if self.step_count % self.ping_interval == 0: 295 if self._messenger is not None: 296 self._messenger.ping() 297 298 def run(self, max_loops: int = 1000000, ping_interval: int = 100): 299 """Runs BciController processing in a loop. 300 301 See setup() for configuration of processing. 302 303 Parameters 304 ---------- 305 max_loops : int, *optional* 306 Maximum number of loops to run, default is `1000000`. 307 ping_interval : int, *optional* 308 Number of steps between each messenger ping. 309 310 Returns 311 ------ 312 `None` 313 314 """ 315 # if offline, then all data is already loaded, only need to loop once 316 if self.online is False: 317 self.loops = max_loops - 1 318 else: 319 self.loops = 0 320 321 self.ping_interval = ping_interval 322 323 # Initialize the event marker buffer 324 self.event_marker_buffer = [] 325 self.event_timestamp_buffer = [] 326 327 # start the main loop, stops after pulling new data, max_loops times 328 while self.loops < max_loops: 329 # print out loop status 330 if self.loops % 100 == 0: 331 logger.debug(self.loops) 332 333 if self.loops == max_loops - 1: 334 logger.debug("last loop") 335 336 # read from sources and process 337 self.step() 338 339 # Wait a short period of time and then try to pull more data 340 if self.online: 341 time.sleep(0.00001) 342 343 self.loops += 1 344 345 # 2. Protected methods (single underscore) 346 def _pull_data_from_sources(self): 347 """Get pull data from EEG and optionally, the marker source. 348 349 This method will fill up the marker_data, bci_controller and corresponding timestamp arrays. 350 351 Parameters 352 ---------- 353 `None` 354 355 Returns 356 ------- 357 `None` 358 359 """ 360 # Get new data from source, whatever it is 361 self.__pull_marker_data_from_source() 362 self.__pull_eeg_data_from_source() 363 364 # 3. Private methods (double underscore) 365 # 3a. Private methods for retrieving data from sources 366 def __pull_marker_data_from_source(self): 367 """Pulls marker samples from source, sanity checks and appends to buffer. 368 369 Parameters 370 ---------- 371 `None` 372 373 Returns 374 ------- 375 `None` 376 377 """ 378 379 # if there isn't a marker source, abort 380 if self.__marker_source is None: 381 return 382 383 # read in the data 384 markers, timestamps = self.__marker_source.get_markers() 385 markers = np.array(markers) 386 timestamps = np.array(timestamps) 387 388 if markers.size == 0: 389 return 390 391 if markers.ndim != 2: 392 logger.warning("discarded invalid marker data") 393 return 394 395 # apply time correction 396 time_correction = self.__marker_source.time_correction() 397 timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))] 398 399 for i, marker in enumerate(markers): 400 marker = marker[0] 401 if "Ping" in marker: 402 continue 403 404 # Add all markers to the controller 405 self.marker_data = np.append(self.marker_data, marker) 406 self.marker_timestamps = np.append(self.marker_timestamps, timestamps[i]) 407 408 # Add all markers to the data tank 409 self.__data_tank.add_raw_markers( 410 np.array([marker]), np.array([timestamps[i]]) 411 ) 412 413 def __pull_eeg_data_from_source(self): 414 """Pulls eeg samples from source, sanity checks and appends to buffer. 415 416 Parameters 417 ---------- 418 `None` 419 420 Returns 421 ------- 422 `None` 423 424 """ 425 426 # read in the data 427 eeg, timestamps = self.__eeg_source.get_samples() 428 eeg = np.array(eeg) 429 timestamps = np.array(timestamps) 430 431 if eeg.size == 0: 432 return 433 434 if eeg.ndim != 2: 435 logger.warning("discarded invalid eeg data") 436 return 437 438 # if time is in milliseconds, divide by 1000, works for sampling rates above 10Hz 439 try: 440 if self.time_units == "milliseconds": 441 timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))] 442 443 # If time units are not defined then define them 444 except Exception: 445 dif_low = -2 446 dif_high = -1 447 while timestamps[dif_high] - timestamps[dif_low] == 0: 448 dif_low -= 1 449 dif_high -= 1 450 451 if timestamps[dif_high] - timestamps[dif_low] > 0.1: 452 timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))] 453 self.time_units = "milliseconds" 454 else: 455 self.time_units = "seconds" 456 457 # apply time correction, this is essential for headsets like neurosity which have their own clock 458 time_correction = self.__eeg_source.time_correction() 459 timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))] 460 461 self.__data_tank.add_raw_eeg(eeg.T, timestamps) 462 463 # Update latest EEG timestamp 464 self.latest_eeg_timestamp = timestamps[-1] 465 466 # 3b. Private methods for data processing and classification 467 def __process_resting_state_data(self): 468 """Handles the resting state data by packaging it and adding it to the data tank. 469 470 Parameters 471 ---------- 472 `None` 473 474 Returns 475 ------ 476 continue_flag : bool 477 Flag indicating to continue the while loop in step(). 478 479 """ 480 ( 481 self.bci_controller, 482 self.eeg_timestamps, 483 ) = self.__data_tank.get_raw_eeg() 484 485 resting_state_data = self.__paradigm.package_resting_state_data( 486 self.marker_data, 487 self.marker_timestamps, 488 self.bci_controller, 489 self.eeg_timestamps, 490 self.fsample, 491 ) 492 493 self.__data_tank.add_resting_state_data(resting_state_data) 494 495 return True # Continue processing 496 497 def __process_and_classify(self): 498 """Process the markers and classify the data. 499 500 Parameters 501 ---------- 502 `None` 503 504 Returns 505 ------- 506 success_string : str 507 String indicating if the processing and classification was successful. 508 Potential values are "Success", "Skip", "Wait". 509 - "Success": The processing and classification was successful. 510 - "Skip": EEG is either absent entirely or contains lost packets. 511 - "Wait": The processing is waiting for more data. 512 513 """ 514 515 eeg_start_time, eeg_end_time = self.__paradigm.get_eeg_start_and_end_times( 516 self.event_marker_buffer, self.event_timestamp_buffer 517 ) 518 519 # No we actually need to wait until we have all the data for these markers 520 eeg, timestamps = self.__data_tank.get_raw_eeg() 521 522 # Check if there is available EEG data 523 if len(eeg) == 0: 524 logger.warning("No EEG data available") 525 return "Skip" 526 527 # If the last timestamp is less than the end time, then we don't have the necessary EEG to process 528 if timestamps[-1] < eeg_end_time: 529 return "Wait" 530 531 # Check if EEG sampling is continuous over this time period 532 start_indices = np.where(timestamps > eeg_start_time)[0] 533 if len(start_indices) == 0: 534 logger.warning("No timestamps exceed eeg_start_time") 535 return "Skip" 536 start_index = start_indices[0] 537 end_index = np.where(timestamps < eeg_end_time)[0][-1] 538 539 time_diffs = np.diff(timestamps[start_index:end_index]) 540 if np.any(time_diffs > 2 / self.fsample): 541 logger.warning("Time gaps in EEG data") 542 return "Skip" 543 544 X, y = self.__paradigm.process_markers( 545 self.event_marker_buffer, 546 self.event_timestamp_buffer, 547 eeg, 548 timestamps, 549 self.fsample, 550 ) 551 552 sum_new_labeled_trials = np.sum(y != -1) 553 554 # Add the epochs to the data tank 555 self.__data_tank.add_epochs(X, y) 556 557 # Save epochs to temp_epochs file 558 if self.auto_save_epochs and self.online and sum_new_labeled_trials > 0: 559 paradigm_str = self.__paradigm.paradigm_name 560 561 with open(self.temp_epochs, "wb") as f: 562 np.savez( 563 f, 564 X=self.__data_tank.epochs, 565 y=self.__data_tank.labels, 566 paradigm=paradigm_str, 567 ) 568 569 # If either there are no labels OR iterative training is on, then make a prediction 570 if self.train_complete: 571 if -1 in y or self.__paradigm.iterative_training: 572 prediction = self._classifier.predict(X) 573 self.__send_prediction(prediction) 574 575 self.event_marker_buffer = [] 576 self.event_timestamp_buffer = [] 577 578 return "Success" 579 580 def __update_and_train_classifier(self): 581 """Updates the classifier if required. 582 583 Parameters 584 ---------- 585 `None` 586 587 Returns 588 ------- 589 continue_flag : bool 590 Flag indicating to continue the while loop in step(). 591 """ 592 if self.train_lock is False: 593 # Pull the epochs from the data tank and pass them to the classifier 594 X, y = self.__data_tank.get_epochs(latest=True) 595 596 # Remove epochs with label -1 597 ind_to_remove = [] 598 for i, label in enumerate(y): 599 if label == -1: 600 ind_to_remove.append(i) 601 X = np.delete(X, ind_to_remove, axis=0) 602 y = np.delete(y, ind_to_remove, axis=0) 603 604 # Check that there are epochs 605 if len(y) > 0: 606 self._classifier.add_to_train(X, y) 607 608 if self._classifier._check_ready_for_fit(): 609 self._classifier.fit() 610 self.train_complete = True 611 612 return True 613 614 # 3c. Private methods for event handling (trial and markers) and messaging 615 def __log_trial_start(self): 616 """Logs the start of a trial. 617 618 Parameters 619 ---------- 620 `None` 621 622 Returns 623 ------- 624 continue_flag : bool 625 Flag indicating to continue the while loop in step(). 626 627 """ 628 logger.debug("Trial started, incrementing marker count and continuing") 629 # Note that a marker occured, but do nothing else 630 return True # Continue processing 631 632 def __handle_trial_end(self): 633 """Handles the end of a trial. Processes and classifies trial data if required. 634 635 Parameters 636 ---------- 637 `None` 638 639 Returns 640 ------ 641 success_flag : bool 642 Flag indicating if the processing and classification was successful. 643 - Returns `True` if not classifying. 644 """ 645 # If we are classifying based on trials, then process the trial, 646 if self.__paradigm.classify_each_trial: 647 success_string = self.__process_and_classify() 648 if success_string == "Wait": 649 logger.debug("Processing of trial not run: waiting for more data") 650 return False 651 if success_string == "Skip": 652 logger.warning("Processing of trial failed: skipping trial") 653 self.event_marker_buffer = [] 654 self.event_timestamp_buffer = [] 655 self.marker_count += 1 656 return False 657 658 return True # Return True by default if not classifying 659 660 def __handle_event_marker(self, marker, timestamp): 661 """Processes and classifies event markers. 662 663 Parameters 664 ---------- 665 marker : str 666 Event marker string containing comma-separated values. 667 - Format depends on paradigm implementation. 668 timestamp : float 669 Timestamp of the marker in seconds (after time correction). 670 671 672 Returns 673 ------ 674 continue_flag : bool 675 Flag indicating to continue the while loop in step(). 676 677 """ 678 # Add the marker to the event marker buffer 679 self.event_marker_buffer.append(marker) 680 self.event_timestamp_buffer.append(timestamp) 681 682 # If classification is on epochs, then update epochs, maybe classify, and clear the buffer 683 if self.__paradigm.classify_each_epoch: 684 success_string = self.__process_and_classify() 685 if success_string == "Wait": 686 logger.debug("Processing of epoch not run: waiting for more data") 687 self.event_marker_buffer = [] 688 self.event_timestamp_buffer = [] 689 return False # Stop processing 690 elif success_string == "Skip": 691 logger.warning("Processing of epoch failed: skipping epoch") 692 self.event_marker_buffer = [] 693 self.event_timestamp_buffer = [] 694 695 self.marker_count += 1 696 return False # Stop processing 697 698 return True # Continue processing 699 700 def __handle_command_marker(self, marker: str) -> bool: 701 """Processes a command marker by invoking its associated method. 702 703 The command marker string is assumed to be in the self.marker_methods dictionary. 704 The associated method is retrieved and called. 705 The return value of the method is used to determine if processing should continue. 706 707 Parameters 708 ---------- 709 marker : str 710 A command marker string (assumed to be in self.marker_methods). 711 712 Returns 713 ------- 714 bool 715 A flag indicating if the processing should continue. 716 717 """ 718 command_marker_method = self.marker_methods[marker] # Retrieve method 719 continue_flag = command_marker_method() # Call method 720 721 # Debug level logging if continue_flag is FALSE 722 if continue_flag is False: 723 logger.debug("Command marker '%s' set continue_flag to FALSE", marker) 724 725 return continue_flag 726 727 def __send_prediction(self, prediction): 728 """Send a prediction to the messenger object. 729 730 Parameters 731 ---------- 732 `None` 733 734 Returns 735 ------- 736 `None` 737 738 """ 739 if self._messenger is not None: 740 logger.debug("Sending prediction: %s", prediction) 741 self._messenger.prediction(prediction) 742 elif self._messenger is None and self.online is True: 743 # If running in online mode and messenger is not available, log a warning 744 logger.warning( 745 "Messenger not available (self._messenger is None). Prediction not sent: %s", 746 prediction, 747 ) 748 749 def __load_temp_epochs_if_available(self, reload_data_time: int = 300): 750 """Load temp_epochs if available and valid. 751 752 Parameters 753 ---------- 754 reload_data_time : int, *optional* 755 Time in seconds of the last temp_epochs file to reload the data from. 756 Default is `300` seconds (5 minutes). 757 758 Returns 759 ------- 760 `None` 761 762 """ 763 self.temp_epochs = os.path.join( 764 os.path.dirname(os.path.dirname(__file__)), "temp_epochs.npz" 765 ) 766 767 if not os.path.exists(self.temp_epochs): 768 return 769 770 # If temp_epochs is older than `reload_data_time`, delete it 771 if os.path.getmtime(self.temp_epochs) < (time.time() - reload_data_time): 772 os.remove(self.temp_epochs) 773 logger.info("Deleted old temp_epochs file.") 774 return 775 776 # Load the temp_epochs file 777 with open(self.temp_epochs, "rb") as f: 778 npz = np.load(f, allow_pickle=True) 779 X = npz["X"] 780 y = npz["y"] 781 paradigm_str = npz["paradigm"].item() 782 783 # If the paradigm is different, delete the file 784 if self.__paradigm.paradigm_name != paradigm_str: 785 logger.warning( 786 "Paradigm in temp_epochs file does not match current paradigm. Deleting file." 787 ) 788 os.remove(self.temp_epochs) 789 return 790 791 # If the paradigm is the same, then add the epochs to the data tank 792 logger.info("Loading epochs from temp_epochs file.") 793 logger.info("X shape: %s", X.shape) 794 logger.info("y shape: %s", y.shape) 795 self.__data_tank.add_epochs(X, y) 796 797 # If there are epochs in the data tank, then train the classifier 798 if len(self.__data_tank.labels) > 0: 799 self.__update_and_train_classifier()
38class MarkerTypes(Enum): 39 TRIAL_STARTED = "Trial Started" 40 TRIAL_ENDS = "Trial Ends" 41 TRAINING_COMPLETE = "Training Complete" 42 TRAIN_CLASSIFIER = "Train Classifier" 43 DONE_RS_COLLECTION = "Done with all RS collection" 44 UPDATE_CLASSIFIER = "Update Classifier"
An enumeration.
48class BciController: 49 """ 50 Class that holds, trials, processes, and classifies EEG data. 51 This class is used for processing of continuous EEG data in trials of a defined length. 52 """ 53 54 # 0. Special methods (e.g. __init__) 55 def __init__( 56 self, 57 classifier: GenericClassifier, 58 eeg_source: EegSource, 59 marker_source: MarkerSource | None = None, 60 paradigm: Paradigm | None = None, 61 data_tank: DataTank | None = None, 62 messenger: Messenger | None = None, 63 ): 64 """Initializes `BciController` class. 65 66 Parameters 67 ---------- 68 classifier : GenericClassifier 69 The classifier used by BciController. 70 eeg_source : EegSource 71 Source of EEG data and timestamps, this could be from a file or headset via LSL, etc. 72 marker_source : EegSource 73 Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc. 74 - Default is `None`. 75 paradigm : Paradigm 76 The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data. 77 data_tank : DataTank 78 DataTank object to handle the storage of EEG trials and labels. 79 - Default is `None`. 80 messenger: Messenger 81 Messenger object to handle events from BciController, ex: acknowledging markers and 82 predictions. 83 - Default is `None`. 84 85 """ 86 87 # Ensure the incoming dependencies are the right type 88 assert isinstance(classifier, GenericClassifier) 89 assert isinstance(eeg_source, EegSource) 90 assert isinstance(marker_source, MarkerSource | None) 91 assert isinstance(paradigm, Paradigm | None) 92 assert isinstance(data_tank, DataTank | None) 93 assert isinstance(messenger, Messenger | None) 94 95 self._classifier = classifier 96 self.__eeg_source = eeg_source 97 self.__marker_source = marker_source 98 self.__paradigm = paradigm 99 self.__data_tank = data_tank 100 self._messenger = messenger 101 102 self.headset_string = self.__eeg_source.name 103 self.fsample = self.__eeg_source.fsample 104 self.n_channels = self.__eeg_source.n_channels 105 self.ch_type = self.__eeg_source.channel_types 106 self.ch_units = self.__eeg_source.channel_units 107 self.channel_labels = self.__eeg_source.channel_labels 108 109 # Emily EGI fix 110 # Set default channel types if none 111 if self.ch_type is None: 112 logger.warning("Channel types are none, setting all to 'eeg'") 113 self.ch_type = ["eeg"] * self.n_channels 114 115 self.__data_tank.set_source_data( 116 self.headset_string, 117 self.fsample, 118 self.n_channels, 119 self.ch_type, 120 self.ch_units, 121 self.channel_labels, 122 ) 123 124 # Switch any trigger channels to stim, this is for mne/bids export (?) 125 self.ch_type = [type.replace("trg", "stim") for type in self.ch_type] 126 127 self._classifier.channel_labels = self.channel_labels 128 129 logger.info(self.headset_string) 130 logger.info(self.channel_labels) 131 132 # Initialize data and timestamp arrays to the right dimensions, but zero elements 133 self.marker_data = np.zeros((0, 1)) 134 self.marker_timestamps = np.zeros((0)) 135 self.bci_controller = np.zeros((0, self.n_channels)) 136 self.eeg_timestamps = np.zeros((0)) 137 138 # Initialize marker methods dictionary 139 self.marker_methods = { 140 MarkerTypes.DONE_RS_COLLECTION.value: self.__process_resting_state_data, 141 MarkerTypes.TRIAL_STARTED.value: self.__log_trial_start, 142 MarkerTypes.TRIAL_ENDS.value: self.__handle_trial_end, 143 MarkerTypes.TRAINING_COMPLETE.value: self.__update_and_train_classifier, 144 MarkerTypes.TRAIN_CLASSIFIER.value: self.__update_and_train_classifier, 145 MarkerTypes.UPDATE_CLASSIFIER.value: self.__update_and_train_classifier, 146 } 147 148 self.step_count = 0 149 self.ping_interval = 1000 150 self.n_samples = 0 151 self.time_units = "" 152 153 # 1. Core public API methods 154 def setup( 155 self, 156 online=True, 157 train_complete=False, 158 train_lock=False, 159 auto_save_epochs=True, 160 ): 161 """Configure processing loop. 162 163 This should be called before starting the loop with run() or step(). 164 165 Calling after will reset the loop state. 166 167 The processing loop reads in EEG and marker data and processes it. 168 The loop can be run in "offline" or "online" modes: 169 - If in `online` mode, then the loop will continuously try to read 170 in data from the `BciController` object and process it. The loop will 171 terminate when `max_loops` is reached, or when manually terminated. 172 - If in `offline` mode, then the loop will read in all of the data 173 at once, process it, and then terminate. 174 175 Parameters 176 ---------- 177 online : bool, *optional* 178 Flag to indicate if the data will be processed in `online` mode. 179 - `True`: The data will be processed in `online` mode. 180 - `False`: The data will be processed in `offline` mode. 181 - Default is `True`. 182 train_complete : bool, *optional* 183 Flag to indicate if the classifier has been trained. 184 - `True`: The classifier has been trained. 185 - `False`: The classifier has not been trained. 186 - Default is `False`. 187 train_lock : bool, *optional* 188 Flag to indicate if the classifier is locked (ie. no more training). 189 - `True`: The classifier is locked. 190 - `False`: The classifier is not locked. 191 - Default is `False`. 192 auto_save_epochs : bool, *optional* 193 Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes. 194 - `True`: Epochs will be saved to a temp file. 195 - `False`: Epochs will not be saved to a temp file. 196 197 198 Returns 199 ------- 200 `None` 201 202 """ 203 self.online = online 204 self.train_complete = train_complete 205 self.train_lock = train_lock 206 self.auto_save_epochs = auto_save_epochs 207 208 # initialize the numbers of markers and trials to zero 209 self.marker_count = 0 210 self.current_num_trials = 0 211 self.n_trials = 0 212 213 self.num_online_selections = 0 214 self.online_selection_indices = [] 215 self.online_selections = [] 216 217 # Check for a temp_epochs file 218 if online: 219 self.__load_temp_epochs_if_available() 220 221 def step(self): 222 """Runs a single BciController processing step. 223 224 See setup() for configuration of processing. 225 226 The method: 227 1. Pulls data from sources (EEG and markers). 228 2. Run a while loop to process markers as long as there are unprocessed markers. 229 3. The while loop processes the markers in the following order: 230 - First checks if the marker is a known command marker from self.marker_methods. 231 - Then checks if it's an event marker (contains commas) 232 - If neither, logs a warning about unknown marker type 233 3. If the marker is a command marker, handles it by calling __handle_command_marker(). 234 4. If the marker is an event marker, handles it by calling __handle_event_marker(). 235 5. If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker. 236 - Note: If there is an unknown marker type, the marker count is still incremented and processing continues. 237 6. If the command or event marker handling return continue_flag as False, break out of the while loop and end the step. 238 239 Parameters 240 ---------- 241 `None` 242 243 Returns 244 ------ 245 `None` 246 247 """ 248 # read from sources to get new data. 249 # This puts command markers in the marker_data array and 250 # event markers in the event_marker_strings array 251 self._pull_data_from_sources() 252 253 # Process markers while there are unprocessed markers 254 # REMOVE COMMENT: check if there is an available command marker, if not, break and wait for more data 255 while len(self.marker_timestamps) > self.marker_count: 256 # Get the current marker 257 current_step_marker = self.marker_data[self.marker_count] # String 258 current_timestamp = self.marker_timestamps[self.marker_count] # Float 259 260 # If marker is empty, skip it 261 if not current_step_marker: 262 logger.warning("Empty marker received") 263 self.marker_count += 1 264 continue 265 266 # If messenger is available, send feedback for each marker received 267 if self._messenger is not None: 268 self._messenger.marker_received(current_step_marker) 269 270 # Process markers in order specified in the docstrings 271 # First check if it's a known command marker 272 if current_step_marker in self.marker_methods: 273 continue_flag = self.__handle_command_marker(current_step_marker) 274 # Then check if it's an event marker (contains commas) 275 elif "," in current_step_marker: 276 continue_flag = self.__handle_event_marker( 277 current_step_marker, current_timestamp 278 ) 279 # Otherwise, log a warning about unknown marker type 280 else: 281 # Log warning for unknown marker types 282 logger.warning("Unknown marker type received: %s", current_step_marker) 283 continue_flag = True 284 285 # Check if we should continue processing markers in the while loop 286 # if continue_flag is False, then break out of the while loop 287 # else, increment the marker count and process the next marker 288 if continue_flag is False: 289 break 290 else: 291 logger.info("Processed Marker: %s", current_step_marker) 292 self.marker_count += 1 293 294 self.step_count += 1 295 if self.step_count % self.ping_interval == 0: 296 if self._messenger is not None: 297 self._messenger.ping() 298 299 def run(self, max_loops: int = 1000000, ping_interval: int = 100): 300 """Runs BciController processing in a loop. 301 302 See setup() for configuration of processing. 303 304 Parameters 305 ---------- 306 max_loops : int, *optional* 307 Maximum number of loops to run, default is `1000000`. 308 ping_interval : int, *optional* 309 Number of steps between each messenger ping. 310 311 Returns 312 ------ 313 `None` 314 315 """ 316 # if offline, then all data is already loaded, only need to loop once 317 if self.online is False: 318 self.loops = max_loops - 1 319 else: 320 self.loops = 0 321 322 self.ping_interval = ping_interval 323 324 # Initialize the event marker buffer 325 self.event_marker_buffer = [] 326 self.event_timestamp_buffer = [] 327 328 # start the main loop, stops after pulling new data, max_loops times 329 while self.loops < max_loops: 330 # print out loop status 331 if self.loops % 100 == 0: 332 logger.debug(self.loops) 333 334 if self.loops == max_loops - 1: 335 logger.debug("last loop") 336 337 # read from sources and process 338 self.step() 339 340 # Wait a short period of time and then try to pull more data 341 if self.online: 342 time.sleep(0.00001) 343 344 self.loops += 1 345 346 # 2. Protected methods (single underscore) 347 def _pull_data_from_sources(self): 348 """Get pull data from EEG and optionally, the marker source. 349 350 This method will fill up the marker_data, bci_controller and corresponding timestamp arrays. 351 352 Parameters 353 ---------- 354 `None` 355 356 Returns 357 ------- 358 `None` 359 360 """ 361 # Get new data from source, whatever it is 362 self.__pull_marker_data_from_source() 363 self.__pull_eeg_data_from_source() 364 365 # 3. Private methods (double underscore) 366 # 3a. Private methods for retrieving data from sources 367 def __pull_marker_data_from_source(self): 368 """Pulls marker samples from source, sanity checks and appends to buffer. 369 370 Parameters 371 ---------- 372 `None` 373 374 Returns 375 ------- 376 `None` 377 378 """ 379 380 # if there isn't a marker source, abort 381 if self.__marker_source is None: 382 return 383 384 # read in the data 385 markers, timestamps = self.__marker_source.get_markers() 386 markers = np.array(markers) 387 timestamps = np.array(timestamps) 388 389 if markers.size == 0: 390 return 391 392 if markers.ndim != 2: 393 logger.warning("discarded invalid marker data") 394 return 395 396 # apply time correction 397 time_correction = self.__marker_source.time_correction() 398 timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))] 399 400 for i, marker in enumerate(markers): 401 marker = marker[0] 402 if "Ping" in marker: 403 continue 404 405 # Add all markers to the controller 406 self.marker_data = np.append(self.marker_data, marker) 407 self.marker_timestamps = np.append(self.marker_timestamps, timestamps[i]) 408 409 # Add all markers to the data tank 410 self.__data_tank.add_raw_markers( 411 np.array([marker]), np.array([timestamps[i]]) 412 ) 413 414 def __pull_eeg_data_from_source(self): 415 """Pulls eeg samples from source, sanity checks and appends to buffer. 416 417 Parameters 418 ---------- 419 `None` 420 421 Returns 422 ------- 423 `None` 424 425 """ 426 427 # read in the data 428 eeg, timestamps = self.__eeg_source.get_samples() 429 eeg = np.array(eeg) 430 timestamps = np.array(timestamps) 431 432 if eeg.size == 0: 433 return 434 435 if eeg.ndim != 2: 436 logger.warning("discarded invalid eeg data") 437 return 438 439 # if time is in milliseconds, divide by 1000, works for sampling rates above 10Hz 440 try: 441 if self.time_units == "milliseconds": 442 timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))] 443 444 # If time units are not defined then define them 445 except Exception: 446 dif_low = -2 447 dif_high = -1 448 while timestamps[dif_high] - timestamps[dif_low] == 0: 449 dif_low -= 1 450 dif_high -= 1 451 452 if timestamps[dif_high] - timestamps[dif_low] > 0.1: 453 timestamps = [(timestamps[i] / 1000) for i in range(len(timestamps))] 454 self.time_units = "milliseconds" 455 else: 456 self.time_units = "seconds" 457 458 # apply time correction, this is essential for headsets like neurosity which have their own clock 459 time_correction = self.__eeg_source.time_correction() 460 timestamps = [timestamps[i] + time_correction for i in range(len(timestamps))] 461 462 self.__data_tank.add_raw_eeg(eeg.T, timestamps) 463 464 # Update latest EEG timestamp 465 self.latest_eeg_timestamp = timestamps[-1] 466 467 # 3b. Private methods for data processing and classification 468 def __process_resting_state_data(self): 469 """Handles the resting state data by packaging it and adding it to the data tank. 470 471 Parameters 472 ---------- 473 `None` 474 475 Returns 476 ------ 477 continue_flag : bool 478 Flag indicating to continue the while loop in step(). 479 480 """ 481 ( 482 self.bci_controller, 483 self.eeg_timestamps, 484 ) = self.__data_tank.get_raw_eeg() 485 486 resting_state_data = self.__paradigm.package_resting_state_data( 487 self.marker_data, 488 self.marker_timestamps, 489 self.bci_controller, 490 self.eeg_timestamps, 491 self.fsample, 492 ) 493 494 self.__data_tank.add_resting_state_data(resting_state_data) 495 496 return True # Continue processing 497 498 def __process_and_classify(self): 499 """Process the markers and classify the data. 500 501 Parameters 502 ---------- 503 `None` 504 505 Returns 506 ------- 507 success_string : str 508 String indicating if the processing and classification was successful. 509 Potential values are "Success", "Skip", "Wait". 510 - "Success": The processing and classification was successful. 511 - "Skip": EEG is either absent entirely or contains lost packets. 512 - "Wait": The processing is waiting for more data. 513 514 """ 515 516 eeg_start_time, eeg_end_time = self.__paradigm.get_eeg_start_and_end_times( 517 self.event_marker_buffer, self.event_timestamp_buffer 518 ) 519 520 # No we actually need to wait until we have all the data for these markers 521 eeg, timestamps = self.__data_tank.get_raw_eeg() 522 523 # Check if there is available EEG data 524 if len(eeg) == 0: 525 logger.warning("No EEG data available") 526 return "Skip" 527 528 # If the last timestamp is less than the end time, then we don't have the necessary EEG to process 529 if timestamps[-1] < eeg_end_time: 530 return "Wait" 531 532 # Check if EEG sampling is continuous over this time period 533 start_indices = np.where(timestamps > eeg_start_time)[0] 534 if len(start_indices) == 0: 535 logger.warning("No timestamps exceed eeg_start_time") 536 return "Skip" 537 start_index = start_indices[0] 538 end_index = np.where(timestamps < eeg_end_time)[0][-1] 539 540 time_diffs = np.diff(timestamps[start_index:end_index]) 541 if np.any(time_diffs > 2 / self.fsample): 542 logger.warning("Time gaps in EEG data") 543 return "Skip" 544 545 X, y = self.__paradigm.process_markers( 546 self.event_marker_buffer, 547 self.event_timestamp_buffer, 548 eeg, 549 timestamps, 550 self.fsample, 551 ) 552 553 sum_new_labeled_trials = np.sum(y != -1) 554 555 # Add the epochs to the data tank 556 self.__data_tank.add_epochs(X, y) 557 558 # Save epochs to temp_epochs file 559 if self.auto_save_epochs and self.online and sum_new_labeled_trials > 0: 560 paradigm_str = self.__paradigm.paradigm_name 561 562 with open(self.temp_epochs, "wb") as f: 563 np.savez( 564 f, 565 X=self.__data_tank.epochs, 566 y=self.__data_tank.labels, 567 paradigm=paradigm_str, 568 ) 569 570 # If either there are no labels OR iterative training is on, then make a prediction 571 if self.train_complete: 572 if -1 in y or self.__paradigm.iterative_training: 573 prediction = self._classifier.predict(X) 574 self.__send_prediction(prediction) 575 576 self.event_marker_buffer = [] 577 self.event_timestamp_buffer = [] 578 579 return "Success" 580 581 def __update_and_train_classifier(self): 582 """Updates the classifier if required. 583 584 Parameters 585 ---------- 586 `None` 587 588 Returns 589 ------- 590 continue_flag : bool 591 Flag indicating to continue the while loop in step(). 592 """ 593 if self.train_lock is False: 594 # Pull the epochs from the data tank and pass them to the classifier 595 X, y = self.__data_tank.get_epochs(latest=True) 596 597 # Remove epochs with label -1 598 ind_to_remove = [] 599 for i, label in enumerate(y): 600 if label == -1: 601 ind_to_remove.append(i) 602 X = np.delete(X, ind_to_remove, axis=0) 603 y = np.delete(y, ind_to_remove, axis=0) 604 605 # Check that there are epochs 606 if len(y) > 0: 607 self._classifier.add_to_train(X, y) 608 609 if self._classifier._check_ready_for_fit(): 610 self._classifier.fit() 611 self.train_complete = True 612 613 return True 614 615 # 3c. Private methods for event handling (trial and markers) and messaging 616 def __log_trial_start(self): 617 """Logs the start of a trial. 618 619 Parameters 620 ---------- 621 `None` 622 623 Returns 624 ------- 625 continue_flag : bool 626 Flag indicating to continue the while loop in step(). 627 628 """ 629 logger.debug("Trial started, incrementing marker count and continuing") 630 # Note that a marker occured, but do nothing else 631 return True # Continue processing 632 633 def __handle_trial_end(self): 634 """Handles the end of a trial. Processes and classifies trial data if required. 635 636 Parameters 637 ---------- 638 `None` 639 640 Returns 641 ------ 642 success_flag : bool 643 Flag indicating if the processing and classification was successful. 644 - Returns `True` if not classifying. 645 """ 646 # If we are classifying based on trials, then process the trial, 647 if self.__paradigm.classify_each_trial: 648 success_string = self.__process_and_classify() 649 if success_string == "Wait": 650 logger.debug("Processing of trial not run: waiting for more data") 651 return False 652 if success_string == "Skip": 653 logger.warning("Processing of trial failed: skipping trial") 654 self.event_marker_buffer = [] 655 self.event_timestamp_buffer = [] 656 self.marker_count += 1 657 return False 658 659 return True # Return True by default if not classifying 660 661 def __handle_event_marker(self, marker, timestamp): 662 """Processes and classifies event markers. 663 664 Parameters 665 ---------- 666 marker : str 667 Event marker string containing comma-separated values. 668 - Format depends on paradigm implementation. 669 timestamp : float 670 Timestamp of the marker in seconds (after time correction). 671 672 673 Returns 674 ------ 675 continue_flag : bool 676 Flag indicating to continue the while loop in step(). 677 678 """ 679 # Add the marker to the event marker buffer 680 self.event_marker_buffer.append(marker) 681 self.event_timestamp_buffer.append(timestamp) 682 683 # If classification is on epochs, then update epochs, maybe classify, and clear the buffer 684 if self.__paradigm.classify_each_epoch: 685 success_string = self.__process_and_classify() 686 if success_string == "Wait": 687 logger.debug("Processing of epoch not run: waiting for more data") 688 self.event_marker_buffer = [] 689 self.event_timestamp_buffer = [] 690 return False # Stop processing 691 elif success_string == "Skip": 692 logger.warning("Processing of epoch failed: skipping epoch") 693 self.event_marker_buffer = [] 694 self.event_timestamp_buffer = [] 695 696 self.marker_count += 1 697 return False # Stop processing 698 699 return True # Continue processing 700 701 def __handle_command_marker(self, marker: str) -> bool: 702 """Processes a command marker by invoking its associated method. 703 704 The command marker string is assumed to be in the self.marker_methods dictionary. 705 The associated method is retrieved and called. 706 The return value of the method is used to determine if processing should continue. 707 708 Parameters 709 ---------- 710 marker : str 711 A command marker string (assumed to be in self.marker_methods). 712 713 Returns 714 ------- 715 bool 716 A flag indicating if the processing should continue. 717 718 """ 719 command_marker_method = self.marker_methods[marker] # Retrieve method 720 continue_flag = command_marker_method() # Call method 721 722 # Debug level logging if continue_flag is FALSE 723 if continue_flag is False: 724 logger.debug("Command marker '%s' set continue_flag to FALSE", marker) 725 726 return continue_flag 727 728 def __send_prediction(self, prediction): 729 """Send a prediction to the messenger object. 730 731 Parameters 732 ---------- 733 `None` 734 735 Returns 736 ------- 737 `None` 738 739 """ 740 if self._messenger is not None: 741 logger.debug("Sending prediction: %s", prediction) 742 self._messenger.prediction(prediction) 743 elif self._messenger is None and self.online is True: 744 # If running in online mode and messenger is not available, log a warning 745 logger.warning( 746 "Messenger not available (self._messenger is None). Prediction not sent: %s", 747 prediction, 748 ) 749 750 def __load_temp_epochs_if_available(self, reload_data_time: int = 300): 751 """Load temp_epochs if available and valid. 752 753 Parameters 754 ---------- 755 reload_data_time : int, *optional* 756 Time in seconds of the last temp_epochs file to reload the data from. 757 Default is `300` seconds (5 minutes). 758 759 Returns 760 ------- 761 `None` 762 763 """ 764 self.temp_epochs = os.path.join( 765 os.path.dirname(os.path.dirname(__file__)), "temp_epochs.npz" 766 ) 767 768 if not os.path.exists(self.temp_epochs): 769 return 770 771 # If temp_epochs is older than `reload_data_time`, delete it 772 if os.path.getmtime(self.temp_epochs) < (time.time() - reload_data_time): 773 os.remove(self.temp_epochs) 774 logger.info("Deleted old temp_epochs file.") 775 return 776 777 # Load the temp_epochs file 778 with open(self.temp_epochs, "rb") as f: 779 npz = np.load(f, allow_pickle=True) 780 X = npz["X"] 781 y = npz["y"] 782 paradigm_str = npz["paradigm"].item() 783 784 # If the paradigm is different, delete the file 785 if self.__paradigm.paradigm_name != paradigm_str: 786 logger.warning( 787 "Paradigm in temp_epochs file does not match current paradigm. Deleting file." 788 ) 789 os.remove(self.temp_epochs) 790 return 791 792 # If the paradigm is the same, then add the epochs to the data tank 793 logger.info("Loading epochs from temp_epochs file.") 794 logger.info("X shape: %s", X.shape) 795 logger.info("y shape: %s", y.shape) 796 self.__data_tank.add_epochs(X, y) 797 798 # If there are epochs in the data tank, then train the classifier 799 if len(self.__data_tank.labels) > 0: 800 self.__update_and_train_classifier()
Class that holds, trials, processes, and classifies EEG data. This class is used for processing of continuous EEG data in trials of a defined length.
55 def __init__( 56 self, 57 classifier: GenericClassifier, 58 eeg_source: EegSource, 59 marker_source: MarkerSource | None = None, 60 paradigm: Paradigm | None = None, 61 data_tank: DataTank | None = None, 62 messenger: Messenger | None = None, 63 ): 64 """Initializes `BciController` class. 65 66 Parameters 67 ---------- 68 classifier : GenericClassifier 69 The classifier used by BciController. 70 eeg_source : EegSource 71 Source of EEG data and timestamps, this could be from a file or headset via LSL, etc. 72 marker_source : EegSource 73 Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc. 74 - Default is `None`. 75 paradigm : Paradigm 76 The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data. 77 data_tank : DataTank 78 DataTank object to handle the storage of EEG trials and labels. 79 - Default is `None`. 80 messenger: Messenger 81 Messenger object to handle events from BciController, ex: acknowledging markers and 82 predictions. 83 - Default is `None`. 84 85 """ 86 87 # Ensure the incoming dependencies are the right type 88 assert isinstance(classifier, GenericClassifier) 89 assert isinstance(eeg_source, EegSource) 90 assert isinstance(marker_source, MarkerSource | None) 91 assert isinstance(paradigm, Paradigm | None) 92 assert isinstance(data_tank, DataTank | None) 93 assert isinstance(messenger, Messenger | None) 94 95 self._classifier = classifier 96 self.__eeg_source = eeg_source 97 self.__marker_source = marker_source 98 self.__paradigm = paradigm 99 self.__data_tank = data_tank 100 self._messenger = messenger 101 102 self.headset_string = self.__eeg_source.name 103 self.fsample = self.__eeg_source.fsample 104 self.n_channels = self.__eeg_source.n_channels 105 self.ch_type = self.__eeg_source.channel_types 106 self.ch_units = self.__eeg_source.channel_units 107 self.channel_labels = self.__eeg_source.channel_labels 108 109 # Emily EGI fix 110 # Set default channel types if none 111 if self.ch_type is None: 112 logger.warning("Channel types are none, setting all to 'eeg'") 113 self.ch_type = ["eeg"] * self.n_channels 114 115 self.__data_tank.set_source_data( 116 self.headset_string, 117 self.fsample, 118 self.n_channels, 119 self.ch_type, 120 self.ch_units, 121 self.channel_labels, 122 ) 123 124 # Switch any trigger channels to stim, this is for mne/bids export (?) 125 self.ch_type = [type.replace("trg", "stim") for type in self.ch_type] 126 127 self._classifier.channel_labels = self.channel_labels 128 129 logger.info(self.headset_string) 130 logger.info(self.channel_labels) 131 132 # Initialize data and timestamp arrays to the right dimensions, but zero elements 133 self.marker_data = np.zeros((0, 1)) 134 self.marker_timestamps = np.zeros((0)) 135 self.bci_controller = np.zeros((0, self.n_channels)) 136 self.eeg_timestamps = np.zeros((0)) 137 138 # Initialize marker methods dictionary 139 self.marker_methods = { 140 MarkerTypes.DONE_RS_COLLECTION.value: self.__process_resting_state_data, 141 MarkerTypes.TRIAL_STARTED.value: self.__log_trial_start, 142 MarkerTypes.TRIAL_ENDS.value: self.__handle_trial_end, 143 MarkerTypes.TRAINING_COMPLETE.value: self.__update_and_train_classifier, 144 MarkerTypes.TRAIN_CLASSIFIER.value: self.__update_and_train_classifier, 145 MarkerTypes.UPDATE_CLASSIFIER.value: self.__update_and_train_classifier, 146 } 147 148 self.step_count = 0 149 self.ping_interval = 1000 150 self.n_samples = 0 151 self.time_units = ""
Initializes BciController class.
Parameters
- classifier (GenericClassifier): The classifier used by BciController.
- eeg_source (EegSource): Source of EEG data and timestamps, this could be from a file or headset via LSL, etc.
- marker_source (EegSource):
Source of Marker/Control data and timestamps, this could be from a file or Unity via LSL, etc.
- Default is
None.
- Default is
- paradigm (Paradigm): The paradigm used by BciController. This defines the processing and reshaping steps for the EEG data.
- data_tank (DataTank):
DataTank object to handle the storage of EEG trials and labels.
- Default is
None.
- Default is
- messenger (Messenger):
Messenger object to handle events from BciController, ex: acknowledging markers and
predictions.
- Default is
None.
- Default is
154 def setup( 155 self, 156 online=True, 157 train_complete=False, 158 train_lock=False, 159 auto_save_epochs=True, 160 ): 161 """Configure processing loop. 162 163 This should be called before starting the loop with run() or step(). 164 165 Calling after will reset the loop state. 166 167 The processing loop reads in EEG and marker data and processes it. 168 The loop can be run in "offline" or "online" modes: 169 - If in `online` mode, then the loop will continuously try to read 170 in data from the `BciController` object and process it. The loop will 171 terminate when `max_loops` is reached, or when manually terminated. 172 - If in `offline` mode, then the loop will read in all of the data 173 at once, process it, and then terminate. 174 175 Parameters 176 ---------- 177 online : bool, *optional* 178 Flag to indicate if the data will be processed in `online` mode. 179 - `True`: The data will be processed in `online` mode. 180 - `False`: The data will be processed in `offline` mode. 181 - Default is `True`. 182 train_complete : bool, *optional* 183 Flag to indicate if the classifier has been trained. 184 - `True`: The classifier has been trained. 185 - `False`: The classifier has not been trained. 186 - Default is `False`. 187 train_lock : bool, *optional* 188 Flag to indicate if the classifier is locked (ie. no more training). 189 - `True`: The classifier is locked. 190 - `False`: The classifier is not locked. 191 - Default is `False`. 192 auto_save_epochs : bool, *optional* 193 Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes. 194 - `True`: Epochs will be saved to a temp file. 195 - `False`: Epochs will not be saved to a temp file. 196 197 198 Returns 199 ------- 200 `None` 201 202 """ 203 self.online = online 204 self.train_complete = train_complete 205 self.train_lock = train_lock 206 self.auto_save_epochs = auto_save_epochs 207 208 # initialize the numbers of markers and trials to zero 209 self.marker_count = 0 210 self.current_num_trials = 0 211 self.n_trials = 0 212 213 self.num_online_selections = 0 214 self.online_selection_indices = [] 215 self.online_selections = [] 216 217 # Check for a temp_epochs file 218 if online: 219 self.__load_temp_epochs_if_available()
Configure processing loop.
This should be called before starting the loop with run() or step().
Calling after will reset the loop state.
The processing loop reads in EEG and marker data and processes it. The loop can be run in "offline" or "online" modes:
- If in
onlinemode, then the loop will continuously try to read in data from theBciControllerobject and process it. The loop will terminate whenmax_loopsis reached, or when manually terminated. - If in
offlinemode, then the loop will read in all of the data at once, process it, and then terminate.
Parameters
- online (bool, optional):
Flag to indicate if the data will be processed in
onlinemode.True: The data will be processed inonlinemode.False: The data will be processed inofflinemode.- Default is
True.
- train_complete (bool, optional):
Flag to indicate if the classifier has been trained.
True: The classifier has been trained.False: The classifier has not been trained.- Default is
False.
- train_lock (bool, optional):
Flag to indicate if the classifier is locked (ie. no more training).
True: The classifier is locked.False: The classifier is not locked.- Default is
False.
- auto_save_epochs (bool, optional):
Flag to indicate if labeled epochs should be automatically saved to a temp file so they can be reloaded if Bessy crashes.
True: Epochs will be saved to a temp file.False: Epochs will not be saved to a temp file.
Returns
None
221 def step(self): 222 """Runs a single BciController processing step. 223 224 See setup() for configuration of processing. 225 226 The method: 227 1. Pulls data from sources (EEG and markers). 228 2. Run a while loop to process markers as long as there are unprocessed markers. 229 3. The while loop processes the markers in the following order: 230 - First checks if the marker is a known command marker from self.marker_methods. 231 - Then checks if it's an event marker (contains commas) 232 - If neither, logs a warning about unknown marker type 233 3. If the marker is a command marker, handles it by calling __handle_command_marker(). 234 4. If the marker is an event marker, handles it by calling __handle_event_marker(). 235 5. If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker. 236 - Note: If there is an unknown marker type, the marker count is still incremented and processing continues. 237 6. If the command or event marker handling return continue_flag as False, break out of the while loop and end the step. 238 239 Parameters 240 ---------- 241 `None` 242 243 Returns 244 ------ 245 `None` 246 247 """ 248 # read from sources to get new data. 249 # This puts command markers in the marker_data array and 250 # event markers in the event_marker_strings array 251 self._pull_data_from_sources() 252 253 # Process markers while there are unprocessed markers 254 # REMOVE COMMENT: check if there is an available command marker, if not, break and wait for more data 255 while len(self.marker_timestamps) > self.marker_count: 256 # Get the current marker 257 current_step_marker = self.marker_data[self.marker_count] # String 258 current_timestamp = self.marker_timestamps[self.marker_count] # Float 259 260 # If marker is empty, skip it 261 if not current_step_marker: 262 logger.warning("Empty marker received") 263 self.marker_count += 1 264 continue 265 266 # If messenger is available, send feedback for each marker received 267 if self._messenger is not None: 268 self._messenger.marker_received(current_step_marker) 269 270 # Process markers in order specified in the docstrings 271 # First check if it's a known command marker 272 if current_step_marker in self.marker_methods: 273 continue_flag = self.__handle_command_marker(current_step_marker) 274 # Then check if it's an event marker (contains commas) 275 elif "," in current_step_marker: 276 continue_flag = self.__handle_event_marker( 277 current_step_marker, current_timestamp 278 ) 279 # Otherwise, log a warning about unknown marker type 280 else: 281 # Log warning for unknown marker types 282 logger.warning("Unknown marker type received: %s", current_step_marker) 283 continue_flag = True 284 285 # Check if we should continue processing markers in the while loop 286 # if continue_flag is False, then break out of the while loop 287 # else, increment the marker count and process the next marker 288 if continue_flag is False: 289 break 290 else: 291 logger.info("Processed Marker: %s", current_step_marker) 292 self.marker_count += 1 293 294 self.step_count += 1 295 if self.step_count % self.ping_interval == 0: 296 if self._messenger is not None: 297 self._messenger.ping()
Runs a single BciController processing step.
See setup() for configuration of processing.
The method:
- Pulls data from sources (EEG and markers).
- Run a while loop to process markers as long as there are unprocessed markers.
- The while loop processes the markers in the following order:
- First checks if the marker is a known command marker from self.marker_methods.
- Then checks if it's an event marker (contains commas)
- If neither, logs a warning about unknown marker type
- If the marker is a command marker, handles it by calling __handle_command_marker().
- If the marker is an event marker, handles it by calling __handle_event_marker().
- If the command or event marker handling return continue_flag as True, increment the marker count and process the next marker.
- Note: If there is an unknown marker type, the marker count is still incremented and processing continues.
- If the command or event marker handling return continue_flag as False, break out of the while loop and end the step.
Parameters
None
Returns
None
299 def run(self, max_loops: int = 1000000, ping_interval: int = 100): 300 """Runs BciController processing in a loop. 301 302 See setup() for configuration of processing. 303 304 Parameters 305 ---------- 306 max_loops : int, *optional* 307 Maximum number of loops to run, default is `1000000`. 308 ping_interval : int, *optional* 309 Number of steps between each messenger ping. 310 311 Returns 312 ------ 313 `None` 314 315 """ 316 # if offline, then all data is already loaded, only need to loop once 317 if self.online is False: 318 self.loops = max_loops - 1 319 else: 320 self.loops = 0 321 322 self.ping_interval = ping_interval 323 324 # Initialize the event marker buffer 325 self.event_marker_buffer = [] 326 self.event_timestamp_buffer = [] 327 328 # start the main loop, stops after pulling new data, max_loops times 329 while self.loops < max_loops: 330 # print out loop status 331 if self.loops % 100 == 0: 332 logger.debug(self.loops) 333 334 if self.loops == max_loops - 1: 335 logger.debug("last loop") 336 337 # read from sources and process 338 self.step() 339 340 # Wait a short period of time and then try to pull more data 341 if self.online: 342 time.sleep(0.00001) 343 344 self.loops += 1
Runs BciController processing in a loop.
See setup() for configuration of processing.
Parameters
- max_loops (int, optional):
Maximum number of loops to run, default is
1000000. - ping_interval (int, optional): Number of steps between each messenger ping.
Returns
None