bci_essentials.channel_selection

Channel selection methods for BCI performance improvement.

This module includes functions for selecting channels in order to improve BCI performance.

Notes

The EEG data input for each function is a set of trials. The data must be of the shape n_trials x n_channels x n_samples, where: - n_trials = number of trials - n_channels = number of channels - n_samples = number of samples

   1"""Channel selection methods for BCI performance improvement.
   2
   3This module includes functions for selecting channels in order to
   4improve BCI performance.
   5
   6Notes
   7-----
   8The EEG data input for each function is a set of trials. The data must
   9be of the shape `n_trials x n_channels x n_samples`, where:
  10    - n_trials = number of trials
  11    - n_channels = number of channels
  12    - n_samples = number of samples
  13
  14"""
  15
  16from joblib import Parallel, delayed
  17import time
  18import numpy as np
  19import pandas as pd
  20from dataclasses import dataclass, field
  21from .utils.logger import Logger  # Logger wrapper
  22from sklearn.pipeline import Pipeline
  23
  24# Instantiate a logger for the module at the default level of logging.INFO
  25# Logs to bci_essentials.__module__) where __module__ is the name of the module
  26logger = Logger(name=__name__)
  27
  28
  29@dataclass
  30class ChannelSelectionOutput:
  31    """Dataclass to store output from channel selection.
  32
  33    Parameters
  34    ----------
  35    best_channel_subset : list of `str`
  36        The best channel subset from the list of 'channel_labels'.
  37    best_model : classifier
  38        The trained classification model.
  39    best_preds : numpy.ndarray
  40        The predictions from the model.
  41    best_accuracy : float
  42        The accuracy of the trained classification model.
  43    best_precision : float
  44        The precision of the trained classification model.
  45    best_recall : float
  46        The recall of the trained classification model.
  47    results_df : pandas.DataFrame
  48        A dataframe containing the performance metrics at each step.
  49
  50    """
  51
  52    best_channel_subset: list = field(default_factory=list)
  53    best_model: Pipeline = field(default=None)
  54    best_preds: np.ndarray = field(default_factory=np.ndarray)
  55    best_accuracy: float = field(default=0.0)
  56    best_precision: float = field(default=0.0)
  57    best_recall: float = field(default=0.0)
  58    results_df: pd.DataFrame = field(default_factory=pd.DataFrame)
  59
  60
  61def channel_selection_by_method(
  62    kernel_func,
  63    X,
  64    y,
  65    channel_labels,
  66    method="SBS",
  67    metric="accuracy",
  68    initial_channels=[],
  69    max_time=999,
  70    min_channels=1,
  71    max_channels=999,
  72    performance_delta=0.001,
  73    n_jobs=1,
  74    record_performance=True,
  75):
  76    """Passes the BCI kernel function into a wrapper defined by `method`.
  77
  78    Parameters
  79    ----------
  80    kernel_func : function
  81        The classification kernel function which does feature extraction
  82        and classification.
  83        Different functions  are used for MI, P300, SSVEP, etc.
  84    X : numpy.ndarray
  85        Training data for the classifier as trials of EEG data.
  86        3D array containing data with `float` type.
  87
  88        shape = (`n_trials`,`n_channels`,`n_samples`)
  89    y : numpy.ndarray
  90        Training labels for the classifier.
  91        1D array.
  92
  93        shape = (`n_trials`)
  94    channel_labels : list of `str`
  95        The set of channel labels corresponding to `n_channels`.
  96        A list of strings with length = `n_channels`.
  97    method = str, *optional*
  98        The wrapper method. Options are `"SBS"` or `"SBFS"`.
  99        - Default is `"SBS"`.
 100    metric : str, *optional*
 101        The metric used to measure the "goodness" of the trained classifier.
 102        - Default is `"accuracy"`.
 103    initial_channels : list of `str`, *optional*
 104        Initial guess of channels.
 105        - Defaults is `[]`. Assigns an empty set for forward selections,
 106        and a full set for backward selections.
 107    max_time : int, *optional*
 108        The maxiumum amount of time, in seconds, that the function will
 109        search for the optimal solution.
 110        - Default is `999` seconds.
 111    min_channels : int, *optional*
 112        The minimum number of channels.
 113        - Default is `1`.
 114    max_channels : int, *optional*
 115        The maximum number of channels.
 116        - Default is `999`.
 117    performance_delta : float, *optional*
 118        The performance delta under which the algorithm is considered to
 119        be close enough to optimal.
 120        - Default is `0.001`.
 121    n_jobs : int, *optional*
 122        The number of threads to dedicate to this calculation.
 123        - Default is `1`.
 124    record_performance : bool, *optional*
 125        Whether or not to record the performance of the channel selection
 126        - Default is `True`.
 127
 128    Returns
 129    -------
 130    channelSelectionOutput : ChannelSelectionOutput
 131        ChannelSelectionOutput object containing the following attributes:
 132            best_channel_subset : list of `str`
 133                The new best channel subset from the list of `channel_labels`.
 134            self.clf : classifier
 135                The trained classification model.
 136            preds : numpy.ndarray
 137                The predictions from the model.
 138                1D array with the same shape as `y`.
 139                shape = (`n_trials`)
 140            accuracy : float
 141                The accuracy of the trained classification model.
 142            precision : float
 143                The precision of the trained classification model.
 144            recall : float
 145                The recall of the trained classification model.
 146            results_df : pandas.DataFrame
 147                The dataframe containing the results of each step of channel selection.
 148
 149    """
 150
 151    # max length can't be greater than the length of channel labels
 152    if max_channels > len(channel_labels):
 153        logger.debug(
 154            "Maximum number of channels must be less than or equal to "
 155            + "the number of channels. Setting to number of channels."
 156        )
 157        max_channels = len(channel_labels)
 158
 159    # min length can't be less than 1
 160    if min_channels < 1:
 161        logger.debug("Minimum number of channels must be greater than 0. Setting to 1.")
 162        min_channels = 1
 163
 164    logger.debug("Running channel selection method: %s", method)
 165    if method == "SBS":
 166        if initial_channels == []:
 167            initial_channels = channel_labels
 168        logger.debug("Initial subset: %s", initial_channels)
 169
 170        # pass arguments to SBS
 171        return __sbs(
 172            kernel_func,
 173            X,
 174            y,
 175            channel_labels=channel_labels,
 176            metric=metric,
 177            initial_channels=initial_channels,
 178            max_time=max_time,
 179            min_channels=min_channels,
 180            max_channels=max_channels,
 181            performance_delta=performance_delta,
 182            n_jobs=n_jobs,
 183            record_performance=record_performance,
 184        )
 185
 186    elif method == "SFS":
 187        logger.debug("Initial subset: %s", initial_channels)
 188
 189        # pass arguments to SBS
 190        return __sfs(
 191            kernel_func,
 192            X,
 193            y,
 194            channel_labels=channel_labels,
 195            metric=metric,
 196            initial_channels=initial_channels,
 197            max_time=max_time,
 198            min_channels=min_channels,
 199            max_channels=max_channels,
 200            performance_delta=performance_delta,
 201            n_jobs=n_jobs,
 202            record_performance=record_performance,
 203        )
 204
 205    elif method == "SBFS":
 206        if initial_channels == []:
 207            initial_channels = channel_labels
 208        logger.debug("Initial subset: %s", initial_channels)
 209
 210        # pass arguments to SBS
 211        return __sbfs(
 212            kernel_func,
 213            X,
 214            y,
 215            channel_labels=channel_labels,
 216            metric=metric,
 217            initial_channels=initial_channels,
 218            max_time=max_time,
 219            min_channels=min_channels,
 220            max_channels=max_channels,
 221            performance_delta=performance_delta,
 222            n_jobs=n_jobs,
 223            record_performance=record_performance,
 224        )
 225
 226    elif method == "SFFS":
 227        logger.debug("Initial subset: %s", initial_channels)
 228
 229        # pass arguments to SBS
 230        return __sffs(
 231            kernel_func,
 232            X,
 233            y,
 234            channel_labels=channel_labels,
 235            metric=metric,
 236            initial_channels=initial_channels,
 237            max_time=max_time,
 238            min_channels=min_channels,
 239            max_channels=max_channels,
 240            performance_delta=performance_delta,
 241            n_jobs=n_jobs,
 242            record_performance=record_performance,
 243        )
 244
 245
 246def __check_stopping_criterion(
 247    algorithm,
 248    current_time,
 249    n_channels,
 250    current_performance_delta,
 251    max_time,
 252    min_channels,
 253    max_channels,
 254    performance_delta,
 255):
 256    """Function to check if a stopping criterion has been met.
 257
 258    Parameters
 259    ----------
 260    algorithm : str
 261        The algorithm being used for channel selection.
 262    current_time : float
 263        The time elapsed since the start of the channel selection method.
 264    n_channels : int
 265        The number of channels in the current iteration of the new best channel
 266        subset (`len(new_channel_subset)`).
 267    current_performance_delta : float
 268        The performance delta between the current iteration and the previous.
 269    max_time : int
 270        The maxiumum amount of time, in seconds, that the function will
 271        search for the optimal solution.
 272    min_channels : int
 273        The minimum number of channels.
 274    max_channels : int
 275        The maximum number of channels.
 276    performance_delta : float
 277        The performance delta under which the algorithm is considered to
 278        be close enough to optimal.
 279
 280    Returns
 281    -------
 282    *bool*
 283        Has stopping criterion been met (`True`) or not (`False`).
 284
 285    """
 286    if current_time > max_time:
 287        logger.debug("Stopping based on time")
 288        return True
 289
 290    if algorithm == "SBS" or algorithm == "SBFS":
 291        if n_channels <= min_channels:
 292            logger.debug("Stopping because minimum number of channels reached")
 293            return True
 294
 295    if algorithm == "SFS" or algorithm == "SFFS":
 296        if n_channels >= max_channels:
 297            logger.debug("Stopping because maximum number of channels reached")
 298            return True
 299
 300    if current_performance_delta < performance_delta:
 301        logger.debug("Stopping because performance improvements are declining")
 302        return True
 303
 304    return False
 305
 306
 307def __sfs(
 308    kernel_func,
 309    X,
 310    y,
 311    channel_labels,
 312    metric,
 313    initial_channels,
 314    max_time,
 315    min_channels,
 316    max_channels,
 317    performance_delta,
 318    n_jobs,
 319    record_performance,
 320):
 321    """
 322    The Sequential Forward Selection (SFS) method for channel selection.
 323
 324    Parameters
 325    ----------
 326    kernel_func : function
 327        The classification kernel function which does feature extraction
 328        and classification.
 329        Different functions  are used for MI, P300, SSVEP, etc.
 330    X : numpy.ndarray
 331        Training data for the classifier as trials of EEG data.
 332        3D array containing data with `float` type.
 333
 334        shape = (`n_trials`,`n_channels`,`n_samples`)
 335    y : numpy.ndarray
 336        Training labels for the classifier.
 337        1D array.
 338
 339        shape = (`n_trials`)
 340    channel_labels: list of `str`
 341        The set of channel labels corresponding to `n_channels`.
 342        A list of strings with length = `n_channels`.
 343    metric : str
 344        The metric used to measure the "goodness" of the trained classifier.
 345    initial_channels : list of `str`
 346        Initial guess of channels.
 347    max_time : int
 348        The maxiumum amount of time, in seconds, that the function will
 349        search for the optimal solution.
 350    min_channels : int
 351        The minimum number of channels.
 352    max_channels : int
 353        The maximum number of channels.
 354    performance_delta : float
 355        The performance delta under which the algorithm is considered to
 356        be close enough to optimal.
 357    n_jobs : int
 358        The number of threads to dedicate to this calculation.
 359    record_performance : bool
 360        Flag on whether or not to record performance at each step.
 361
 362
 363    Returns
 364    -------
 365    channelSelectionOutput : ChannelSelectionOutput
 366        ChannelSelectionOutput object containing the following attributes:
 367            best_channel_subset : list of `str`
 368                The new best channel subset from the list of `channel_labels`.
 369            self.clf : classifier
 370                The trained classification model.
 371            preds : numpy.ndarray
 372                The predictions from the model.
 373                1D array with the same shape as `y`.
 374
 375                shape = (`n_trials`)
 376            accuracy : float
 377                The accuracy of the trained classification model.
 378            precision : float
 379                The precision of the trained classification model.
 380            recall : float
 381                The recall of the trained classification model.
 382            results_df : pandas.DataFrame
 383                The dataframe containing the results of each step of channel selection.
 384    """
 385    results_df = pd.DataFrame(
 386        columns=[
 387            "Step",
 388            "Time",
 389            "N Channels",
 390            "Channel Subset",
 391            "Unique Combinations Tested in Step",
 392            "Accuracy",
 393            "Precision",
 394            "Recall",
 395        ]
 396    )
 397    step = 1
 398
 399    start_time = time.time()
 400
 401    n_trials, n_channels, n_samples = X.shape
 402    sfs_subset = []
 403
 404    for i, c in enumerate(channel_labels):
 405        if c in initial_channels:
 406            sfs_subset.append(i)
 407
 408    previous_performance = 0
 409
 410    stop_criterion = False
 411
 412    # Get the performance of the initial subset, if possible
 413    try:
 414        initial_results = kernel_func(X[:, sfs_subset, :], y)
 415        initial_model = initial_results.model
 416        initial_preds = initial_results.cv_preds
 417        initial_accuracy = initial_results.accuracy
 418        initial_precision = initial_results.precision
 419        initial_recall = initial_results.recall
 420
 421        if metric == "accuracy":
 422            initial_performance = initial_accuracy
 423        elif metric == "precision":
 424            initial_performance = initial_precision
 425        elif metric == "recall":
 426            initial_performance = initial_recall
 427
 428        # Best
 429        best_channel_subset = initial_channels
 430        best_model = initial_model
 431        best_performance = initial_performance
 432        best_preds = initial_preds
 433        best_accuracy = initial_accuracy
 434        best_precision = initial_precision
 435        best_recall = initial_recall
 436
 437    # If not possible then set the initial performance to 0
 438    except ValueError:
 439        best_channel_subset = []
 440        best_model = None
 441        best_performance = 0
 442        best_preds = []
 443        best_accuracy = 0
 444        best_precision = 0
 445        best_recall = 0
 446
 447    preds = []
 448    accuracy = 0
 449    precision = 0
 450    recall = 0
 451
 452    while stop_criterion is False:
 453        sets_to_try = []
 454        X_to_try = []
 455        for channel in range(n_channels):
 456            if channel not in sfs_subset:
 457                set_to_try = sfs_subset.copy()
 458                set_to_try.append(channel)
 459                sets_to_try.append(set_to_try)
 460
 461                # Get the new subset of data
 462                new_subset_data = np.zeros((n_trials, len(set_to_try), n_samples))
 463                for subset_idx, channel_number in enumerate(set_to_try):
 464                    channel_data = X[:, channel_number, :]
 465                    new_subset_data[:, subset_idx, :] = channel_data
 466
 467                # Add to list of all subsets of X to try
 468                X_to_try.append(new_subset_data)
 469
 470        # This handles the multiprocessing to check multiple channel combinations at once if n_jobs > 1
 471        outputs = Parallel(n_jobs=n_jobs)(
 472            delayed(kernel_func)(Xtest, y) for Xtest in X_to_try
 473        )
 474
 475        models = []
 476        predictions = []
 477        accuracies = []
 478        precisions = []
 479        recalls = []
 480
 481        # Extract the outputs
 482        for output in outputs:
 483            models.append(output.model)
 484            predictions.append(output.cv_preds)
 485            accuracies.append(output.accuracy)
 486            precisions.append(output.precision)
 487            recalls.append(output.recall)
 488
 489        # Get the performance metric
 490        if metric == "accuracy":
 491            performances = accuracies
 492        elif metric == "precision":
 493            performances = precisions
 494        elif metric == "recall":
 495            performances = recalls
 496        else:
 497            logger.warning("Performance metric invalid, defaulting to accuracy")
 498            performances = accuracies
 499
 500        # Get the index of the best X tried in this round
 501        best_set_index = accuracies.index(np.max(performances))
 502
 503        sfs_subset = sets_to_try[best_set_index]
 504        new_channel_subset = [channel_labels[c] for c in sfs_subset]
 505        model = models[best_set_index]
 506        preds = predictions[best_set_index]
 507        accuracy = accuracies[best_set_index]
 508        # best_overall_accuracy = accuracy
 509        precision = precisions[best_set_index]
 510        recall = recalls[best_set_index]
 511        logger.debug("New subset: %s", new_channel_subset)
 512        logger.debug("Accuracy: %s", accuracy)
 513        logger.debug("Accuracies: %s", [float(acc) for acc in accuracies])
 514
 515        if metric == "accuracy":
 516            current_performance = accuracy
 517
 518        p_delta = current_performance - previous_performance
 519        previous_performance = current_performance
 520
 521        if current_performance > best_performance:
 522            best_channel_subset = new_channel_subset
 523            best_model = model
 524            best_performance = current_performance
 525            best_preds = preds
 526            best_accuracy = accuracy
 527            best_precision = precision
 528            best_recall = recall
 529        elif current_performance >= best_performance and len(new_channel_subset) < len(
 530            best_channel_subset
 531        ):
 532            best_channel_subset = new_channel_subset
 533            best_model = model
 534            best_performance = current_performance
 535            best_preds = preds
 536            best_accuracy = accuracy
 537            best_precision = precision
 538            best_recall = recall
 539
 540        new_channel_subset.sort()
 541        results_df.loc[step] = [
 542            step,
 543            time.time() - start_time,
 544            len(new_channel_subset),
 545            "".join(new_channel_subset),
 546            len(sets_to_try),
 547            accuracy,
 548            precision,
 549            recall,
 550        ]
 551
 552        step += 1
 553
 554        stop_criterion = __check_stopping_criterion(
 555            "SFS",
 556            time.time() - start_time,
 557            len(new_channel_subset),
 558            p_delta,
 559            max_time,
 560            min_channels,
 561            max_channels,
 562            performance_delta,
 563        )
 564
 565    new_channel_subset = [channel_labels[c] for c in sfs_subset]
 566
 567    logger.debug("Best channel subset: %s", best_channel_subset)
 568    logger.debug("%s : %s", metric, best_performance)
 569    logger.debug("Time to optimal subset: %s s", time.time() - start_time)
 570
 571    if record_performance is True:
 572        logger.info(results_df)
 573
 574    # Get the best model
 575
 576    return ChannelSelectionOutput(
 577        best_channel_subset,
 578        best_model,
 579        best_preds,
 580        best_accuracy,
 581        best_precision,
 582        best_recall,
 583        results_df,
 584    )
 585
 586
 587def __sbs(
 588    kernel_func,
 589    X,
 590    y,
 591    channel_labels,
 592    metric,
 593    initial_channels,
 594    max_time,
 595    min_channels,
 596    max_channels,
 597    performance_delta,
 598    n_jobs,
 599    record_performance,
 600):
 601    """The Sequential Backward Selection (SBS) method for channel selection.
 602
 603    Parameters
 604    ----------
 605    kernel_func : function
 606        The classification kernel function which does feature extraction
 607        and classification.
 608        Different functions  are used for MI, P300, SSVEP, etc.
 609    X : numpy.ndarray
 610        Training data for the classifier as trials of EEG data.
 611        3D array containing data with `float` type.
 612
 613        shape = (`n_trials`,`n_channels`,`n_samples`)
 614    y : numpy.ndarray
 615        Training labels for the classifier.
 616        1D array.
 617
 618        shape = (`n_trials`)
 619    channel_labels: list of `str`
 620        The set of channel labels corresponding to `n_channels`.
 621        A list of strings with length = `n_channels`.
 622    metric : str
 623        The metric used to measure the "goodness" of the trained classifier.
 624    initial_channels : list of `str`
 625        Initial guess of channels.
 626    max_time : int
 627        The maxiumum amount of time, in seconds, that the function will
 628        search for the optimal solution.
 629    min_channels : int
 630        The minimum number of channels.
 631    max_channels : int
 632        The maximum number of channels.
 633    performance_delta : float
 634        The performance delta under which the algorithm is considered to
 635        be close enough to optimal.
 636    n_jobs : int
 637        The number of threads to dedicate to this calculation.
 638    record_performance : bool
 639        Flag on whether or not to record performance metrics at each step.
 640
 641    Returns
 642    -------
 643    channelSelectionOutput : ChannelSelectionOutput
 644        ChannelSelectionOutput object containing the following attributes:
 645            best_channel_subset : list of `str`
 646                The new best channel subset from the list of `channel_labels`.
 647            self.clf : classifier
 648                The trained classification model.
 649            preds : numpy.ndarray
 650                The predictions from the model.
 651                1D array with the same shape as `y`.
 652                shape = (`n_trials`)
 653            accuracy : float
 654                The accuracy of the trained classification model.
 655            precision : float
 656                The precision of the trained classification model.
 657            recall : float
 658                The recall of the trained classification model.
 659            results_df : pandas.DataFrame
 660                The dataframe containing the results of each step of channel selection.
 661
 662
 663    """
 664
 665    results_df = pd.DataFrame(
 666        columns=[
 667            "Step",
 668            "Time",
 669            "N Channels",
 670            "Channel Subset",
 671            "Unique Combinations Tested in Step",
 672            "Accuracy",
 673            "Precision",
 674            "Recall",
 675        ]
 676    )
 677    step = 1
 678
 679    if len(initial_channels) <= min_channels:
 680        initial_channels = channel_labels
 681
 682    start_time = time.time()
 683
 684    n_trials, n_channels, n_samples = X.shape
 685    sbs_subset = []
 686    all_sets_tried = []  # set of all channels that have been tried
 687
 688    for i, c in enumerate(channel_labels):
 689        if c in initial_channels:
 690            sbs_subset.append(i)
 691
 692    # Get the performance of the initial subset
 693    initial_results = kernel_func(X[:, sbs_subset, :], y)
 694    initial_model = initial_results.model
 695    initial_preds = initial_results.cv_preds
 696    initial_accuracy = initial_results.accuracy
 697    initial_precision = initial_results.precision
 698    initial_recall = initial_results.recall
 699
 700    if metric == "accuracy":
 701        initial_performance = initial_accuracy
 702    elif metric == "precision":
 703        initial_performance = initial_precision
 704    elif metric == "recall":
 705        initial_performance = initial_recall
 706
 707    # Best
 708    best_channel_subset = initial_channels
 709    best_model = initial_model
 710    best_performance = initial_performance
 711    best_preds = initial_preds
 712    best_accuracy = initial_accuracy
 713    best_precision = initial_precision
 714    best_recall = initial_recall
 715
 716    previous_performance = 0
 717
 718    stop_criterion = False
 719
 720    preds = []
 721    accuracy = 0
 722    precision = 0
 723    recall = 0
 724
 725    while stop_criterion is False:
 726        # Exclusion Step
 727        sets_to_try = []
 728        X_to_try = []
 729        for channel in sbs_subset:
 730            set_to_try = sbs_subset.copy()
 731            set_to_try.remove(channel)
 732            set_to_try.sort()
 733
 734            # Only try sets that have not been tried before
 735            if set_to_try not in all_sets_tried:
 736                sets_to_try.append(set_to_try)
 737                all_sets_tried.append(set_to_try)
 738            else:
 739                continue
 740
 741            # Get the new subset of data
 742            new_subset_data = np.zeros((n_trials, len(set_to_try), n_samples))
 743            for subset_idx, channel_number in enumerate(set_to_try):
 744                channel_data = X[:, channel_number, :]
 745                new_subset_data[:, subset_idx, :] = channel_data
 746
 747            # Add to list of all subsets of X to try
 748            X_to_try.append(new_subset_data)
 749
 750        # run the kernel function on all cores
 751        outputs = Parallel(n_jobs=n_jobs)(
 752            delayed(kernel_func)(Xtest, y) for Xtest in X_to_try
 753        )
 754
 755        # [all_sets_tried.append(set.sort()) for set in sets_to_try]
 756
 757        models = []
 758        predictions = []
 759        accuracies = []
 760        precisions = []
 761        recalls = []
 762
 763        # Extract the outputs
 764        for output in outputs:
 765            models.append(output.model)
 766            predictions.append(output.cv_preds)
 767            accuracies.append(output.accuracy)
 768            precisions.append(output.precision)
 769            recalls.append(output.recall)
 770
 771        # Get the performance metric
 772        if metric == "accuracy":
 773            performances = accuracies
 774        elif metric == "precision":
 775            performances = precisions
 776        elif metric == "recall":
 777            performances = recalls
 778        else:
 779            logger.warning("Performance metric invalid, defaulting to accuracy")
 780            performances = accuracies
 781
 782        best_set_index = accuracies.index(np.max(performances))
 783
 784        sbs_subset = sets_to_try[best_set_index]
 785        new_channel_subset = [channel_labels[c] for c in sbs_subset]
 786        model = models[best_set_index]
 787        preds = predictions[best_set_index]
 788        accuracy = accuracies[best_set_index]
 789        # best_overall_accuracy = accuracy
 790        precision = precisions[best_set_index]
 791        recall = recalls[best_set_index]
 792
 793        current_performance = performances[best_set_index]
 794        logger.debug("Removed a channel")
 795        logger.debug("New subset: %s", new_channel_subset)
 796        logger.debug("Accuracy: %s", accuracy)
 797        logger.debug("Accuracies: %s", [float(acc) for acc in accuracies])
 798
 799        p_delta = current_performance - previous_performance
 800        previous_performance = current_performance
 801
 802        if current_performance > best_performance:
 803            best_channel_subset = new_channel_subset
 804            best_model = model
 805            best_performance = current_performance
 806            best_preds = preds
 807            best_accuracy = accuracy
 808            best_precision = precision
 809            best_recall = recall
 810        elif current_performance >= best_performance and len(new_channel_subset) < len(
 811            best_channel_subset
 812        ):
 813            best_channel_subset = new_channel_subset
 814            best_model = model
 815            best_performance = current_performance
 816            best_preds = preds
 817            best_accuracy = accuracy
 818            best_precision = precision
 819            best_recall = recall
 820
 821        new_channel_subset.sort()
 822        results_df.loc[step] = [
 823            step,
 824            time.time() - start_time,
 825            len(new_channel_subset),
 826            "".join(new_channel_subset),
 827            len(sets_to_try),
 828            accuracy,
 829            precision,
 830            recall,
 831        ]
 832
 833        step += 1
 834
 835        # Break if SBS subset is 1 channel
 836        if len(sbs_subset) == 1:
 837            break
 838
 839        stop_criterion = __check_stopping_criterion(
 840            "SBS",
 841            time.time() - start_time,
 842            len(new_channel_subset),
 843            p_delta,
 844            max_time,
 845            min_channels,
 846            max_channels,
 847            performance_delta,
 848        )
 849
 850    new_channel_subset = [channel_labels[c] for c in sbs_subset]
 851
 852    logger.debug("Best channel subset: %s", best_channel_subset)
 853    logger.debug("%s : %s", metric, best_performance)
 854    logger.debug("Time to optimal subset: %s s", time.time() - start_time)
 855
 856    if record_performance is True:
 857        logger.info(results_df)
 858
 859    return ChannelSelectionOutput(
 860        best_channel_subset,
 861        best_model,
 862        best_preds,
 863        best_accuracy,
 864        best_precision,
 865        best_recall,
 866        results_df,
 867    )
 868
 869
 870def __sbfs(
 871    kernel_func,
 872    X,
 873    y,
 874    channel_labels,
 875    metric,
 876    initial_channels,
 877    max_time,
 878    min_channels,
 879    max_channels,
 880    performance_delta,
 881    n_jobs,
 882    record_performance,
 883):
 884    """The Sequential Backward Floating Selection (SBFS) method for channel selection.
 885
 886    Parameters
 887    ----------
 888    kernel_func : function
 889        The classification kernel function which does feature extraction
 890        and classification.
 891        Different functions  are used for MI, P300, SSVEP, etc.
 892    X : numpy.ndarray
 893        Training data for the classifier as trials of EEG data.
 894        3D array containing data with `float` type.
 895
 896        shape = (`n_trials`,`n_channels`,`n_samples`)
 897    y : numpy.ndarray
 898        Training labels for the classifier.
 899        1D array.
 900
 901        shape = (`n_trials`)
 902    channel_labels: list of `str`
 903        The set of channel labels corresponding to `n_channels`.
 904        A list of strings with length = `n_channels`.
 905    metric : str
 906        The metric used to measure the "goodness" of the trained classifier.
 907    initial_channels : list of `str`
 908        Initial guess of channels.
 909    max_time : int
 910        The maxiumum amount of time, in seconds, that the function will
 911        search for the optimal solution.
 912    min_channels : int
 913        The minimum number of channels.
 914    max_channels : int
 915        The maximum number of channels.
 916    performance_delta : float
 917        The performance delta under which the algorithm is considered to
 918        be close enough to optimal.
 919    n_jobs : int
 920        The number of threads to dedicate to this calculation.
 921    record_performance : bool
 922        Flag on whether or not to record performance metrics at each step.
 923
 924    Returns
 925    -------
 926    channelSelectionOutput : ChannelSelectionOutput
 927        ChannelSelectionOutput object containing the following attributes:
 928            best_channel_subset : list of `str`
 929                The new best channel subset from the list of `channel_labels`.
 930            self.clf : classifier
 931                The trained classification model.
 932            preds : numpy.ndarray
 933                The predictions from the model.
 934                1D array with the same shape as `y`.
 935                shape = (`n_trials`)
 936            accuracy : float
 937                The accuracy of the trained classification model.
 938            precision : float
 939                The precision of the trained classification model.
 940            recall : float
 941                The recall of the trained classification model.
 942            results_df : pandas.DataFrame
 943                The dataframe containing the results of each step of channel selection.
 944
 945
 946    """
 947    results_df = pd.DataFrame(
 948        columns=[
 949            "Step",
 950            "Time",
 951            "N Channels",
 952            "Channel Subset",
 953            "Unique Combinations Tested in Step",
 954            "Accuracy",
 955            "Precision",
 956            "Recall",
 957        ]
 958    )
 959    step = 1
 960
 961    if len(initial_channels) <= min_channels or len(initial_channels) == 0:
 962        initial_channels = channel_labels
 963
 964    start_time = time.time()
 965
 966    n_trials, n_channels, n_samples = X.shape
 967    sbfs_subset = []
 968    all_sets_tried = []  # set of all channels that have been tried
 969
 970    for i, c in enumerate(channel_labels):
 971        if c in initial_channels:
 972            sbfs_subset.append(i)
 973
 974    performance_at_n_channels = np.zeros(len(channel_labels))
 975    best_subset_at_n_channels = [0] * len(channel_labels)
 976
 977    previous_performance = 0
 978
 979    stop_criterion = False
 980
 981    # Get the performance of the initial subset, if possible
 982    try:
 983        initial_results = kernel_func(X[:, sbfs_subset, :], y)
 984        initial_model = initial_results.model
 985        initial_preds = initial_results.cv_preds
 986        initial_accuracy = initial_results.accuracy
 987        initial_precision = initial_results.precision
 988        initial_recall = initial_results.recall
 989
 990        if metric == "accuracy":
 991            initial_performance = initial_accuracy
 992        elif metric == "precision":
 993            initial_performance = initial_precision
 994        elif metric == "recall":
 995            initial_performance = initial_recall
 996
 997        # Best
 998        best_channel_subset = initial_channels
 999        best_model = initial_model
1000        best_performance = initial_performance
1001        best_preds = initial_preds
1002        best_accuracy = initial_accuracy
1003        best_precision = initial_precision
1004        best_recall = initial_recall
1005
1006        performance_at_n_channels[len(initial_channels) - 1] = initial_performance
1007        best_subset_at_n_channels[len(initial_channels) - 1] = initial_channels
1008
1009    # If not possible then set the initial performance to 0
1010    except ValueError:
1011        best_channel_subset = []
1012        best_model = None
1013        best_performance = 0
1014        best_preds = []
1015        best_accuracy = 0
1016        best_precision = 0
1017        best_recall = 0
1018
1019    preds = []
1020    accuracy = 0
1021    precision = 0
1022    recall = 0
1023
1024    while stop_criterion is False:
1025        # Exclusion Step
1026        sets_to_try = []
1027        X_to_try = []
1028        for c in sbfs_subset:
1029            set_to_try = sbfs_subset.copy()
1030            set_to_try.remove(c)
1031            set_to_try.sort()
1032
1033            # Only try sets that have not been tried before
1034            if set_to_try not in all_sets_tried:
1035                sets_to_try.append(set_to_try)
1036                all_sets_tried.append(set_to_try)
1037            else:
1038                continue
1039
1040            # Get the new subset of data
1041            new_subset_data = np.zeros((n_trials, len(set_to_try), n_samples))
1042            for subset_idx, channel_number in enumerate(set_to_try):
1043                channel_data = X[:, channel_number, :]
1044                new_subset_data[:, subset_idx, :] = channel_data
1045
1046            # Add to list of all subsets of X to try
1047            X_to_try.append(new_subset_data)
1048
1049        # run the kernel function on all cores
1050        outputs = Parallel(n_jobs=n_jobs)(
1051            delayed(kernel_func)(Xtest, y) for Xtest in X_to_try
1052        )
1053
1054        # [all_sets_tried.append(set.sort()) for set in sets_to_try]
1055
1056        models = []
1057        predictions = []
1058        accuracies = []
1059        precisions = []
1060        recalls = []
1061
1062        # Extract the outputs
1063        for output in outputs:
1064            models.append(output.model)
1065            predictions.append(output.cv_preds)
1066            accuracies.append(output.accuracy)
1067            precisions.append(output.precision)
1068            recalls.append(output.recall)
1069
1070        # Get the performance metric
1071        if metric == "accuracy":
1072            performances = accuracies
1073        elif metric == "precision":
1074            performances = precisions
1075        elif metric == "recall":
1076            performances = recalls
1077        else:
1078            logger.warning("Performance metric invalid, defaulting to accuracy")
1079            performances = accuracies
1080
1081        best_round_performance = np.max(performances)
1082        best_set_index = accuracies.index(best_round_performance)
1083
1084        sbfs_subset = sets_to_try[best_set_index]
1085        new_channel_subset = [channel_labels[c] for c in sbfs_subset]
1086        model = models[best_set_index]
1087        preds = predictions[best_set_index]
1088        accuracy = accuracies[best_set_index]
1089        precision = precisions[best_set_index]
1090        recall = recalls[best_set_index]
1091
1092        logger.debug("Removed a channel")
1093        logger.debug("New subset: %s", new_channel_subset)
1094        logger.debug("Accuracy: %s", accuracy)
1095        logger.debug("Accuracies: %s", [float(acc) for acc in accuracies])
1096
1097        current_performance = best_round_performance
1098
1099        # If this is the best perfomance at n_channels
1100        if performance_at_n_channels[len(sbfs_subset) - 1] < current_performance:
1101            performance_at_n_channels[len(sbfs_subset) - 1] = current_performance
1102            best_subset_at_n_channels[len(sbfs_subset) - 1] = sbfs_subset
1103
1104        p_delta = current_performance - previous_performance
1105        previous_performance = current_performance
1106
1107        # If the performance is the best so far, then save it as the best
1108        if current_performance > best_performance:
1109            best_channel_subset = new_channel_subset
1110            best_model = model
1111            best_performance = current_performance
1112            best_preds = preds
1113            best_accuracy = accuracy
1114            best_precision = precision
1115            best_recall = recall
1116        elif current_performance >= best_performance and len(new_channel_subset) < len(
1117            best_channel_subset
1118        ):
1119            best_channel_subset = new_channel_subset
1120            best_model = model
1121            best_performance = current_performance
1122            best_preds = preds
1123            best_accuracy = accuracy
1124            best_precision = precision
1125            best_recall = recall
1126
1127        if record_performance:
1128            new_channel_subset.sort()
1129            results_df.loc[step] = [
1130                step,
1131                time.time() - start_time,
1132                len(new_channel_subset),
1133                "".join(new_channel_subset),
1134                len(sets_to_try),
1135                accuracy,
1136                precision,
1137                recall,
1138            ]
1139
1140        step += 1
1141
1142        # Conditional Inclusion
1143        while stop_criterion is False:
1144            # Get the length of the set if we were to include an additional channel
1145            length_of_resultant_set = len(sbfs_subset) + 1
1146            if (
1147                length_of_resultant_set > max_channels
1148                or length_of_resultant_set == len(channel_labels)
1149            ):
1150                break
1151
1152            # Check all of the possible inclusions that do not lead to a previously tested subset
1153            potential_channels_to_add = list(range(len(channel_labels)))
1154            [potential_channels_to_add.remove(c) for c in sbfs_subset]
1155
1156            sets_to_try = []
1157            X_to_try = []
1158
1159            for c in potential_channels_to_add:
1160                set_to_try = sbfs_subset.copy()
1161                set_to_try.append(c)
1162                set_to_try.sort()
1163
1164                if set_to_try not in all_sets_tried:
1165                    sets_to_try.append(set_to_try)
1166                    all_sets_tried.append(set_to_try)
1167
1168                else:
1169                    continue
1170
1171                # Get the new subset of data
1172                new_subset_data = np.zeros((n_trials, len(set_to_try), n_samples))
1173                for subset_idx, channel_number in enumerate(set_to_try):
1174                    channel_data = X[:, channel_number, :]
1175                    new_subset_data[:, subset_idx, :] = channel_data
1176
1177                # Add to list of all subsets of X to try
1178                X_to_try.append(new_subset_data)
1179
1180            if X_to_try == []:
1181                break
1182
1183            # run the kernel on the new sets
1184            outputs = Parallel(n_jobs=n_jobs)(
1185                delayed(kernel_func)(Xtest, y) for Xtest in X_to_try
1186            )
1187
1188            models = []
1189            predictions = []
1190            accuracies = []
1191            precisions = []
1192            recalls = []
1193            performances = []
1194
1195            # Extract the outputs
1196            for output in outputs:
1197                models.append(output.model)
1198                predictions.append(output.cv_preds)
1199                accuracies.append(output.accuracy)
1200                precisions.append(output.precision)
1201                recalls.append(output.recall)
1202
1203            # Get the performance metric
1204            if metric == "accuracy":
1205                performances = accuracies
1206            elif metric == "precision":
1207                performances = precisions
1208            elif metric == "recall":
1209                performances = recalls
1210            else:
1211                logger.warning("Performance metric invalid, defaulting to accuracy")
1212                performances = accuracies
1213
1214            best_round_performance = np.max(performances)
1215            best_set_index = accuracies.index(best_round_performance)
1216
1217            # if performance is better the best performance at n_channels
1218            if (
1219                performance_at_n_channels[length_of_resultant_set - 1]
1220                < best_round_performance
1221            ):
1222                sbfs_subset = sets_to_try[best_set_index]
1223                new_channel_subset = [channel_labels[c] for c in sbfs_subset]
1224                model = models[best_set_index]
1225                preds = predictions[best_set_index]
1226                accuracy = accuracies[best_set_index]
1227                precision = precisions[best_set_index]
1228                recall = recalls[best_set_index]
1229
1230                logger.debug("Added back a channel")
1231                logger.debug("New subset: %s", new_channel_subset)
1232                logger.debug("Accuracy: %s", accuracy)
1233                logger.debug("Accuracies: %s", [float(acc) for acc in accuracies])
1234
1235                current_performance = best_round_performance
1236
1237                p_delta = current_performance - previous_performance
1238                previous_performance = current_performance
1239
1240                # ADD Memory here
1241                if current_performance > best_performance:
1242                    best_channel_subset = new_channel_subset
1243                    best_model = model
1244                    best_performance = current_performance
1245                    best_preds = preds
1246                    best_accuracy = accuracy
1247                    best_precision = precision
1248                    best_recall = recall
1249                elif current_performance >= best_performance and len(
1250                    new_channel_subset
1251                ) < len(best_channel_subset):
1252                    best_channel_subset = new_channel_subset
1253                    best_model = model
1254                    best_performance = current_performance
1255                    best_preds = preds
1256                    best_accuracy = accuracy
1257                    best_precision = precision
1258                    best_recall = recall
1259
1260                new_channel_subset.sort()
1261                results_df.loc[step] = [
1262                    step,
1263                    time.time() - start_time,
1264                    len(new_channel_subset),
1265                    "".join(new_channel_subset),
1266                    len(sets_to_try),
1267                    accuracy,
1268                    precision,
1269                    recall,
1270                ]
1271                step += 1
1272
1273                performance_at_n_channels[length_of_resultant_set - 1] = (
1274                    current_performance
1275                )
1276                best_subset_at_n_channels[length_of_resultant_set - 1] = sbfs_subset
1277
1278            # if no performance gains, then stop conditional inclusion
1279            else:
1280                break
1281
1282            # Check stopping criterion
1283            stop_criterion = __check_stopping_criterion(
1284                "SBFS",
1285                time.time() - start_time,
1286                len(new_channel_subset),
1287                p_delta,
1288                max_time,
1289                min_channels,
1290                max_channels,
1291                performance_delta,
1292            )
1293
1294        stop_criterion = __check_stopping_criterion(
1295            "SBFS",
1296            time.time() - start_time,
1297            len(new_channel_subset),
1298            p_delta,
1299            max_time,
1300            min_channels,
1301            max_channels,
1302            performance_delta,
1303        )
1304
1305        # Break if SBFS subset is 1 channel
1306        if len(sbfs_subset) == 1:
1307            break
1308
1309    new_channel_subset = [channel_labels[c] for c in sbfs_subset]
1310
1311    logger.debug("Best channel subset: %s", best_channel_subset)
1312    logger.debug("%s : %s", metric, best_performance)
1313    logger.debug("Time to optimal subset: %s s", time.time() - start_time)
1314
1315    if record_performance is True:
1316        logger.info(results_df)
1317
1318    return ChannelSelectionOutput(
1319        best_channel_subset,
1320        best_model,
1321        best_preds,
1322        best_accuracy,
1323        best_precision,
1324        best_recall,
1325        results_df,
1326    )
1327
1328
1329def __sffs(
1330    kernel_func,
1331    X,
1332    y,
1333    channel_labels,
1334    metric,
1335    initial_channels,
1336    max_time,
1337    min_channels,
1338    max_channels,
1339    performance_delta,
1340    n_jobs,
1341    record_performance,
1342):
1343    """The Sequential Forward Floating Selection (SFFS) method for channel selection.
1344
1345    Parameters
1346    ----------
1347    kernel_func : function
1348        The classification kernel function which does feature extraction
1349        and classification.
1350        Different functions  are used for MI, P300, SSVEP, etc.
1351    X : numpy.ndarray
1352        Training data for the classifier as trials of EEG data.
1353        3D array containing data with `float` type.
1354
1355        shape = (`n_trials`,`n_channels`,`n_samples`)
1356    y : numpy.ndarray
1357        Training labels for the classifier.
1358        1D array.
1359
1360        shape = (`n_trials`)
1361    channel_labels: list of `str`
1362        The set of channel labels corresponding to `n_channels`.
1363        A list of strings with length = `n_channels`.
1364    metric : str
1365        The metric used to measure the "goodness" of the trained classifier.
1366    initial_channels : list of `str`
1367        Initial guess of channels.
1368    max_time : int
1369        The maxiumum amount of time, in seconds, that the function will
1370        search for the optimal solution.
1371    min_channels : int
1372        The minimum number of channels.
1373    max_channels : int
1374        The maximum number of channels.
1375    performance_delta : float
1376        The performance delta under which the algorithm is considered to
1377        be close enough to optimal.
1378    n_jobs : int
1379        The number of threads to dedicate to this calculation.
1380    record_performance : bool
1381        Flag on whether or not to record performance metrics at each step.
1382
1383    Returns
1384    -------
1385    channelSelectionOutput : ChannelSelectionOutput
1386        ChannelSelectionOutput object containing the following attributes:
1387            best_channel_subset : list of `str`
1388                The new best channel subset from the list of `channel_labels`.
1389            self.clf : classifier
1390                The trained classification model.
1391            preds : numpy.ndarray
1392                The predictions from the model.
1393                1D array with the same shape as `y`.
1394                shape = (`n_trials`)
1395            accuracy : float
1396                The accuracy of the trained classification model.
1397            precision : float
1398                The precision of the trained classification model.
1399            recall : float
1400                The recall of the trained classification model.
1401            results_df : pandas.DataFrame
1402                The dataframe containing the results of each step of channel selection.
1403
1404
1405    """
1406    results_df = pd.DataFrame(
1407        columns=[
1408            "Step",
1409            "Time",
1410            "N Channels",
1411            "Channel Subset",
1412            "Unique Combinations Tested in Step",
1413            "Accuracy",
1414            "Precision",
1415            "Recall",
1416        ]
1417    )
1418    step = 1
1419
1420    start_time = time.time()
1421
1422    n_trials, n_channels, n_samples = X.shape
1423    sffs_subset = []
1424    all_sets_tried = []  # set of all channels that have been tried
1425
1426    for i, c in enumerate(channel_labels):
1427        if c in initial_channels:
1428            sffs_subset.append(i)
1429
1430    performance_at_n_channels = np.zeros(len(channel_labels))
1431    performance_at_n_channels[: min_channels - 1] = np.inf
1432    best_subset_at_n_channels = [0] * len(channel_labels)
1433
1434    previous_performance = 0
1435
1436    stop_criterion = False
1437
1438    # Get the performance of the initial subset, if possible
1439    try:
1440        initial_results = kernel_func(X[:, sffs_subset, :], y)
1441        initial_model = initial_results.model
1442        initial_preds = initial_results.cv_preds
1443        initial_accuracy = initial_results.accuracy
1444        initial_precision = initial_results.precision
1445        initial_recall = initial_results.recall
1446
1447        if metric == "accuracy":
1448            initial_performance = initial_accuracy
1449        elif metric == "precision":
1450            initial_performance = initial_precision
1451        elif metric == "recall":
1452            initial_performance = initial_recall
1453
1454        # Best
1455        best_channel_subset = initial_channels
1456        best_model = initial_model
1457        best_performance = initial_performance
1458        best_preds = initial_preds
1459        best_accuracy = initial_accuracy
1460        best_precision = initial_precision
1461        best_recall = initial_recall
1462
1463        performance_at_n_channels[len(initial_channels) - 1] = initial_performance
1464        best_subset_at_n_channels[len(initial_channels) - 1] = initial_channels
1465
1466    # If not possible then set the initial performance to 0
1467    except ValueError:
1468        best_channel_subset = []
1469        best_model = None
1470        best_performance = 0
1471        best_preds = []
1472        best_accuracy = 0
1473        best_precision = 0
1474        best_recall = 0
1475
1476    preds = []
1477    accuracy = 0
1478    precision = 0
1479    recall = 0
1480
1481    pass_stopping_criterion = False
1482
1483    # TODO Test the initial subset
1484
1485    while stop_criterion is False:
1486        sets_to_try = []
1487        X_to_try = []
1488        for c in range(n_channels):
1489            if c not in sffs_subset:
1490                set_to_try = sffs_subset.copy()
1491                set_to_try.append(c)
1492                sets_to_try.append(set_to_try)
1493
1494                # Get the new subset of data
1495                new_subset_data = np.zeros((n_trials, len(set_to_try), n_samples))
1496                for subset_idx, channel_number in enumerate(set_to_try):
1497                    channel_data = X[:, channel_number, :]
1498                    new_subset_data[:, subset_idx, :] = channel_data
1499
1500                # Add to list of all subsets of X to try
1501                X_to_try.append(new_subset_data)
1502
1503        # This handles the multiprocessing to check multiple channel combinations at once if n_jobs > 1
1504        outputs = Parallel(n_jobs=n_jobs)(
1505            delayed(kernel_func)(Xtest, y) for Xtest in X_to_try
1506        )
1507
1508        models = []
1509        predictions = []
1510        accuracies = []
1511        precisions = []
1512        recalls = []
1513
1514        # Extract the outputs
1515        for output in outputs:
1516            models.append(output.model)
1517            predictions.append(output.cv_preds)
1518            accuracies.append(output.accuracy)
1519            precisions.append(output.precision)
1520            recalls.append(output.recall)
1521
1522        # Get the performance metric
1523        if metric == "accuracy":
1524            performances = accuracies
1525        elif metric == "precision":
1526            performances = precisions
1527        elif metric == "recall":
1528            performances = recalls
1529        else:
1530            logger.warning("Performance metric invalid, defaulting to accuracy")
1531            performances = accuracies
1532
1533        best_round_performance = np.max(performances)
1534        best_set_index = accuracies.index(best_round_performance)
1535
1536        sffs_subset = sets_to_try[best_set_index]
1537        new_channel_subset = [channel_labels[c] for c in sffs_subset]
1538        model = models[best_set_index]
1539        preds = predictions[best_set_index]
1540        accuracy = accuracies[best_set_index]
1541        precision = precisions[best_set_index]
1542        recall = recalls[best_set_index]
1543
1544        current_performance = best_round_performance
1545
1546        logger.debug("Removed a channel")
1547        logger.debug("New subset: %s", new_channel_subset)
1548        logger.debug("Accuracy: %s", accuracy)
1549        logger.debug("Accuracies: %s", [float(acc) for acc in accuracies])
1550
1551        # If this is the best perfomance at n_channels
1552        if performance_at_n_channels[len(sffs_subset) - 1] < current_performance:
1553            performance_at_n_channels[len(sffs_subset) - 1] = current_performance
1554            best_subset_at_n_channels[len(sffs_subset) - 1] = sffs_subset
1555
1556        p_delta = current_performance - previous_performance
1557        previous_performance = current_performance
1558
1559        if current_performance > best_performance:
1560            best_channel_subset = new_channel_subset
1561            best_model = model
1562            best_performance = current_performance
1563            best_preds = preds
1564            best_accuracy = accuracy
1565            best_precision = precision
1566            best_recall = recall
1567        elif current_performance >= best_performance and len(new_channel_subset) < len(
1568            best_channel_subset
1569        ):
1570            best_channel_subset = new_channel_subset
1571            best_model = model
1572            best_performance = current_performance
1573            best_preds = preds
1574            best_accuracy = accuracy
1575            best_precision = precision
1576            best_recall = recall
1577
1578        new_channel_subset.sort()
1579        results_df.loc[step] = [
1580            step,
1581            time.time() - start_time,
1582            len(new_channel_subset),
1583            "".join(new_channel_subset),
1584            len(sets_to_try),
1585            accuracy,
1586            precision,
1587            recall,
1588        ]
1589
1590        step += 1
1591
1592        # Conditional Exclusion
1593        while stop_criterion is False:
1594            # Get the length of the set if we were to include an additional channel
1595            length_of_resultant_set = len(sffs_subset) - 1
1596            if length_of_resultant_set < min_channels or length_of_resultant_set == 0:
1597                break
1598
1599            # If length of resultant set equal to min channels the pass stopping criterion
1600            if length_of_resultant_set == min_channels:
1601                pass_stopping_criterion = True
1602
1603            # Check all of the possible inclusions that do not lead to a previously tested subset
1604            potential_channels_to_add = list(range(len(channel_labels)))
1605            [potential_channels_to_add.remove(c) for c in sffs_subset]
1606
1607            sets_to_try = []
1608            X_to_try = []
1609
1610            for c in sffs_subset:
1611                set_to_try = sffs_subset.copy()
1612                set_to_try.remove(c)
1613                set_to_try.sort()
1614
1615                # Only try sets that have not been tried before
1616                if set_to_try not in all_sets_tried:
1617                    sets_to_try.append(set_to_try)
1618                    all_sets_tried.append(set_to_try)
1619                else:
1620                    continue
1621
1622                # Get the new subset of data
1623                new_subset_data = np.zeros((n_trials, len(set_to_try), n_samples))
1624                for subset_idx, channel_number in enumerate(set_to_try):
1625                    channel_data = X[:, channel_number, :]
1626                    new_subset_data[:, subset_idx, :] = channel_data
1627
1628                # Add to list of all subsets of X to try
1629                X_to_try.append(new_subset_data)
1630
1631            if X_to_try == []:
1632                break
1633
1634            # run the kernel on the new sets
1635            outputs = Parallel(n_jobs=n_jobs)(
1636                delayed(kernel_func)(Xtest, y) for Xtest in X_to_try
1637            )
1638
1639            # [all_sets_tried.append(set.sort()) for set in sets_to_try]
1640
1641            models = []
1642            predictions = []
1643            accuracies = []
1644            precisions = []
1645            recalls = []
1646            performances = []
1647
1648            # Extract the outputs
1649            for output in outputs:
1650                models.append(output.model)
1651                predictions.append(output.cv_preds)
1652                accuracies.append(output.accuracy)
1653                precisions.append(output.precision)
1654                recalls.append(output.recall)
1655
1656            # Get the performance metric
1657            if metric == "accuracy":
1658                performances = accuracies
1659            elif metric == "precision":
1660                performances = precisions
1661            elif metric == "recall":
1662                performances = recalls
1663            else:
1664                logger.warning("Performance metric invalid, defaulting to accuracy")
1665                performances = accuracies
1666
1667            best_round_performance = np.max(performances)
1668            best_set_index = accuracies.index(best_round_performance)
1669
1670            # if performance is better at the resultant channel length
1671            if (
1672                performance_at_n_channels[length_of_resultant_set - 1]
1673                < best_round_performance
1674            ):
1675                sffs_subset = sets_to_try[best_set_index]
1676                new_channel_subset = [channel_labels[c] for c in sffs_subset]
1677                model = models[best_set_index]
1678                preds = predictions[best_set_index]
1679                accuracy = accuracies[best_set_index]
1680                precision = precisions[best_set_index]
1681                recall = recalls[best_set_index]
1682
1683                logger.debug("Added back a channel")
1684                logger.debug("New subset: %s", new_channel_subset)
1685                logger.debug("Accuracy: %s", accuracy)
1686                logger.debug("Accuracies: %s", [float(acc) for acc in accuracies])
1687
1688                current_performance = best_round_performance
1689
1690                p_delta = current_performance - previous_performance
1691                previous_performance = current_performance
1692
1693                if current_performance > best_performance:
1694                    best_channel_subset = new_channel_subset
1695                    best_model = model
1696                    best_performance = current_performance
1697                    best_preds = preds
1698                    best_accuracy = accuracy
1699                    best_precision = precision
1700                    best_recall = recall
1701                elif current_performance >= best_performance and len(
1702                    new_channel_subset
1703                ) < len(best_channel_subset):
1704                    best_channel_subset = new_channel_subset
1705                    best_model = model
1706                    best_performance = current_performance
1707                    best_preds = preds
1708                    best_accuracy = accuracy
1709                    best_precision = precision
1710                    best_recall = recall
1711
1712                if record_performance:
1713                    new_channel_subset.sort()
1714                    results_df.loc[step] = [
1715                        step,
1716                        time.time() - start_time,
1717                        len(new_channel_subset),
1718                        "".join(new_channel_subset),
1719                        len(sets_to_try),
1720                        accuracy,
1721                        precision,
1722                        recall,
1723                    ]
1724                    step += 1
1725
1726                performance_at_n_channels[length_of_resultant_set - 1] = (
1727                    current_performance
1728                )
1729                best_subset_at_n_channels[length_of_resultant_set - 1] = sffs_subset
1730
1731            # if no performance gains, then stop conditional exclusion
1732            else:
1733                break
1734
1735            # Check stopping criterion
1736            if pass_stopping_criterion is False:
1737                stop_criterion = __check_stopping_criterion(
1738                    "SFFS",
1739                    time.time() - start_time,
1740                    len(new_channel_subset),
1741                    p_delta,
1742                    max_time,
1743                    min_channels,
1744                    max_channels,
1745                    performance_delta,
1746                )
1747
1748        if pass_stopping_criterion:
1749            pass_stopping_criterion = False
1750            continue
1751        else:
1752            stop_criterion = __check_stopping_criterion(
1753                "SFFS",
1754                time.time() - start_time,
1755                len(new_channel_subset),
1756                p_delta,
1757                max_time,
1758                min_channels,
1759                max_channels,
1760                performance_delta,
1761            )
1762
1763    new_channel_subset = [channel_labels[c] for c in sffs_subset]
1764
1765    logger.debug("Best channel subset: %s", best_channel_subset)
1766    logger.debug("%s : %s", metric, best_performance)
1767    logger.debug("Time to optimal subset: %s s", time.time() - start_time)
1768
1769    if record_performance is True:
1770        logger.info(results_df)
1771
1772    return ChannelSelectionOutput(
1773        best_channel_subset,
1774        best_model,
1775        best_preds,
1776        best_accuracy,
1777        best_precision,
1778        best_recall,
1779        results_df,
1780    )
@dataclass
class ChannelSelectionOutput:
30@dataclass
31class ChannelSelectionOutput:
32    """Dataclass to store output from channel selection.
33
34    Parameters
35    ----------
36    best_channel_subset : list of `str`
37        The best channel subset from the list of 'channel_labels'.
38    best_model : classifier
39        The trained classification model.
40    best_preds : numpy.ndarray
41        The predictions from the model.
42    best_accuracy : float
43        The accuracy of the trained classification model.
44    best_precision : float
45        The precision of the trained classification model.
46    best_recall : float
47        The recall of the trained classification model.
48    results_df : pandas.DataFrame
49        A dataframe containing the performance metrics at each step.
50
51    """
52
53    best_channel_subset: list = field(default_factory=list)
54    best_model: Pipeline = field(default=None)
55    best_preds: np.ndarray = field(default_factory=np.ndarray)
56    best_accuracy: float = field(default=0.0)
57    best_precision: float = field(default=0.0)
58    best_recall: float = field(default=0.0)
59    results_df: pd.DataFrame = field(default_factory=pd.DataFrame)

Dataclass to store output from channel selection.

Parameters
  • best_channel_subset (list of str): The best channel subset from the list of 'channel_labels'.
  • best_model (classifier): The trained classification model.
  • best_preds (numpy.ndarray): The predictions from the model.
  • best_accuracy (float): The accuracy of the trained classification model.
  • best_precision (float): The precision of the trained classification model.
  • best_recall (float): The recall of the trained classification model.
  • results_df (pandas.DataFrame): A dataframe containing the performance metrics at each step.
ChannelSelectionOutput( best_channel_subset: list = <factory>, best_model: sklearn.pipeline.Pipeline = None, best_preds: numpy.ndarray = <factory>, best_accuracy: float = 0.0, best_precision: float = 0.0, best_recall: float = 0.0, results_df: pandas.core.frame.DataFrame = <factory>)
best_channel_subset: list
best_model: sklearn.pipeline.Pipeline = None
best_preds: numpy.ndarray
best_accuracy: float = 0.0
best_precision: float = 0.0
best_recall: float = 0.0
results_df: pandas.core.frame.DataFrame
def channel_selection_by_method( kernel_func, X, y, channel_labels, method='SBS', metric='accuracy', initial_channels=[], max_time=999, min_channels=1, max_channels=999, performance_delta=0.001, n_jobs=1, record_performance=True):
 62def channel_selection_by_method(
 63    kernel_func,
 64    X,
 65    y,
 66    channel_labels,
 67    method="SBS",
 68    metric="accuracy",
 69    initial_channels=[],
 70    max_time=999,
 71    min_channels=1,
 72    max_channels=999,
 73    performance_delta=0.001,
 74    n_jobs=1,
 75    record_performance=True,
 76):
 77    """Passes the BCI kernel function into a wrapper defined by `method`.
 78
 79    Parameters
 80    ----------
 81    kernel_func : function
 82        The classification kernel function which does feature extraction
 83        and classification.
 84        Different functions  are used for MI, P300, SSVEP, etc.
 85    X : numpy.ndarray
 86        Training data for the classifier as trials of EEG data.
 87        3D array containing data with `float` type.
 88
 89        shape = (`n_trials`,`n_channels`,`n_samples`)
 90    y : numpy.ndarray
 91        Training labels for the classifier.
 92        1D array.
 93
 94        shape = (`n_trials`)
 95    channel_labels : list of `str`
 96        The set of channel labels corresponding to `n_channels`.
 97        A list of strings with length = `n_channels`.
 98    method = str, *optional*
 99        The wrapper method. Options are `"SBS"` or `"SBFS"`.
100        - Default is `"SBS"`.
101    metric : str, *optional*
102        The metric used to measure the "goodness" of the trained classifier.
103        - Default is `"accuracy"`.
104    initial_channels : list of `str`, *optional*
105        Initial guess of channels.
106        - Defaults is `[]`. Assigns an empty set for forward selections,
107        and a full set for backward selections.
108    max_time : int, *optional*
109        The maxiumum amount of time, in seconds, that the function will
110        search for the optimal solution.
111        - Default is `999` seconds.
112    min_channels : int, *optional*
113        The minimum number of channels.
114        - Default is `1`.
115    max_channels : int, *optional*
116        The maximum number of channels.
117        - Default is `999`.
118    performance_delta : float, *optional*
119        The performance delta under which the algorithm is considered to
120        be close enough to optimal.
121        - Default is `0.001`.
122    n_jobs : int, *optional*
123        The number of threads to dedicate to this calculation.
124        - Default is `1`.
125    record_performance : bool, *optional*
126        Whether or not to record the performance of the channel selection
127        - Default is `True`.
128
129    Returns
130    -------
131    channelSelectionOutput : ChannelSelectionOutput
132        ChannelSelectionOutput object containing the following attributes:
133            best_channel_subset : list of `str`
134                The new best channel subset from the list of `channel_labels`.
135            self.clf : classifier
136                The trained classification model.
137            preds : numpy.ndarray
138                The predictions from the model.
139                1D array with the same shape as `y`.
140                shape = (`n_trials`)
141            accuracy : float
142                The accuracy of the trained classification model.
143            precision : float
144                The precision of the trained classification model.
145            recall : float
146                The recall of the trained classification model.
147            results_df : pandas.DataFrame
148                The dataframe containing the results of each step of channel selection.
149
150    """
151
152    # max length can't be greater than the length of channel labels
153    if max_channels > len(channel_labels):
154        logger.debug(
155            "Maximum number of channels must be less than or equal to "
156            + "the number of channels. Setting to number of channels."
157        )
158        max_channels = len(channel_labels)
159
160    # min length can't be less than 1
161    if min_channels < 1:
162        logger.debug("Minimum number of channels must be greater than 0. Setting to 1.")
163        min_channels = 1
164
165    logger.debug("Running channel selection method: %s", method)
166    if method == "SBS":
167        if initial_channels == []:
168            initial_channels = channel_labels
169        logger.debug("Initial subset: %s", initial_channels)
170
171        # pass arguments to SBS
172        return __sbs(
173            kernel_func,
174            X,
175            y,
176            channel_labels=channel_labels,
177            metric=metric,
178            initial_channels=initial_channels,
179            max_time=max_time,
180            min_channels=min_channels,
181            max_channels=max_channels,
182            performance_delta=performance_delta,
183            n_jobs=n_jobs,
184            record_performance=record_performance,
185        )
186
187    elif method == "SFS":
188        logger.debug("Initial subset: %s", initial_channels)
189
190        # pass arguments to SBS
191        return __sfs(
192            kernel_func,
193            X,
194            y,
195            channel_labels=channel_labels,
196            metric=metric,
197            initial_channels=initial_channels,
198            max_time=max_time,
199            min_channels=min_channels,
200            max_channels=max_channels,
201            performance_delta=performance_delta,
202            n_jobs=n_jobs,
203            record_performance=record_performance,
204        )
205
206    elif method == "SBFS":
207        if initial_channels == []:
208            initial_channels = channel_labels
209        logger.debug("Initial subset: %s", initial_channels)
210
211        # pass arguments to SBS
212        return __sbfs(
213            kernel_func,
214            X,
215            y,
216            channel_labels=channel_labels,
217            metric=metric,
218            initial_channels=initial_channels,
219            max_time=max_time,
220            min_channels=min_channels,
221            max_channels=max_channels,
222            performance_delta=performance_delta,
223            n_jobs=n_jobs,
224            record_performance=record_performance,
225        )
226
227    elif method == "SFFS":
228        logger.debug("Initial subset: %s", initial_channels)
229
230        # pass arguments to SBS
231        return __sffs(
232            kernel_func,
233            X,
234            y,
235            channel_labels=channel_labels,
236            metric=metric,
237            initial_channels=initial_channels,
238            max_time=max_time,
239            min_channels=min_channels,
240            max_channels=max_channels,
241            performance_delta=performance_delta,
242            n_jobs=n_jobs,
243            record_performance=record_performance,
244        )

Passes the BCI kernel function into a wrapper defined by method.

Parameters
  • kernel_func (function): The classification kernel function which does feature extraction and classification. Different functions are used for MI, P300, SSVEP, etc.
  • X (numpy.ndarray): Training data for the classifier as trials of EEG data. 3D array containing data with float type.

    shape = (n_trials,n_channels,n_samples)

  • y (numpy.ndarray): Training labels for the classifier. 1D array.

    shape = (n_trials)

  • channel_labels (list of str): The set of channel labels corresponding to n_channels. A list of strings with length = n_channels.
  • method = str, optional: The wrapper method. Options are "SBS" or "SBFS".
    • Default is "SBS".
  • metric (str, optional): The metric used to measure the "goodness" of the trained classifier.
    • Default is "accuracy".
  • initial_channels (list of str, optional): Initial guess of channels.
    • Defaults is []. Assigns an empty set for forward selections, and a full set for backward selections.
  • max_time (int, optional): The maxiumum amount of time, in seconds, that the function will search for the optimal solution.
    • Default is 999 seconds.
  • min_channels (int, optional): The minimum number of channels.
    • Default is 1.
  • max_channels (int, optional): The maximum number of channels.
    • Default is 999.
  • performance_delta (float, optional): The performance delta under which the algorithm is considered to be close enough to optimal.
    • Default is 0.001.
  • n_jobs (int, optional): The number of threads to dedicate to this calculation.
    • Default is 1.
  • record_performance (bool, optional): Whether or not to record the performance of the channel selection
    • Default is True.
Returns
  • channelSelectionOutput (ChannelSelectionOutput): ChannelSelectionOutput object containing the following attributes: best_channel_subset : list of str The new best channel subset from the list of channel_labels. self.clf : classifier The trained classification model. preds : numpy.ndarray The predictions from the model. 1D array with the same shape as y. shape = (n_trials) accuracy : float The accuracy of the trained classification model. precision : float The precision of the trained classification model. recall : float The recall of the trained classification model. results_df : pandas.DataFrame The dataframe containing the results of each step of channel selection.