bci_essentials.channel_selection
Channel selection methods for BCI performance improvement.
This module includes functions for selecting channels in order to improve BCI performance.
Notes
The EEG data input for each function is a set of trials. The data must
be of the shape n_trials x n_channels x n_samples, where:
- n_trials = number of trials
- n_channels = number of channels
- n_samples = number of samples
1"""Channel selection methods for BCI performance improvement. 2 3This module includes functions for selecting channels in order to 4improve BCI performance. 5 6Notes 7----- 8The EEG data input for each function is a set of trials. The data must 9be of the shape `n_trials x n_channels x n_samples`, where: 10 - n_trials = number of trials 11 - n_channels = number of channels 12 - n_samples = number of samples 13 14""" 15 16from joblib import Parallel, delayed 17import time 18import numpy as np 19import pandas as pd 20from dataclasses import dataclass, field 21from .utils.logger import Logger # Logger wrapper 22from sklearn.pipeline import Pipeline 23 24# Instantiate a logger for the module at the default level of logging.INFO 25# Logs to bci_essentials.__module__) where __module__ is the name of the module 26logger = Logger(name=__name__) 27 28 29@dataclass 30class ChannelSelectionOutput: 31 """Dataclass to store output from channel selection. 32 33 Parameters 34 ---------- 35 best_channel_subset : list of `str` 36 The best channel subset from the list of 'channel_labels'. 37 best_model : classifier 38 The trained classification model. 39 best_preds : numpy.ndarray 40 The predictions from the model. 41 best_accuracy : float 42 The accuracy of the trained classification model. 43 best_precision : float 44 The precision of the trained classification model. 45 best_recall : float 46 The recall of the trained classification model. 47 results_df : pandas.DataFrame 48 A dataframe containing the performance metrics at each step. 49 50 """ 51 52 best_channel_subset: list = field(default_factory=list) 53 best_model: Pipeline = field(default=None) 54 best_preds: np.ndarray = field(default_factory=np.ndarray) 55 best_accuracy: float = field(default=0.0) 56 best_precision: float = field(default=0.0) 57 best_recall: float = field(default=0.0) 58 results_df: pd.DataFrame = field(default_factory=pd.DataFrame) 59 60 61def channel_selection_by_method( 62 kernel_func, 63 X, 64 y, 65 channel_labels, 66 method="SBS", 67 metric="accuracy", 68 initial_channels=[], 69 max_time=999, 70 min_channels=1, 71 max_channels=999, 72 performance_delta=0.001, 73 n_jobs=1, 74 record_performance=True, 75): 76 """Passes the BCI kernel function into a wrapper defined by `method`. 77 78 Parameters 79 ---------- 80 kernel_func : function 81 The classification kernel function which does feature extraction 82 and classification. 83 Different functions are used for MI, P300, SSVEP, etc. 84 X : numpy.ndarray 85 Training data for the classifier as trials of EEG data. 86 3D array containing data with `float` type. 87 88 shape = (`n_trials`,`n_channels`,`n_samples`) 89 y : numpy.ndarray 90 Training labels for the classifier. 91 1D array. 92 93 shape = (`n_trials`) 94 channel_labels : list of `str` 95 The set of channel labels corresponding to `n_channels`. 96 A list of strings with length = `n_channels`. 97 method = str, *optional* 98 The wrapper method. Options are `"SBS"` or `"SBFS"`. 99 - Default is `"SBS"`. 100 metric : str, *optional* 101 The metric used to measure the "goodness" of the trained classifier. 102 - Default is `"accuracy"`. 103 initial_channels : list of `str`, *optional* 104 Initial guess of channels. 105 - Defaults is `[]`. Assigns an empty set for forward selections, 106 and a full set for backward selections. 107 max_time : int, *optional* 108 The maxiumum amount of time, in seconds, that the function will 109 search for the optimal solution. 110 - Default is `999` seconds. 111 min_channels : int, *optional* 112 The minimum number of channels. 113 - Default is `1`. 114 max_channels : int, *optional* 115 The maximum number of channels. 116 - Default is `999`. 117 performance_delta : float, *optional* 118 The performance delta under which the algorithm is considered to 119 be close enough to optimal. 120 - Default is `0.001`. 121 n_jobs : int, *optional* 122 The number of threads to dedicate to this calculation. 123 - Default is `1`. 124 record_performance : bool, *optional* 125 Whether or not to record the performance of the channel selection 126 - Default is `True`. 127 128 Returns 129 ------- 130 channelSelectionOutput : ChannelSelectionOutput 131 ChannelSelectionOutput object containing the following attributes: 132 best_channel_subset : list of `str` 133 The new best channel subset from the list of `channel_labels`. 134 self.clf : classifier 135 The trained classification model. 136 preds : numpy.ndarray 137 The predictions from the model. 138 1D array with the same shape as `y`. 139 shape = (`n_trials`) 140 accuracy : float 141 The accuracy of the trained classification model. 142 precision : float 143 The precision of the trained classification model. 144 recall : float 145 The recall of the trained classification model. 146 results_df : pandas.DataFrame 147 The dataframe containing the results of each step of channel selection. 148 149 """ 150 151 # max length can't be greater than the length of channel labels 152 if max_channels > len(channel_labels): 153 logger.debug( 154 "Maximum number of channels must be less than or equal to " 155 + "the number of channels. Setting to number of channels." 156 ) 157 max_channels = len(channel_labels) 158 159 # min length can't be less than 1 160 if min_channels < 1: 161 logger.debug("Minimum number of channels must be greater than 0. Setting to 1.") 162 min_channels = 1 163 164 logger.debug("Running channel selection method: %s", method) 165 if method == "SBS": 166 if initial_channels == []: 167 initial_channels = channel_labels 168 logger.debug("Initial subset: %s", initial_channels) 169 170 # pass arguments to SBS 171 return __sbs( 172 kernel_func, 173 X, 174 y, 175 channel_labels=channel_labels, 176 metric=metric, 177 initial_channels=initial_channels, 178 max_time=max_time, 179 min_channels=min_channels, 180 max_channels=max_channels, 181 performance_delta=performance_delta, 182 n_jobs=n_jobs, 183 record_performance=record_performance, 184 ) 185 186 elif method == "SFS": 187 logger.debug("Initial subset: %s", initial_channels) 188 189 # pass arguments to SBS 190 return __sfs( 191 kernel_func, 192 X, 193 y, 194 channel_labels=channel_labels, 195 metric=metric, 196 initial_channels=initial_channels, 197 max_time=max_time, 198 min_channels=min_channels, 199 max_channels=max_channels, 200 performance_delta=performance_delta, 201 n_jobs=n_jobs, 202 record_performance=record_performance, 203 ) 204 205 elif method == "SBFS": 206 if initial_channels == []: 207 initial_channels = channel_labels 208 logger.debug("Initial subset: %s", initial_channels) 209 210 # pass arguments to SBS 211 return __sbfs( 212 kernel_func, 213 X, 214 y, 215 channel_labels=channel_labels, 216 metric=metric, 217 initial_channels=initial_channels, 218 max_time=max_time, 219 min_channels=min_channels, 220 max_channels=max_channels, 221 performance_delta=performance_delta, 222 n_jobs=n_jobs, 223 record_performance=record_performance, 224 ) 225 226 elif method == "SFFS": 227 logger.debug("Initial subset: %s", initial_channels) 228 229 # pass arguments to SBS 230 return __sffs( 231 kernel_func, 232 X, 233 y, 234 channel_labels=channel_labels, 235 metric=metric, 236 initial_channels=initial_channels, 237 max_time=max_time, 238 min_channels=min_channels, 239 max_channels=max_channels, 240 performance_delta=performance_delta, 241 n_jobs=n_jobs, 242 record_performance=record_performance, 243 ) 244 245 246def __check_stopping_criterion( 247 algorithm, 248 current_time, 249 n_channels, 250 current_performance_delta, 251 max_time, 252 min_channels, 253 max_channels, 254 performance_delta, 255): 256 """Function to check if a stopping criterion has been met. 257 258 Parameters 259 ---------- 260 algorithm : str 261 The algorithm being used for channel selection. 262 current_time : float 263 The time elapsed since the start of the channel selection method. 264 n_channels : int 265 The number of channels in the current iteration of the new best channel 266 subset (`len(new_channel_subset)`). 267 current_performance_delta : float 268 The performance delta between the current iteration and the previous. 269 max_time : int 270 The maxiumum amount of time, in seconds, that the function will 271 search for the optimal solution. 272 min_channels : int 273 The minimum number of channels. 274 max_channels : int 275 The maximum number of channels. 276 performance_delta : float 277 The performance delta under which the algorithm is considered to 278 be close enough to optimal. 279 280 Returns 281 ------- 282 *bool* 283 Has stopping criterion been met (`True`) or not (`False`). 284 285 """ 286 if current_time > max_time: 287 logger.debug("Stopping based on time") 288 return True 289 290 if algorithm == "SBS" or algorithm == "SBFS": 291 if n_channels <= min_channels: 292 logger.debug("Stopping because minimum number of channels reached") 293 return True 294 295 if algorithm == "SFS" or algorithm == "SFFS": 296 if n_channels >= max_channels: 297 logger.debug("Stopping because maximum number of channels reached") 298 return True 299 300 if current_performance_delta < performance_delta: 301 logger.debug("Stopping because performance improvements are declining") 302 return True 303 304 return False 305 306 307def __sfs( 308 kernel_func, 309 X, 310 y, 311 channel_labels, 312 metric, 313 initial_channels, 314 max_time, 315 min_channels, 316 max_channels, 317 performance_delta, 318 n_jobs, 319 record_performance, 320): 321 """ 322 The Sequential Forward Selection (SFS) method for channel selection. 323 324 Parameters 325 ---------- 326 kernel_func : function 327 The classification kernel function which does feature extraction 328 and classification. 329 Different functions are used for MI, P300, SSVEP, etc. 330 X : numpy.ndarray 331 Training data for the classifier as trials of EEG data. 332 3D array containing data with `float` type. 333 334 shape = (`n_trials`,`n_channels`,`n_samples`) 335 y : numpy.ndarray 336 Training labels for the classifier. 337 1D array. 338 339 shape = (`n_trials`) 340 channel_labels: list of `str` 341 The set of channel labels corresponding to `n_channels`. 342 A list of strings with length = `n_channels`. 343 metric : str 344 The metric used to measure the "goodness" of the trained classifier. 345 initial_channels : list of `str` 346 Initial guess of channels. 347 max_time : int 348 The maxiumum amount of time, in seconds, that the function will 349 search for the optimal solution. 350 min_channels : int 351 The minimum number of channels. 352 max_channels : int 353 The maximum number of channels. 354 performance_delta : float 355 The performance delta under which the algorithm is considered to 356 be close enough to optimal. 357 n_jobs : int 358 The number of threads to dedicate to this calculation. 359 record_performance : bool 360 Flag on whether or not to record performance at each step. 361 362 363 Returns 364 ------- 365 channelSelectionOutput : ChannelSelectionOutput 366 ChannelSelectionOutput object containing the following attributes: 367 best_channel_subset : list of `str` 368 The new best channel subset from the list of `channel_labels`. 369 self.clf : classifier 370 The trained classification model. 371 preds : numpy.ndarray 372 The predictions from the model. 373 1D array with the same shape as `y`. 374 375 shape = (`n_trials`) 376 accuracy : float 377 The accuracy of the trained classification model. 378 precision : float 379 The precision of the trained classification model. 380 recall : float 381 The recall of the trained classification model. 382 results_df : pandas.DataFrame 383 The dataframe containing the results of each step of channel selection. 384 """ 385 results_df = pd.DataFrame( 386 columns=[ 387 "Step", 388 "Time", 389 "N Channels", 390 "Channel Subset", 391 "Unique Combinations Tested in Step", 392 "Accuracy", 393 "Precision", 394 "Recall", 395 ] 396 ) 397 step = 1 398 399 start_time = time.time() 400 401 n_trials, n_channels, n_samples = X.shape 402 sfs_subset = [] 403 404 for i, c in enumerate(channel_labels): 405 if c in initial_channels: 406 sfs_subset.append(i) 407 408 previous_performance = 0 409 410 stop_criterion = False 411 412 # Get the performance of the initial subset, if possible 413 try: 414 initial_results = kernel_func(X[:, sfs_subset, :], y) 415 initial_model = initial_results.model 416 initial_preds = initial_results.cv_preds 417 initial_accuracy = initial_results.accuracy 418 initial_precision = initial_results.precision 419 initial_recall = initial_results.recall 420 421 if metric == "accuracy": 422 initial_performance = initial_accuracy 423 elif metric == "precision": 424 initial_performance = initial_precision 425 elif metric == "recall": 426 initial_performance = initial_recall 427 428 # Best 429 best_channel_subset = initial_channels 430 best_model = initial_model 431 best_performance = initial_performance 432 best_preds = initial_preds 433 best_accuracy = initial_accuracy 434 best_precision = initial_precision 435 best_recall = initial_recall 436 437 # If not possible then set the initial performance to 0 438 except ValueError: 439 best_channel_subset = [] 440 best_model = None 441 best_performance = 0 442 best_preds = [] 443 best_accuracy = 0 444 best_precision = 0 445 best_recall = 0 446 447 preds = [] 448 accuracy = 0 449 precision = 0 450 recall = 0 451 452 while stop_criterion is False: 453 sets_to_try = [] 454 X_to_try = [] 455 for channel in range(n_channels): 456 if channel not in sfs_subset: 457 set_to_try = sfs_subset.copy() 458 set_to_try.append(channel) 459 sets_to_try.append(set_to_try) 460 461 # Get the new subset of data 462 new_subset_data = np.zeros((n_trials, len(set_to_try), n_samples)) 463 for subset_idx, channel_number in enumerate(set_to_try): 464 channel_data = X[:, channel_number, :] 465 new_subset_data[:, subset_idx, :] = channel_data 466 467 # Add to list of all subsets of X to try 468 X_to_try.append(new_subset_data) 469 470 # This handles the multiprocessing to check multiple channel combinations at once if n_jobs > 1 471 outputs = Parallel(n_jobs=n_jobs)( 472 delayed(kernel_func)(Xtest, y) for Xtest in X_to_try 473 ) 474 475 models = [] 476 predictions = [] 477 accuracies = [] 478 precisions = [] 479 recalls = [] 480 481 # Extract the outputs 482 for output in outputs: 483 models.append(output.model) 484 predictions.append(output.cv_preds) 485 accuracies.append(output.accuracy) 486 precisions.append(output.precision) 487 recalls.append(output.recall) 488 489 # Get the performance metric 490 if metric == "accuracy": 491 performances = accuracies 492 elif metric == "precision": 493 performances = precisions 494 elif metric == "recall": 495 performances = recalls 496 else: 497 logger.warning("Performance metric invalid, defaulting to accuracy") 498 performances = accuracies 499 500 # Get the index of the best X tried in this round 501 best_set_index = accuracies.index(np.max(performances)) 502 503 sfs_subset = sets_to_try[best_set_index] 504 new_channel_subset = [channel_labels[c] for c in sfs_subset] 505 model = models[best_set_index] 506 preds = predictions[best_set_index] 507 accuracy = accuracies[best_set_index] 508 # best_overall_accuracy = accuracy 509 precision = precisions[best_set_index] 510 recall = recalls[best_set_index] 511 logger.debug("New subset: %s", new_channel_subset) 512 logger.debug("Accuracy: %s", accuracy) 513 logger.debug("Accuracies: %s", [float(acc) for acc in accuracies]) 514 515 if metric == "accuracy": 516 current_performance = accuracy 517 518 p_delta = current_performance - previous_performance 519 previous_performance = current_performance 520 521 if current_performance > best_performance: 522 best_channel_subset = new_channel_subset 523 best_model = model 524 best_performance = current_performance 525 best_preds = preds 526 best_accuracy = accuracy 527 best_precision = precision 528 best_recall = recall 529 elif current_performance >= best_performance and len(new_channel_subset) < len( 530 best_channel_subset 531 ): 532 best_channel_subset = new_channel_subset 533 best_model = model 534 best_performance = current_performance 535 best_preds = preds 536 best_accuracy = accuracy 537 best_precision = precision 538 best_recall = recall 539 540 new_channel_subset.sort() 541 results_df.loc[step] = [ 542 step, 543 time.time() - start_time, 544 len(new_channel_subset), 545 "".join(new_channel_subset), 546 len(sets_to_try), 547 accuracy, 548 precision, 549 recall, 550 ] 551 552 step += 1 553 554 stop_criterion = __check_stopping_criterion( 555 "SFS", 556 time.time() - start_time, 557 len(new_channel_subset), 558 p_delta, 559 max_time, 560 min_channels, 561 max_channels, 562 performance_delta, 563 ) 564 565 new_channel_subset = [channel_labels[c] for c in sfs_subset] 566 567 logger.debug("Best channel subset: %s", best_channel_subset) 568 logger.debug("%s : %s", metric, best_performance) 569 logger.debug("Time to optimal subset: %s s", time.time() - start_time) 570 571 if record_performance is True: 572 logger.info(results_df) 573 574 # Get the best model 575 576 return ChannelSelectionOutput( 577 best_channel_subset, 578 best_model, 579 best_preds, 580 best_accuracy, 581 best_precision, 582 best_recall, 583 results_df, 584 ) 585 586 587def __sbs( 588 kernel_func, 589 X, 590 y, 591 channel_labels, 592 metric, 593 initial_channels, 594 max_time, 595 min_channels, 596 max_channels, 597 performance_delta, 598 n_jobs, 599 record_performance, 600): 601 """The Sequential Backward Selection (SBS) method for channel selection. 602 603 Parameters 604 ---------- 605 kernel_func : function 606 The classification kernel function which does feature extraction 607 and classification. 608 Different functions are used for MI, P300, SSVEP, etc. 609 X : numpy.ndarray 610 Training data for the classifier as trials of EEG data. 611 3D array containing data with `float` type. 612 613 shape = (`n_trials`,`n_channels`,`n_samples`) 614 y : numpy.ndarray 615 Training labels for the classifier. 616 1D array. 617 618 shape = (`n_trials`) 619 channel_labels: list of `str` 620 The set of channel labels corresponding to `n_channels`. 621 A list of strings with length = `n_channels`. 622 metric : str 623 The metric used to measure the "goodness" of the trained classifier. 624 initial_channels : list of `str` 625 Initial guess of channels. 626 max_time : int 627 The maxiumum amount of time, in seconds, that the function will 628 search for the optimal solution. 629 min_channels : int 630 The minimum number of channels. 631 max_channels : int 632 The maximum number of channels. 633 performance_delta : float 634 The performance delta under which the algorithm is considered to 635 be close enough to optimal. 636 n_jobs : int 637 The number of threads to dedicate to this calculation. 638 record_performance : bool 639 Flag on whether or not to record performance metrics at each step. 640 641 Returns 642 ------- 643 channelSelectionOutput : ChannelSelectionOutput 644 ChannelSelectionOutput object containing the following attributes: 645 best_channel_subset : list of `str` 646 The new best channel subset from the list of `channel_labels`. 647 self.clf : classifier 648 The trained classification model. 649 preds : numpy.ndarray 650 The predictions from the model. 651 1D array with the same shape as `y`. 652 shape = (`n_trials`) 653 accuracy : float 654 The accuracy of the trained classification model. 655 precision : float 656 The precision of the trained classification model. 657 recall : float 658 The recall of the trained classification model. 659 results_df : pandas.DataFrame 660 The dataframe containing the results of each step of channel selection. 661 662 663 """ 664 665 results_df = pd.DataFrame( 666 columns=[ 667 "Step", 668 "Time", 669 "N Channels", 670 "Channel Subset", 671 "Unique Combinations Tested in Step", 672 "Accuracy", 673 "Precision", 674 "Recall", 675 ] 676 ) 677 step = 1 678 679 if len(initial_channels) <= min_channels: 680 initial_channels = channel_labels 681 682 start_time = time.time() 683 684 n_trials, n_channels, n_samples = X.shape 685 sbs_subset = [] 686 all_sets_tried = [] # set of all channels that have been tried 687 688 for i, c in enumerate(channel_labels): 689 if c in initial_channels: 690 sbs_subset.append(i) 691 692 # Get the performance of the initial subset 693 initial_results = kernel_func(X[:, sbs_subset, :], y) 694 initial_model = initial_results.model 695 initial_preds = initial_results.cv_preds 696 initial_accuracy = initial_results.accuracy 697 initial_precision = initial_results.precision 698 initial_recall = initial_results.recall 699 700 if metric == "accuracy": 701 initial_performance = initial_accuracy 702 elif metric == "precision": 703 initial_performance = initial_precision 704 elif metric == "recall": 705 initial_performance = initial_recall 706 707 # Best 708 best_channel_subset = initial_channels 709 best_model = initial_model 710 best_performance = initial_performance 711 best_preds = initial_preds 712 best_accuracy = initial_accuracy 713 best_precision = initial_precision 714 best_recall = initial_recall 715 716 previous_performance = 0 717 718 stop_criterion = False 719 720 preds = [] 721 accuracy = 0 722 precision = 0 723 recall = 0 724 725 while stop_criterion is False: 726 # Exclusion Step 727 sets_to_try = [] 728 X_to_try = [] 729 for channel in sbs_subset: 730 set_to_try = sbs_subset.copy() 731 set_to_try.remove(channel) 732 set_to_try.sort() 733 734 # Only try sets that have not been tried before 735 if set_to_try not in all_sets_tried: 736 sets_to_try.append(set_to_try) 737 all_sets_tried.append(set_to_try) 738 else: 739 continue 740 741 # Get the new subset of data 742 new_subset_data = np.zeros((n_trials, len(set_to_try), n_samples)) 743 for subset_idx, channel_number in enumerate(set_to_try): 744 channel_data = X[:, channel_number, :] 745 new_subset_data[:, subset_idx, :] = channel_data 746 747 # Add to list of all subsets of X to try 748 X_to_try.append(new_subset_data) 749 750 # run the kernel function on all cores 751 outputs = Parallel(n_jobs=n_jobs)( 752 delayed(kernel_func)(Xtest, y) for Xtest in X_to_try 753 ) 754 755 # [all_sets_tried.append(set.sort()) for set in sets_to_try] 756 757 models = [] 758 predictions = [] 759 accuracies = [] 760 precisions = [] 761 recalls = [] 762 763 # Extract the outputs 764 for output in outputs: 765 models.append(output.model) 766 predictions.append(output.cv_preds) 767 accuracies.append(output.accuracy) 768 precisions.append(output.precision) 769 recalls.append(output.recall) 770 771 # Get the performance metric 772 if metric == "accuracy": 773 performances = accuracies 774 elif metric == "precision": 775 performances = precisions 776 elif metric == "recall": 777 performances = recalls 778 else: 779 logger.warning("Performance metric invalid, defaulting to accuracy") 780 performances = accuracies 781 782 best_set_index = accuracies.index(np.max(performances)) 783 784 sbs_subset = sets_to_try[best_set_index] 785 new_channel_subset = [channel_labels[c] for c in sbs_subset] 786 model = models[best_set_index] 787 preds = predictions[best_set_index] 788 accuracy = accuracies[best_set_index] 789 # best_overall_accuracy = accuracy 790 precision = precisions[best_set_index] 791 recall = recalls[best_set_index] 792 793 current_performance = performances[best_set_index] 794 logger.debug("Removed a channel") 795 logger.debug("New subset: %s", new_channel_subset) 796 logger.debug("Accuracy: %s", accuracy) 797 logger.debug("Accuracies: %s", [float(acc) for acc in accuracies]) 798 799 p_delta = current_performance - previous_performance 800 previous_performance = current_performance 801 802 if current_performance > best_performance: 803 best_channel_subset = new_channel_subset 804 best_model = model 805 best_performance = current_performance 806 best_preds = preds 807 best_accuracy = accuracy 808 best_precision = precision 809 best_recall = recall 810 elif current_performance >= best_performance and len(new_channel_subset) < len( 811 best_channel_subset 812 ): 813 best_channel_subset = new_channel_subset 814 best_model = model 815 best_performance = current_performance 816 best_preds = preds 817 best_accuracy = accuracy 818 best_precision = precision 819 best_recall = recall 820 821 new_channel_subset.sort() 822 results_df.loc[step] = [ 823 step, 824 time.time() - start_time, 825 len(new_channel_subset), 826 "".join(new_channel_subset), 827 len(sets_to_try), 828 accuracy, 829 precision, 830 recall, 831 ] 832 833 step += 1 834 835 # Break if SBS subset is 1 channel 836 if len(sbs_subset) == 1: 837 break 838 839 stop_criterion = __check_stopping_criterion( 840 "SBS", 841 time.time() - start_time, 842 len(new_channel_subset), 843 p_delta, 844 max_time, 845 min_channels, 846 max_channels, 847 performance_delta, 848 ) 849 850 new_channel_subset = [channel_labels[c] for c in sbs_subset] 851 852 logger.debug("Best channel subset: %s", best_channel_subset) 853 logger.debug("%s : %s", metric, best_performance) 854 logger.debug("Time to optimal subset: %s s", time.time() - start_time) 855 856 if record_performance is True: 857 logger.info(results_df) 858 859 return ChannelSelectionOutput( 860 best_channel_subset, 861 best_model, 862 best_preds, 863 best_accuracy, 864 best_precision, 865 best_recall, 866 results_df, 867 ) 868 869 870def __sbfs( 871 kernel_func, 872 X, 873 y, 874 channel_labels, 875 metric, 876 initial_channels, 877 max_time, 878 min_channels, 879 max_channels, 880 performance_delta, 881 n_jobs, 882 record_performance, 883): 884 """The Sequential Backward Floating Selection (SBFS) method for channel selection. 885 886 Parameters 887 ---------- 888 kernel_func : function 889 The classification kernel function which does feature extraction 890 and classification. 891 Different functions are used for MI, P300, SSVEP, etc. 892 X : numpy.ndarray 893 Training data for the classifier as trials of EEG data. 894 3D array containing data with `float` type. 895 896 shape = (`n_trials`,`n_channels`,`n_samples`) 897 y : numpy.ndarray 898 Training labels for the classifier. 899 1D array. 900 901 shape = (`n_trials`) 902 channel_labels: list of `str` 903 The set of channel labels corresponding to `n_channels`. 904 A list of strings with length = `n_channels`. 905 metric : str 906 The metric used to measure the "goodness" of the trained classifier. 907 initial_channels : list of `str` 908 Initial guess of channels. 909 max_time : int 910 The maxiumum amount of time, in seconds, that the function will 911 search for the optimal solution. 912 min_channels : int 913 The minimum number of channels. 914 max_channels : int 915 The maximum number of channels. 916 performance_delta : float 917 The performance delta under which the algorithm is considered to 918 be close enough to optimal. 919 n_jobs : int 920 The number of threads to dedicate to this calculation. 921 record_performance : bool 922 Flag on whether or not to record performance metrics at each step. 923 924 Returns 925 ------- 926 channelSelectionOutput : ChannelSelectionOutput 927 ChannelSelectionOutput object containing the following attributes: 928 best_channel_subset : list of `str` 929 The new best channel subset from the list of `channel_labels`. 930 self.clf : classifier 931 The trained classification model. 932 preds : numpy.ndarray 933 The predictions from the model. 934 1D array with the same shape as `y`. 935 shape = (`n_trials`) 936 accuracy : float 937 The accuracy of the trained classification model. 938 precision : float 939 The precision of the trained classification model. 940 recall : float 941 The recall of the trained classification model. 942 results_df : pandas.DataFrame 943 The dataframe containing the results of each step of channel selection. 944 945 946 """ 947 results_df = pd.DataFrame( 948 columns=[ 949 "Step", 950 "Time", 951 "N Channels", 952 "Channel Subset", 953 "Unique Combinations Tested in Step", 954 "Accuracy", 955 "Precision", 956 "Recall", 957 ] 958 ) 959 step = 1 960 961 if len(initial_channels) <= min_channels or len(initial_channels) == 0: 962 initial_channels = channel_labels 963 964 start_time = time.time() 965 966 n_trials, n_channels, n_samples = X.shape 967 sbfs_subset = [] 968 all_sets_tried = [] # set of all channels that have been tried 969 970 for i, c in enumerate(channel_labels): 971 if c in initial_channels: 972 sbfs_subset.append(i) 973 974 performance_at_n_channels = np.zeros(len(channel_labels)) 975 best_subset_at_n_channels = [0] * len(channel_labels) 976 977 previous_performance = 0 978 979 stop_criterion = False 980 981 # Get the performance of the initial subset, if possible 982 try: 983 initial_results = kernel_func(X[:, sbfs_subset, :], y) 984 initial_model = initial_results.model 985 initial_preds = initial_results.cv_preds 986 initial_accuracy = initial_results.accuracy 987 initial_precision = initial_results.precision 988 initial_recall = initial_results.recall 989 990 if metric == "accuracy": 991 initial_performance = initial_accuracy 992 elif metric == "precision": 993 initial_performance = initial_precision 994 elif metric == "recall": 995 initial_performance = initial_recall 996 997 # Best 998 best_channel_subset = initial_channels 999 best_model = initial_model 1000 best_performance = initial_performance 1001 best_preds = initial_preds 1002 best_accuracy = initial_accuracy 1003 best_precision = initial_precision 1004 best_recall = initial_recall 1005 1006 performance_at_n_channels[len(initial_channels) - 1] = initial_performance 1007 best_subset_at_n_channels[len(initial_channels) - 1] = initial_channels 1008 1009 # If not possible then set the initial performance to 0 1010 except ValueError: 1011 best_channel_subset = [] 1012 best_model = None 1013 best_performance = 0 1014 best_preds = [] 1015 best_accuracy = 0 1016 best_precision = 0 1017 best_recall = 0 1018 1019 preds = [] 1020 accuracy = 0 1021 precision = 0 1022 recall = 0 1023 1024 while stop_criterion is False: 1025 # Exclusion Step 1026 sets_to_try = [] 1027 X_to_try = [] 1028 for c in sbfs_subset: 1029 set_to_try = sbfs_subset.copy() 1030 set_to_try.remove(c) 1031 set_to_try.sort() 1032 1033 # Only try sets that have not been tried before 1034 if set_to_try not in all_sets_tried: 1035 sets_to_try.append(set_to_try) 1036 all_sets_tried.append(set_to_try) 1037 else: 1038 continue 1039 1040 # Get the new subset of data 1041 new_subset_data = np.zeros((n_trials, len(set_to_try), n_samples)) 1042 for subset_idx, channel_number in enumerate(set_to_try): 1043 channel_data = X[:, channel_number, :] 1044 new_subset_data[:, subset_idx, :] = channel_data 1045 1046 # Add to list of all subsets of X to try 1047 X_to_try.append(new_subset_data) 1048 1049 # run the kernel function on all cores 1050 outputs = Parallel(n_jobs=n_jobs)( 1051 delayed(kernel_func)(Xtest, y) for Xtest in X_to_try 1052 ) 1053 1054 # [all_sets_tried.append(set.sort()) for set in sets_to_try] 1055 1056 models = [] 1057 predictions = [] 1058 accuracies = [] 1059 precisions = [] 1060 recalls = [] 1061 1062 # Extract the outputs 1063 for output in outputs: 1064 models.append(output.model) 1065 predictions.append(output.cv_preds) 1066 accuracies.append(output.accuracy) 1067 precisions.append(output.precision) 1068 recalls.append(output.recall) 1069 1070 # Get the performance metric 1071 if metric == "accuracy": 1072 performances = accuracies 1073 elif metric == "precision": 1074 performances = precisions 1075 elif metric == "recall": 1076 performances = recalls 1077 else: 1078 logger.warning("Performance metric invalid, defaulting to accuracy") 1079 performances = accuracies 1080 1081 best_round_performance = np.max(performances) 1082 best_set_index = accuracies.index(best_round_performance) 1083 1084 sbfs_subset = sets_to_try[best_set_index] 1085 new_channel_subset = [channel_labels[c] for c in sbfs_subset] 1086 model = models[best_set_index] 1087 preds = predictions[best_set_index] 1088 accuracy = accuracies[best_set_index] 1089 precision = precisions[best_set_index] 1090 recall = recalls[best_set_index] 1091 1092 logger.debug("Removed a channel") 1093 logger.debug("New subset: %s", new_channel_subset) 1094 logger.debug("Accuracy: %s", accuracy) 1095 logger.debug("Accuracies: %s", [float(acc) for acc in accuracies]) 1096 1097 current_performance = best_round_performance 1098 1099 # If this is the best perfomance at n_channels 1100 if performance_at_n_channels[len(sbfs_subset) - 1] < current_performance: 1101 performance_at_n_channels[len(sbfs_subset) - 1] = current_performance 1102 best_subset_at_n_channels[len(sbfs_subset) - 1] = sbfs_subset 1103 1104 p_delta = current_performance - previous_performance 1105 previous_performance = current_performance 1106 1107 # If the performance is the best so far, then save it as the best 1108 if current_performance > best_performance: 1109 best_channel_subset = new_channel_subset 1110 best_model = model 1111 best_performance = current_performance 1112 best_preds = preds 1113 best_accuracy = accuracy 1114 best_precision = precision 1115 best_recall = recall 1116 elif current_performance >= best_performance and len(new_channel_subset) < len( 1117 best_channel_subset 1118 ): 1119 best_channel_subset = new_channel_subset 1120 best_model = model 1121 best_performance = current_performance 1122 best_preds = preds 1123 best_accuracy = accuracy 1124 best_precision = precision 1125 best_recall = recall 1126 1127 if record_performance: 1128 new_channel_subset.sort() 1129 results_df.loc[step] = [ 1130 step, 1131 time.time() - start_time, 1132 len(new_channel_subset), 1133 "".join(new_channel_subset), 1134 len(sets_to_try), 1135 accuracy, 1136 precision, 1137 recall, 1138 ] 1139 1140 step += 1 1141 1142 # Conditional Inclusion 1143 while stop_criterion is False: 1144 # Get the length of the set if we were to include an additional channel 1145 length_of_resultant_set = len(sbfs_subset) + 1 1146 if ( 1147 length_of_resultant_set > max_channels 1148 or length_of_resultant_set == len(channel_labels) 1149 ): 1150 break 1151 1152 # Check all of the possible inclusions that do not lead to a previously tested subset 1153 potential_channels_to_add = list(range(len(channel_labels))) 1154 [potential_channels_to_add.remove(c) for c in sbfs_subset] 1155 1156 sets_to_try = [] 1157 X_to_try = [] 1158 1159 for c in potential_channels_to_add: 1160 set_to_try = sbfs_subset.copy() 1161 set_to_try.append(c) 1162 set_to_try.sort() 1163 1164 if set_to_try not in all_sets_tried: 1165 sets_to_try.append(set_to_try) 1166 all_sets_tried.append(set_to_try) 1167 1168 else: 1169 continue 1170 1171 # Get the new subset of data 1172 new_subset_data = np.zeros((n_trials, len(set_to_try), n_samples)) 1173 for subset_idx, channel_number in enumerate(set_to_try): 1174 channel_data = X[:, channel_number, :] 1175 new_subset_data[:, subset_idx, :] = channel_data 1176 1177 # Add to list of all subsets of X to try 1178 X_to_try.append(new_subset_data) 1179 1180 if X_to_try == []: 1181 break 1182 1183 # run the kernel on the new sets 1184 outputs = Parallel(n_jobs=n_jobs)( 1185 delayed(kernel_func)(Xtest, y) for Xtest in X_to_try 1186 ) 1187 1188 models = [] 1189 predictions = [] 1190 accuracies = [] 1191 precisions = [] 1192 recalls = [] 1193 performances = [] 1194 1195 # Extract the outputs 1196 for output in outputs: 1197 models.append(output.model) 1198 predictions.append(output.cv_preds) 1199 accuracies.append(output.accuracy) 1200 precisions.append(output.precision) 1201 recalls.append(output.recall) 1202 1203 # Get the performance metric 1204 if metric == "accuracy": 1205 performances = accuracies 1206 elif metric == "precision": 1207 performances = precisions 1208 elif metric == "recall": 1209 performances = recalls 1210 else: 1211 logger.warning("Performance metric invalid, defaulting to accuracy") 1212 performances = accuracies 1213 1214 best_round_performance = np.max(performances) 1215 best_set_index = accuracies.index(best_round_performance) 1216 1217 # if performance is better the best performance at n_channels 1218 if ( 1219 performance_at_n_channels[length_of_resultant_set - 1] 1220 < best_round_performance 1221 ): 1222 sbfs_subset = sets_to_try[best_set_index] 1223 new_channel_subset = [channel_labels[c] for c in sbfs_subset] 1224 model = models[best_set_index] 1225 preds = predictions[best_set_index] 1226 accuracy = accuracies[best_set_index] 1227 precision = precisions[best_set_index] 1228 recall = recalls[best_set_index] 1229 1230 logger.debug("Added back a channel") 1231 logger.debug("New subset: %s", new_channel_subset) 1232 logger.debug("Accuracy: %s", accuracy) 1233 logger.debug("Accuracies: %s", [float(acc) for acc in accuracies]) 1234 1235 current_performance = best_round_performance 1236 1237 p_delta = current_performance - previous_performance 1238 previous_performance = current_performance 1239 1240 # ADD Memory here 1241 if current_performance > best_performance: 1242 best_channel_subset = new_channel_subset 1243 best_model = model 1244 best_performance = current_performance 1245 best_preds = preds 1246 best_accuracy = accuracy 1247 best_precision = precision 1248 best_recall = recall 1249 elif current_performance >= best_performance and len( 1250 new_channel_subset 1251 ) < len(best_channel_subset): 1252 best_channel_subset = new_channel_subset 1253 best_model = model 1254 best_performance = current_performance 1255 best_preds = preds 1256 best_accuracy = accuracy 1257 best_precision = precision 1258 best_recall = recall 1259 1260 new_channel_subset.sort() 1261 results_df.loc[step] = [ 1262 step, 1263 time.time() - start_time, 1264 len(new_channel_subset), 1265 "".join(new_channel_subset), 1266 len(sets_to_try), 1267 accuracy, 1268 precision, 1269 recall, 1270 ] 1271 step += 1 1272 1273 performance_at_n_channels[length_of_resultant_set - 1] = ( 1274 current_performance 1275 ) 1276 best_subset_at_n_channels[length_of_resultant_set - 1] = sbfs_subset 1277 1278 # if no performance gains, then stop conditional inclusion 1279 else: 1280 break 1281 1282 # Check stopping criterion 1283 stop_criterion = __check_stopping_criterion( 1284 "SBFS", 1285 time.time() - start_time, 1286 len(new_channel_subset), 1287 p_delta, 1288 max_time, 1289 min_channels, 1290 max_channels, 1291 performance_delta, 1292 ) 1293 1294 stop_criterion = __check_stopping_criterion( 1295 "SBFS", 1296 time.time() - start_time, 1297 len(new_channel_subset), 1298 p_delta, 1299 max_time, 1300 min_channels, 1301 max_channels, 1302 performance_delta, 1303 ) 1304 1305 # Break if SBFS subset is 1 channel 1306 if len(sbfs_subset) == 1: 1307 break 1308 1309 new_channel_subset = [channel_labels[c] for c in sbfs_subset] 1310 1311 logger.debug("Best channel subset: %s", best_channel_subset) 1312 logger.debug("%s : %s", metric, best_performance) 1313 logger.debug("Time to optimal subset: %s s", time.time() - start_time) 1314 1315 if record_performance is True: 1316 logger.info(results_df) 1317 1318 return ChannelSelectionOutput( 1319 best_channel_subset, 1320 best_model, 1321 best_preds, 1322 best_accuracy, 1323 best_precision, 1324 best_recall, 1325 results_df, 1326 ) 1327 1328 1329def __sffs( 1330 kernel_func, 1331 X, 1332 y, 1333 channel_labels, 1334 metric, 1335 initial_channels, 1336 max_time, 1337 min_channels, 1338 max_channels, 1339 performance_delta, 1340 n_jobs, 1341 record_performance, 1342): 1343 """The Sequential Forward Floating Selection (SFFS) method for channel selection. 1344 1345 Parameters 1346 ---------- 1347 kernel_func : function 1348 The classification kernel function which does feature extraction 1349 and classification. 1350 Different functions are used for MI, P300, SSVEP, etc. 1351 X : numpy.ndarray 1352 Training data for the classifier as trials of EEG data. 1353 3D array containing data with `float` type. 1354 1355 shape = (`n_trials`,`n_channels`,`n_samples`) 1356 y : numpy.ndarray 1357 Training labels for the classifier. 1358 1D array. 1359 1360 shape = (`n_trials`) 1361 channel_labels: list of `str` 1362 The set of channel labels corresponding to `n_channels`. 1363 A list of strings with length = `n_channels`. 1364 metric : str 1365 The metric used to measure the "goodness" of the trained classifier. 1366 initial_channels : list of `str` 1367 Initial guess of channels. 1368 max_time : int 1369 The maxiumum amount of time, in seconds, that the function will 1370 search for the optimal solution. 1371 min_channels : int 1372 The minimum number of channels. 1373 max_channels : int 1374 The maximum number of channels. 1375 performance_delta : float 1376 The performance delta under which the algorithm is considered to 1377 be close enough to optimal. 1378 n_jobs : int 1379 The number of threads to dedicate to this calculation. 1380 record_performance : bool 1381 Flag on whether or not to record performance metrics at each step. 1382 1383 Returns 1384 ------- 1385 channelSelectionOutput : ChannelSelectionOutput 1386 ChannelSelectionOutput object containing the following attributes: 1387 best_channel_subset : list of `str` 1388 The new best channel subset from the list of `channel_labels`. 1389 self.clf : classifier 1390 The trained classification model. 1391 preds : numpy.ndarray 1392 The predictions from the model. 1393 1D array with the same shape as `y`. 1394 shape = (`n_trials`) 1395 accuracy : float 1396 The accuracy of the trained classification model. 1397 precision : float 1398 The precision of the trained classification model. 1399 recall : float 1400 The recall of the trained classification model. 1401 results_df : pandas.DataFrame 1402 The dataframe containing the results of each step of channel selection. 1403 1404 1405 """ 1406 results_df = pd.DataFrame( 1407 columns=[ 1408 "Step", 1409 "Time", 1410 "N Channels", 1411 "Channel Subset", 1412 "Unique Combinations Tested in Step", 1413 "Accuracy", 1414 "Precision", 1415 "Recall", 1416 ] 1417 ) 1418 step = 1 1419 1420 start_time = time.time() 1421 1422 n_trials, n_channels, n_samples = X.shape 1423 sffs_subset = [] 1424 all_sets_tried = [] # set of all channels that have been tried 1425 1426 for i, c in enumerate(channel_labels): 1427 if c in initial_channels: 1428 sffs_subset.append(i) 1429 1430 performance_at_n_channels = np.zeros(len(channel_labels)) 1431 performance_at_n_channels[: min_channels - 1] = np.inf 1432 best_subset_at_n_channels = [0] * len(channel_labels) 1433 1434 previous_performance = 0 1435 1436 stop_criterion = False 1437 1438 # Get the performance of the initial subset, if possible 1439 try: 1440 initial_results = kernel_func(X[:, sffs_subset, :], y) 1441 initial_model = initial_results.model 1442 initial_preds = initial_results.cv_preds 1443 initial_accuracy = initial_results.accuracy 1444 initial_precision = initial_results.precision 1445 initial_recall = initial_results.recall 1446 1447 if metric == "accuracy": 1448 initial_performance = initial_accuracy 1449 elif metric == "precision": 1450 initial_performance = initial_precision 1451 elif metric == "recall": 1452 initial_performance = initial_recall 1453 1454 # Best 1455 best_channel_subset = initial_channels 1456 best_model = initial_model 1457 best_performance = initial_performance 1458 best_preds = initial_preds 1459 best_accuracy = initial_accuracy 1460 best_precision = initial_precision 1461 best_recall = initial_recall 1462 1463 performance_at_n_channels[len(initial_channels) - 1] = initial_performance 1464 best_subset_at_n_channels[len(initial_channels) - 1] = initial_channels 1465 1466 # If not possible then set the initial performance to 0 1467 except ValueError: 1468 best_channel_subset = [] 1469 best_model = None 1470 best_performance = 0 1471 best_preds = [] 1472 best_accuracy = 0 1473 best_precision = 0 1474 best_recall = 0 1475 1476 preds = [] 1477 accuracy = 0 1478 precision = 0 1479 recall = 0 1480 1481 pass_stopping_criterion = False 1482 1483 # TODO Test the initial subset 1484 1485 while stop_criterion is False: 1486 sets_to_try = [] 1487 X_to_try = [] 1488 for c in range(n_channels): 1489 if c not in sffs_subset: 1490 set_to_try = sffs_subset.copy() 1491 set_to_try.append(c) 1492 sets_to_try.append(set_to_try) 1493 1494 # Get the new subset of data 1495 new_subset_data = np.zeros((n_trials, len(set_to_try), n_samples)) 1496 for subset_idx, channel_number in enumerate(set_to_try): 1497 channel_data = X[:, channel_number, :] 1498 new_subset_data[:, subset_idx, :] = channel_data 1499 1500 # Add to list of all subsets of X to try 1501 X_to_try.append(new_subset_data) 1502 1503 # This handles the multiprocessing to check multiple channel combinations at once if n_jobs > 1 1504 outputs = Parallel(n_jobs=n_jobs)( 1505 delayed(kernel_func)(Xtest, y) for Xtest in X_to_try 1506 ) 1507 1508 models = [] 1509 predictions = [] 1510 accuracies = [] 1511 precisions = [] 1512 recalls = [] 1513 1514 # Extract the outputs 1515 for output in outputs: 1516 models.append(output.model) 1517 predictions.append(output.cv_preds) 1518 accuracies.append(output.accuracy) 1519 precisions.append(output.precision) 1520 recalls.append(output.recall) 1521 1522 # Get the performance metric 1523 if metric == "accuracy": 1524 performances = accuracies 1525 elif metric == "precision": 1526 performances = precisions 1527 elif metric == "recall": 1528 performances = recalls 1529 else: 1530 logger.warning("Performance metric invalid, defaulting to accuracy") 1531 performances = accuracies 1532 1533 best_round_performance = np.max(performances) 1534 best_set_index = accuracies.index(best_round_performance) 1535 1536 sffs_subset = sets_to_try[best_set_index] 1537 new_channel_subset = [channel_labels[c] for c in sffs_subset] 1538 model = models[best_set_index] 1539 preds = predictions[best_set_index] 1540 accuracy = accuracies[best_set_index] 1541 precision = precisions[best_set_index] 1542 recall = recalls[best_set_index] 1543 1544 current_performance = best_round_performance 1545 1546 logger.debug("Removed a channel") 1547 logger.debug("New subset: %s", new_channel_subset) 1548 logger.debug("Accuracy: %s", accuracy) 1549 logger.debug("Accuracies: %s", [float(acc) for acc in accuracies]) 1550 1551 # If this is the best perfomance at n_channels 1552 if performance_at_n_channels[len(sffs_subset) - 1] < current_performance: 1553 performance_at_n_channels[len(sffs_subset) - 1] = current_performance 1554 best_subset_at_n_channels[len(sffs_subset) - 1] = sffs_subset 1555 1556 p_delta = current_performance - previous_performance 1557 previous_performance = current_performance 1558 1559 if current_performance > best_performance: 1560 best_channel_subset = new_channel_subset 1561 best_model = model 1562 best_performance = current_performance 1563 best_preds = preds 1564 best_accuracy = accuracy 1565 best_precision = precision 1566 best_recall = recall 1567 elif current_performance >= best_performance and len(new_channel_subset) < len( 1568 best_channel_subset 1569 ): 1570 best_channel_subset = new_channel_subset 1571 best_model = model 1572 best_performance = current_performance 1573 best_preds = preds 1574 best_accuracy = accuracy 1575 best_precision = precision 1576 best_recall = recall 1577 1578 new_channel_subset.sort() 1579 results_df.loc[step] = [ 1580 step, 1581 time.time() - start_time, 1582 len(new_channel_subset), 1583 "".join(new_channel_subset), 1584 len(sets_to_try), 1585 accuracy, 1586 precision, 1587 recall, 1588 ] 1589 1590 step += 1 1591 1592 # Conditional Exclusion 1593 while stop_criterion is False: 1594 # Get the length of the set if we were to include an additional channel 1595 length_of_resultant_set = len(sffs_subset) - 1 1596 if length_of_resultant_set < min_channels or length_of_resultant_set == 0: 1597 break 1598 1599 # If length of resultant set equal to min channels the pass stopping criterion 1600 if length_of_resultant_set == min_channels: 1601 pass_stopping_criterion = True 1602 1603 # Check all of the possible inclusions that do not lead to a previously tested subset 1604 potential_channels_to_add = list(range(len(channel_labels))) 1605 [potential_channels_to_add.remove(c) for c in sffs_subset] 1606 1607 sets_to_try = [] 1608 X_to_try = [] 1609 1610 for c in sffs_subset: 1611 set_to_try = sffs_subset.copy() 1612 set_to_try.remove(c) 1613 set_to_try.sort() 1614 1615 # Only try sets that have not been tried before 1616 if set_to_try not in all_sets_tried: 1617 sets_to_try.append(set_to_try) 1618 all_sets_tried.append(set_to_try) 1619 else: 1620 continue 1621 1622 # Get the new subset of data 1623 new_subset_data = np.zeros((n_trials, len(set_to_try), n_samples)) 1624 for subset_idx, channel_number in enumerate(set_to_try): 1625 channel_data = X[:, channel_number, :] 1626 new_subset_data[:, subset_idx, :] = channel_data 1627 1628 # Add to list of all subsets of X to try 1629 X_to_try.append(new_subset_data) 1630 1631 if X_to_try == []: 1632 break 1633 1634 # run the kernel on the new sets 1635 outputs = Parallel(n_jobs=n_jobs)( 1636 delayed(kernel_func)(Xtest, y) for Xtest in X_to_try 1637 ) 1638 1639 # [all_sets_tried.append(set.sort()) for set in sets_to_try] 1640 1641 models = [] 1642 predictions = [] 1643 accuracies = [] 1644 precisions = [] 1645 recalls = [] 1646 performances = [] 1647 1648 # Extract the outputs 1649 for output in outputs: 1650 models.append(output.model) 1651 predictions.append(output.cv_preds) 1652 accuracies.append(output.accuracy) 1653 precisions.append(output.precision) 1654 recalls.append(output.recall) 1655 1656 # Get the performance metric 1657 if metric == "accuracy": 1658 performances = accuracies 1659 elif metric == "precision": 1660 performances = precisions 1661 elif metric == "recall": 1662 performances = recalls 1663 else: 1664 logger.warning("Performance metric invalid, defaulting to accuracy") 1665 performances = accuracies 1666 1667 best_round_performance = np.max(performances) 1668 best_set_index = accuracies.index(best_round_performance) 1669 1670 # if performance is better at the resultant channel length 1671 if ( 1672 performance_at_n_channels[length_of_resultant_set - 1] 1673 < best_round_performance 1674 ): 1675 sffs_subset = sets_to_try[best_set_index] 1676 new_channel_subset = [channel_labels[c] for c in sffs_subset] 1677 model = models[best_set_index] 1678 preds = predictions[best_set_index] 1679 accuracy = accuracies[best_set_index] 1680 precision = precisions[best_set_index] 1681 recall = recalls[best_set_index] 1682 1683 logger.debug("Added back a channel") 1684 logger.debug("New subset: %s", new_channel_subset) 1685 logger.debug("Accuracy: %s", accuracy) 1686 logger.debug("Accuracies: %s", [float(acc) for acc in accuracies]) 1687 1688 current_performance = best_round_performance 1689 1690 p_delta = current_performance - previous_performance 1691 previous_performance = current_performance 1692 1693 if current_performance > best_performance: 1694 best_channel_subset = new_channel_subset 1695 best_model = model 1696 best_performance = current_performance 1697 best_preds = preds 1698 best_accuracy = accuracy 1699 best_precision = precision 1700 best_recall = recall 1701 elif current_performance >= best_performance and len( 1702 new_channel_subset 1703 ) < len(best_channel_subset): 1704 best_channel_subset = new_channel_subset 1705 best_model = model 1706 best_performance = current_performance 1707 best_preds = preds 1708 best_accuracy = accuracy 1709 best_precision = precision 1710 best_recall = recall 1711 1712 if record_performance: 1713 new_channel_subset.sort() 1714 results_df.loc[step] = [ 1715 step, 1716 time.time() - start_time, 1717 len(new_channel_subset), 1718 "".join(new_channel_subset), 1719 len(sets_to_try), 1720 accuracy, 1721 precision, 1722 recall, 1723 ] 1724 step += 1 1725 1726 performance_at_n_channels[length_of_resultant_set - 1] = ( 1727 current_performance 1728 ) 1729 best_subset_at_n_channels[length_of_resultant_set - 1] = sffs_subset 1730 1731 # if no performance gains, then stop conditional exclusion 1732 else: 1733 break 1734 1735 # Check stopping criterion 1736 if pass_stopping_criterion is False: 1737 stop_criterion = __check_stopping_criterion( 1738 "SFFS", 1739 time.time() - start_time, 1740 len(new_channel_subset), 1741 p_delta, 1742 max_time, 1743 min_channels, 1744 max_channels, 1745 performance_delta, 1746 ) 1747 1748 if pass_stopping_criterion: 1749 pass_stopping_criterion = False 1750 continue 1751 else: 1752 stop_criterion = __check_stopping_criterion( 1753 "SFFS", 1754 time.time() - start_time, 1755 len(new_channel_subset), 1756 p_delta, 1757 max_time, 1758 min_channels, 1759 max_channels, 1760 performance_delta, 1761 ) 1762 1763 new_channel_subset = [channel_labels[c] for c in sffs_subset] 1764 1765 logger.debug("Best channel subset: %s", best_channel_subset) 1766 logger.debug("%s : %s", metric, best_performance) 1767 logger.debug("Time to optimal subset: %s s", time.time() - start_time) 1768 1769 if record_performance is True: 1770 logger.info(results_df) 1771 1772 return ChannelSelectionOutput( 1773 best_channel_subset, 1774 best_model, 1775 best_preds, 1776 best_accuracy, 1777 best_precision, 1778 best_recall, 1779 results_df, 1780 )
logger =
<bci_essentials.utils.logger.Logger object>
@dataclass
class
ChannelSelectionOutput:
30@dataclass 31class ChannelSelectionOutput: 32 """Dataclass to store output from channel selection. 33 34 Parameters 35 ---------- 36 best_channel_subset : list of `str` 37 The best channel subset from the list of 'channel_labels'. 38 best_model : classifier 39 The trained classification model. 40 best_preds : numpy.ndarray 41 The predictions from the model. 42 best_accuracy : float 43 The accuracy of the trained classification model. 44 best_precision : float 45 The precision of the trained classification model. 46 best_recall : float 47 The recall of the trained classification model. 48 results_df : pandas.DataFrame 49 A dataframe containing the performance metrics at each step. 50 51 """ 52 53 best_channel_subset: list = field(default_factory=list) 54 best_model: Pipeline = field(default=None) 55 best_preds: np.ndarray = field(default_factory=np.ndarray) 56 best_accuracy: float = field(default=0.0) 57 best_precision: float = field(default=0.0) 58 best_recall: float = field(default=0.0) 59 results_df: pd.DataFrame = field(default_factory=pd.DataFrame)
Dataclass to store output from channel selection.
Parameters
- best_channel_subset (list of
str): The best channel subset from the list of 'channel_labels'. - best_model (classifier): The trained classification model.
- best_preds (numpy.ndarray): The predictions from the model.
- best_accuracy (float): The accuracy of the trained classification model.
- best_precision (float): The precision of the trained classification model.
- best_recall (float): The recall of the trained classification model.
- results_df (pandas.DataFrame): A dataframe containing the performance metrics at each step.
def
channel_selection_by_method( kernel_func, X, y, channel_labels, method='SBS', metric='accuracy', initial_channels=[], max_time=999, min_channels=1, max_channels=999, performance_delta=0.001, n_jobs=1, record_performance=True):
62def channel_selection_by_method( 63 kernel_func, 64 X, 65 y, 66 channel_labels, 67 method="SBS", 68 metric="accuracy", 69 initial_channels=[], 70 max_time=999, 71 min_channels=1, 72 max_channels=999, 73 performance_delta=0.001, 74 n_jobs=1, 75 record_performance=True, 76): 77 """Passes the BCI kernel function into a wrapper defined by `method`. 78 79 Parameters 80 ---------- 81 kernel_func : function 82 The classification kernel function which does feature extraction 83 and classification. 84 Different functions are used for MI, P300, SSVEP, etc. 85 X : numpy.ndarray 86 Training data for the classifier as trials of EEG data. 87 3D array containing data with `float` type. 88 89 shape = (`n_trials`,`n_channels`,`n_samples`) 90 y : numpy.ndarray 91 Training labels for the classifier. 92 1D array. 93 94 shape = (`n_trials`) 95 channel_labels : list of `str` 96 The set of channel labels corresponding to `n_channels`. 97 A list of strings with length = `n_channels`. 98 method = str, *optional* 99 The wrapper method. Options are `"SBS"` or `"SBFS"`. 100 - Default is `"SBS"`. 101 metric : str, *optional* 102 The metric used to measure the "goodness" of the trained classifier. 103 - Default is `"accuracy"`. 104 initial_channels : list of `str`, *optional* 105 Initial guess of channels. 106 - Defaults is `[]`. Assigns an empty set for forward selections, 107 and a full set for backward selections. 108 max_time : int, *optional* 109 The maxiumum amount of time, in seconds, that the function will 110 search for the optimal solution. 111 - Default is `999` seconds. 112 min_channels : int, *optional* 113 The minimum number of channels. 114 - Default is `1`. 115 max_channels : int, *optional* 116 The maximum number of channels. 117 - Default is `999`. 118 performance_delta : float, *optional* 119 The performance delta under which the algorithm is considered to 120 be close enough to optimal. 121 - Default is `0.001`. 122 n_jobs : int, *optional* 123 The number of threads to dedicate to this calculation. 124 - Default is `1`. 125 record_performance : bool, *optional* 126 Whether or not to record the performance of the channel selection 127 - Default is `True`. 128 129 Returns 130 ------- 131 channelSelectionOutput : ChannelSelectionOutput 132 ChannelSelectionOutput object containing the following attributes: 133 best_channel_subset : list of `str` 134 The new best channel subset from the list of `channel_labels`. 135 self.clf : classifier 136 The trained classification model. 137 preds : numpy.ndarray 138 The predictions from the model. 139 1D array with the same shape as `y`. 140 shape = (`n_trials`) 141 accuracy : float 142 The accuracy of the trained classification model. 143 precision : float 144 The precision of the trained classification model. 145 recall : float 146 The recall of the trained classification model. 147 results_df : pandas.DataFrame 148 The dataframe containing the results of each step of channel selection. 149 150 """ 151 152 # max length can't be greater than the length of channel labels 153 if max_channels > len(channel_labels): 154 logger.debug( 155 "Maximum number of channels must be less than or equal to " 156 + "the number of channels. Setting to number of channels." 157 ) 158 max_channels = len(channel_labels) 159 160 # min length can't be less than 1 161 if min_channels < 1: 162 logger.debug("Minimum number of channels must be greater than 0. Setting to 1.") 163 min_channels = 1 164 165 logger.debug("Running channel selection method: %s", method) 166 if method == "SBS": 167 if initial_channels == []: 168 initial_channels = channel_labels 169 logger.debug("Initial subset: %s", initial_channels) 170 171 # pass arguments to SBS 172 return __sbs( 173 kernel_func, 174 X, 175 y, 176 channel_labels=channel_labels, 177 metric=metric, 178 initial_channels=initial_channels, 179 max_time=max_time, 180 min_channels=min_channels, 181 max_channels=max_channels, 182 performance_delta=performance_delta, 183 n_jobs=n_jobs, 184 record_performance=record_performance, 185 ) 186 187 elif method == "SFS": 188 logger.debug("Initial subset: %s", initial_channels) 189 190 # pass arguments to SBS 191 return __sfs( 192 kernel_func, 193 X, 194 y, 195 channel_labels=channel_labels, 196 metric=metric, 197 initial_channels=initial_channels, 198 max_time=max_time, 199 min_channels=min_channels, 200 max_channels=max_channels, 201 performance_delta=performance_delta, 202 n_jobs=n_jobs, 203 record_performance=record_performance, 204 ) 205 206 elif method == "SBFS": 207 if initial_channels == []: 208 initial_channels = channel_labels 209 logger.debug("Initial subset: %s", initial_channels) 210 211 # pass arguments to SBS 212 return __sbfs( 213 kernel_func, 214 X, 215 y, 216 channel_labels=channel_labels, 217 metric=metric, 218 initial_channels=initial_channels, 219 max_time=max_time, 220 min_channels=min_channels, 221 max_channels=max_channels, 222 performance_delta=performance_delta, 223 n_jobs=n_jobs, 224 record_performance=record_performance, 225 ) 226 227 elif method == "SFFS": 228 logger.debug("Initial subset: %s", initial_channels) 229 230 # pass arguments to SBS 231 return __sffs( 232 kernel_func, 233 X, 234 y, 235 channel_labels=channel_labels, 236 metric=metric, 237 initial_channels=initial_channels, 238 max_time=max_time, 239 min_channels=min_channels, 240 max_channels=max_channels, 241 performance_delta=performance_delta, 242 n_jobs=n_jobs, 243 record_performance=record_performance, 244 )
Passes the BCI kernel function into a wrapper defined by method.
Parameters
- kernel_func (function): The classification kernel function which does feature extraction and classification. Different functions are used for MI, P300, SSVEP, etc.
X (numpy.ndarray): Training data for the classifier as trials of EEG data. 3D array containing data with
floattype.shape = (
n_trials,n_channels,n_samples)y (numpy.ndarray): Training labels for the classifier. 1D array.
shape = (
n_trials)- channel_labels (list of
str): The set of channel labels corresponding ton_channels. A list of strings with length =n_channels. - method = str, optional: The wrapper method. Options are
"SBS"or"SBFS".- Default is
"SBS".
- Default is
- metric (str, optional):
The metric used to measure the "goodness" of the trained classifier.
- Default is
"accuracy".
- Default is
- initial_channels (list of
str, optional): Initial guess of channels.- Defaults is
[]. Assigns an empty set for forward selections, and a full set for backward selections.
- Defaults is
- max_time (int, optional):
The maxiumum amount of time, in seconds, that the function will
search for the optimal solution.
- Default is
999seconds.
- Default is
- min_channels (int, optional):
The minimum number of channels.
- Default is
1.
- Default is
- max_channels (int, optional):
The maximum number of channels.
- Default is
999.
- Default is
- performance_delta (float, optional):
The performance delta under which the algorithm is considered to
be close enough to optimal.
- Default is
0.001.
- Default is
- n_jobs (int, optional):
The number of threads to dedicate to this calculation.
- Default is
1.
- Default is
- record_performance (bool, optional):
Whether or not to record the performance of the channel selection
- Default is
True.
- Default is
Returns
- channelSelectionOutput (ChannelSelectionOutput):
ChannelSelectionOutput object containing the following attributes:
best_channel_subset : list of
strThe new best channel subset from the list ofchannel_labels. self.clf : classifier The trained classification model. preds : numpy.ndarray The predictions from the model. 1D array with the same shape asy. shape = (n_trials) accuracy : float The accuracy of the trained classification model. precision : float The precision of the trained classification model. recall : float The recall of the trained classification model. results_df : pandas.DataFrame The dataframe containing the results of each step of channel selection.