bci_essentials.classification.switch_mdm_classifier

*Switch MDM Classifier *

This is a switch_classifier.

  • This means that classification occurs between neutral and one other label (i.e. Binary classification).
  • The produced probabilities between labels are then compared for one final classification.

ToDo: Missing correct implementation of this classifier'

  1"""**Switch MDM Classifier **
  2
  3This is a switch_classifier.
  4- This means that classification occurs between neutral and one other
  5label (i.e. Binary classification).
  6- The produced probabilities between labels are then compared for one
  7final classification.
  8
  9**`ToDo`: Missing correct implementation of this classifier**'
 10
 11"""
 12
 13# Stock libraries
 14import numpy as np
 15from sklearn.model_selection import StratifiedKFold
 16from sklearn.pipeline import Pipeline
 17
 18# from sklearn.metrics import confusion_matrix, precision_score, recall_score
 19from sklearn import preprocessing
 20from pyriemann.classification import MDM
 21
 22# Import bci_essentials modules and methods
 23from ..classification.generic_classifier import GenericClassifier, Prediction
 24from ..utils.logger import Logger  # Logger wrapper
 25
 26# Instantiate a logger for the module at the default level of logging.INFO
 27# Logs to bci_essentials.__module__) where __module__ is the name of the module
 28logger = Logger(name=__name__)
 29
 30
 31# TODO: Missing correct implementation of this classifier
 32class SwitchMdmClassifier(GenericClassifier):
 33    """Switch MDM Classifier class (*inherits from GenericClassifier*)."""
 34
 35    def set_switch_classifier_mdm_settings(
 36        self,
 37        n_splits=2,
 38        rebuild=True,
 39        random_seed=42,
 40        n_jobs=1,
 41        activation_main="relu",
 42        activation_class="sigmoid",
 43    ):
 44        """Set the Switch Classifier MDM settings.
 45
 46        Parameters
 47        ----------
 48        n_splits : int, *optional*
 49            Number of folds for cross-validation.
 50            - Default is `2`.
 51        rebuild : bool, *optional*
 52            Rebuild the classifier each time. *More description needed*.
 53            - Default is `True`.
 54        random_seed : int, *optional*
 55            Random seed.
 56            - Default is `42`.
 57        n_jobs : int, *optional*
 58            The number of threads to dedicate to this calculation.
 59        activation_main : str, *optional*
 60            Activation function for hidden layers.
 61            - Default is `relu`.
 62        activation_class : str, *optional*
 63            Activation function for the output layer.
 64            - Default is `sigmoid`.
 65
 66        Returns
 67        -------
 68        `None`
 69            Models created are used in `fit()`.
 70
 71        """
 72        self.n_splits = n_splits
 73        self.cv = StratifiedKFold(
 74            n_splits=self.n_splits, shuffle=True, random_state=random_seed
 75        )
 76        self.rebuild = rebuild
 77
 78        mdm = MDM(metric=dict(mean="riemann", distance="riemann"), n_jobs=n_jobs)
 79        self.clf_model = Pipeline([("MDM", mdm)])
 80        self.clf = Pipeline([("MDM", mdm)])
 81        # self.clf0and1 = MDM()
 82
 83    def fit(self):
 84        """Fit the model.
 85
 86        Returns
 87        -------
 88        `None`
 89            Models created used in `predict()`.
 90
 91        """
 92        # get dimensions
 93        n_trials, n_channels, n_samples = self.X.shape
 94
 95        # do the rest of the training if train_free is false
 96        X = np.array(self.X)
 97        y = np.array(self.y)
 98
 99        # find the number of classes in y there shoud be N + 1, where N is the number of objects in the scene and also the number of classifiers
100        self.num_classifiers = len(list(np.unique(self.y))) - 1
101        logger.info("Number of classes: %s", self.num_classifiers)
102
103        # make a list to hold all of the classifiers
104        self.clfs = []
105
106        # loop through and build the classifiers
107        for i in range(self.num_classifiers):
108            # take a subset / do spatial filtering
109            X = X[:, :, :]  # Does nothing for now
110
111            class_indices = np.logical_or(y == 0, y == (i + 1))
112            X_class = X[class_indices, :, :]
113            y_class = y[class_indices]
114
115            # Try rebuilding the classifier each time
116            if self.rebuild:
117                self.next_fit_trial = 0
118                # tf.keras.backend.clear_session()
119
120            subX = X_class[self.next_fit_trial :, :, :]
121            suby = y_class[self.next_fit_trial :]
122            self.next_fit_trial = n_trials
123
124            for train_idx, test_idx in self.cv.split(subX, suby):
125                X_train, X_test = subX[train_idx], subX[test_idx]
126                y_train, y_test = suby[train_idx], suby[test_idx]
127
128                z_dim, y_dim, x_dim = X_train.shape
129                X_train = X_train.reshape(z_dim, x_dim * y_dim)
130                scaler_train = preprocessing.StandardScaler().fit(X_train)
131                X_train_scaled = scaler_train.transform(X_train)
132
133                logger.info("The shape of X_train_scaled is %s", X_train_scaled.shape)
134
135                z_dim, y_dim, x_dim = X_test.shape
136                X_test = X_test.reshape(z_dim, x_dim * y_dim)
137                scaler_test = preprocessing.StandardScaler().fit(X_test)
138                X_test_scaled = scaler_test.transform(X_test)
139
140                if i == 0:
141                    # Compile the model
142                    logger.info("\nWorking on first model...")
143                    self.clf0and1.compile(
144                        # optimizer=Adam(learning_rate=0.001),
145                        loss="sparse_categorical_crossentropy",
146                        metrics=["accuracy"],
147                    )
148                    # Fit the model
149                    self.clf0and1.fit(
150                        x=X_train_scaled,
151                        y=y_train,
152                        batch_size=5,
153                        epochs=4,
154                        shuffle=True,
155                        verbose=2,
156                        validation_data=(X_test_scaled, y_test),
157                    )  # Need to reshape X_train
158
159                else:
160                    logger.info("\nWorking on second model...")
161                    # Compile the model
162                    self.clf0and2.compile(
163                        # optimizer=Adam(learning_rate=0.001),
164                        loss="sparse_categorical_crossentropy",
165                        metrics=["accuracy"],
166                    )
167                    # Fit the model
168                    self.clf0and2.fit(
169                        x=X_train_scaled,
170                        y=y_train,
171                        batch_size=5,
172                        epochs=4,
173                        shuffle=True,
174                        verbose=2,
175                        validation_data=(X_test_scaled, y_test),
176                    )  # Need to reshape X_train
177
178            # Log performance stats
179            # accuracy
180            # correct = preds == self.y
181            # logger.info("Correct: %s", correct)
182
183            # COMMENTED OUT DUE TO INCOMPLETE IMPLEMENTATION
184            """
185            self.offline_trial_count = n_trials
186            self.offline_trial_counts.append(self.offline_trial_count)
187            # accuracy
188            accuracy = sum(preds == self.y) / len(preds)
189            self.offline_accuracy.append(accuracy)
190            logger.info("Accuracy = %s", accuracy)
191            # precision
192            precision = precision_score(self.y, preds, average="micro")
193            self.offline_precision.append(precision)
194            logger.info("Precision = %s", precision))
195            # recall
196            recall = recall_score(self.y, preds, average="micro")
197            self.offline_recall.append(recall)
198            logger.info("Recall = %s", recall)
199            # confusion matrix in command line
200            cm = confusion_matrix(self.y, preds)
201            self.offline_cm = cm
202            logger.info("Confusion matrix:\n%s", cm)
203            """
204
205    def predict(self, X):
206        """Predict the class labels for the provided data.
207
208        Parameters
209        ----------
210        X : numpy.ndarray
211            3D array where shape = (trials, channels, samples)
212
213        Returns
214        -------
215        prediction : Prediction
216            Results of predict call containing the predicted class labels.  Probabilities
217            are not available (empty list).
218
219        """
220        # if X is 2D, make it 3D with one as first dimension
221        if len(X.shape) < 3:
222            X = X[np.newaxis, ...]
223
224        logger.info("The shape of X is %s", X.shape)
225
226        # self.predict0and1 = Sequential(
227        #     [
228        #         Flatten(),
229        #         Dense(units=8, input_shape=(4,), activation="relu"),
230        #         Dense(units=16, activation="relu"),
231        #         Dense(units=3, activation="sigmoid"),
232        #     ]
233        # )
234
235        # self.predict0and2 = Sequential(
236        #     [
237        #         Flatten(),
238        #         Dense(units=8, input_shape=(4,), activation="relu"),
239        #         Dense(units=16, activation="relu"),
240        #         Dense(units=3, activation="sigmoid"),
241        #     ]
242        # )
243
244        z_dim, y_dim, x_dim = X.shape
245        X_predict = X.reshape(z_dim, x_dim * y_dim)
246        scaler_train = preprocessing.StandardScaler().fit(X_predict)
247        X_predict_scaled = scaler_train.transform(X_predict)
248
249        pred0and1 = self.predict0and1.predict(X_predict_scaled)
250        pred0and2 = self.predict0and2.predict(X_predict_scaled)
251
252        final_predictions = np.array([])
253
254        for row1, row2 in zip(pred0and1, pred0and2):
255            if row1[0] > row1[1] and row2[0] > row2[2]:
256                np.append(final_predictions, 0)
257            elif row1[0] > row1[1] and row2[0] < row2[2]:
258                np.append(final_predictions, 2)
259            elif row1[0] < row1[1] and row2[0] > row2[2]:
260                np.append(final_predictions, 1)
261            elif row1[0] < row1[1] and row2[0] < row2[2]:
262                if row1[1] > row2[2]:
263                    np.append(final_predictions, 1)
264                else:
265                    np.append(final_predictions, 2)
266
267        return Prediction(labels=final_predictions)
 33class SwitchMdmClassifier(GenericClassifier):
 34    """Switch MDM Classifier class (*inherits from GenericClassifier*)."""
 35
 36    def set_switch_classifier_mdm_settings(
 37        self,
 38        n_splits=2,
 39        rebuild=True,
 40        random_seed=42,
 41        n_jobs=1,
 42        activation_main="relu",
 43        activation_class="sigmoid",
 44    ):
 45        """Set the Switch Classifier MDM settings.
 46
 47        Parameters
 48        ----------
 49        n_splits : int, *optional*
 50            Number of folds for cross-validation.
 51            - Default is `2`.
 52        rebuild : bool, *optional*
 53            Rebuild the classifier each time. *More description needed*.
 54            - Default is `True`.
 55        random_seed : int, *optional*
 56            Random seed.
 57            - Default is `42`.
 58        n_jobs : int, *optional*
 59            The number of threads to dedicate to this calculation.
 60        activation_main : str, *optional*
 61            Activation function for hidden layers.
 62            - Default is `relu`.
 63        activation_class : str, *optional*
 64            Activation function for the output layer.
 65            - Default is `sigmoid`.
 66
 67        Returns
 68        -------
 69        `None`
 70            Models created are used in `fit()`.
 71
 72        """
 73        self.n_splits = n_splits
 74        self.cv = StratifiedKFold(
 75            n_splits=self.n_splits, shuffle=True, random_state=random_seed
 76        )
 77        self.rebuild = rebuild
 78
 79        mdm = MDM(metric=dict(mean="riemann", distance="riemann"), n_jobs=n_jobs)
 80        self.clf_model = Pipeline([("MDM", mdm)])
 81        self.clf = Pipeline([("MDM", mdm)])
 82        # self.clf0and1 = MDM()
 83
 84    def fit(self):
 85        """Fit the model.
 86
 87        Returns
 88        -------
 89        `None`
 90            Models created used in `predict()`.
 91
 92        """
 93        # get dimensions
 94        n_trials, n_channels, n_samples = self.X.shape
 95
 96        # do the rest of the training if train_free is false
 97        X = np.array(self.X)
 98        y = np.array(self.y)
 99
100        # find the number of classes in y there shoud be N + 1, where N is the number of objects in the scene and also the number of classifiers
101        self.num_classifiers = len(list(np.unique(self.y))) - 1
102        logger.info("Number of classes: %s", self.num_classifiers)
103
104        # make a list to hold all of the classifiers
105        self.clfs = []
106
107        # loop through and build the classifiers
108        for i in range(self.num_classifiers):
109            # take a subset / do spatial filtering
110            X = X[:, :, :]  # Does nothing for now
111
112            class_indices = np.logical_or(y == 0, y == (i + 1))
113            X_class = X[class_indices, :, :]
114            y_class = y[class_indices]
115
116            # Try rebuilding the classifier each time
117            if self.rebuild:
118                self.next_fit_trial = 0
119                # tf.keras.backend.clear_session()
120
121            subX = X_class[self.next_fit_trial :, :, :]
122            suby = y_class[self.next_fit_trial :]
123            self.next_fit_trial = n_trials
124
125            for train_idx, test_idx in self.cv.split(subX, suby):
126                X_train, X_test = subX[train_idx], subX[test_idx]
127                y_train, y_test = suby[train_idx], suby[test_idx]
128
129                z_dim, y_dim, x_dim = X_train.shape
130                X_train = X_train.reshape(z_dim, x_dim * y_dim)
131                scaler_train = preprocessing.StandardScaler().fit(X_train)
132                X_train_scaled = scaler_train.transform(X_train)
133
134                logger.info("The shape of X_train_scaled is %s", X_train_scaled.shape)
135
136                z_dim, y_dim, x_dim = X_test.shape
137                X_test = X_test.reshape(z_dim, x_dim * y_dim)
138                scaler_test = preprocessing.StandardScaler().fit(X_test)
139                X_test_scaled = scaler_test.transform(X_test)
140
141                if i == 0:
142                    # Compile the model
143                    logger.info("\nWorking on first model...")
144                    self.clf0and1.compile(
145                        # optimizer=Adam(learning_rate=0.001),
146                        loss="sparse_categorical_crossentropy",
147                        metrics=["accuracy"],
148                    )
149                    # Fit the model
150                    self.clf0and1.fit(
151                        x=X_train_scaled,
152                        y=y_train,
153                        batch_size=5,
154                        epochs=4,
155                        shuffle=True,
156                        verbose=2,
157                        validation_data=(X_test_scaled, y_test),
158                    )  # Need to reshape X_train
159
160                else:
161                    logger.info("\nWorking on second model...")
162                    # Compile the model
163                    self.clf0and2.compile(
164                        # optimizer=Adam(learning_rate=0.001),
165                        loss="sparse_categorical_crossentropy",
166                        metrics=["accuracy"],
167                    )
168                    # Fit the model
169                    self.clf0and2.fit(
170                        x=X_train_scaled,
171                        y=y_train,
172                        batch_size=5,
173                        epochs=4,
174                        shuffle=True,
175                        verbose=2,
176                        validation_data=(X_test_scaled, y_test),
177                    )  # Need to reshape X_train
178
179            # Log performance stats
180            # accuracy
181            # correct = preds == self.y
182            # logger.info("Correct: %s", correct)
183
184            # COMMENTED OUT DUE TO INCOMPLETE IMPLEMENTATION
185            """
186            self.offline_trial_count = n_trials
187            self.offline_trial_counts.append(self.offline_trial_count)
188            # accuracy
189            accuracy = sum(preds == self.y) / len(preds)
190            self.offline_accuracy.append(accuracy)
191            logger.info("Accuracy = %s", accuracy)
192            # precision
193            precision = precision_score(self.y, preds, average="micro")
194            self.offline_precision.append(precision)
195            logger.info("Precision = %s", precision))
196            # recall
197            recall = recall_score(self.y, preds, average="micro")
198            self.offline_recall.append(recall)
199            logger.info("Recall = %s", recall)
200            # confusion matrix in command line
201            cm = confusion_matrix(self.y, preds)
202            self.offline_cm = cm
203            logger.info("Confusion matrix:\n%s", cm)
204            """
205
206    def predict(self, X):
207        """Predict the class labels for the provided data.
208
209        Parameters
210        ----------
211        X : numpy.ndarray
212            3D array where shape = (trials, channels, samples)
213
214        Returns
215        -------
216        prediction : Prediction
217            Results of predict call containing the predicted class labels.  Probabilities
218            are not available (empty list).
219
220        """
221        # if X is 2D, make it 3D with one as first dimension
222        if len(X.shape) < 3:
223            X = X[np.newaxis, ...]
224
225        logger.info("The shape of X is %s", X.shape)
226
227        # self.predict0and1 = Sequential(
228        #     [
229        #         Flatten(),
230        #         Dense(units=8, input_shape=(4,), activation="relu"),
231        #         Dense(units=16, activation="relu"),
232        #         Dense(units=3, activation="sigmoid"),
233        #     ]
234        # )
235
236        # self.predict0and2 = Sequential(
237        #     [
238        #         Flatten(),
239        #         Dense(units=8, input_shape=(4,), activation="relu"),
240        #         Dense(units=16, activation="relu"),
241        #         Dense(units=3, activation="sigmoid"),
242        #     ]
243        # )
244
245        z_dim, y_dim, x_dim = X.shape
246        X_predict = X.reshape(z_dim, x_dim * y_dim)
247        scaler_train = preprocessing.StandardScaler().fit(X_predict)
248        X_predict_scaled = scaler_train.transform(X_predict)
249
250        pred0and1 = self.predict0and1.predict(X_predict_scaled)
251        pred0and2 = self.predict0and2.predict(X_predict_scaled)
252
253        final_predictions = np.array([])
254
255        for row1, row2 in zip(pred0and1, pred0and2):
256            if row1[0] > row1[1] and row2[0] > row2[2]:
257                np.append(final_predictions, 0)
258            elif row1[0] > row1[1] and row2[0] < row2[2]:
259                np.append(final_predictions, 2)
260            elif row1[0] < row1[1] and row2[0] > row2[2]:
261                np.append(final_predictions, 1)
262            elif row1[0] < row1[1] and row2[0] < row2[2]:
263                if row1[1] > row2[2]:
264                    np.append(final_predictions, 1)
265                else:
266                    np.append(final_predictions, 2)
267
268        return Prediction(labels=final_predictions)

Switch MDM Classifier class (inherits from GenericClassifier).

def set_switch_classifier_mdm_settings( self, n_splits=2, rebuild=True, random_seed=42, n_jobs=1, activation_main='relu', activation_class='sigmoid'):
36    def set_switch_classifier_mdm_settings(
37        self,
38        n_splits=2,
39        rebuild=True,
40        random_seed=42,
41        n_jobs=1,
42        activation_main="relu",
43        activation_class="sigmoid",
44    ):
45        """Set the Switch Classifier MDM settings.
46
47        Parameters
48        ----------
49        n_splits : int, *optional*
50            Number of folds for cross-validation.
51            - Default is `2`.
52        rebuild : bool, *optional*
53            Rebuild the classifier each time. *More description needed*.
54            - Default is `True`.
55        random_seed : int, *optional*
56            Random seed.
57            - Default is `42`.
58        n_jobs : int, *optional*
59            The number of threads to dedicate to this calculation.
60        activation_main : str, *optional*
61            Activation function for hidden layers.
62            - Default is `relu`.
63        activation_class : str, *optional*
64            Activation function for the output layer.
65            - Default is `sigmoid`.
66
67        Returns
68        -------
69        `None`
70            Models created are used in `fit()`.
71
72        """
73        self.n_splits = n_splits
74        self.cv = StratifiedKFold(
75            n_splits=self.n_splits, shuffle=True, random_state=random_seed
76        )
77        self.rebuild = rebuild
78
79        mdm = MDM(metric=dict(mean="riemann", distance="riemann"), n_jobs=n_jobs)
80        self.clf_model = Pipeline([("MDM", mdm)])
81        self.clf = Pipeline([("MDM", mdm)])
82        # self.clf0and1 = MDM()

Set the Switch Classifier MDM settings.

Parameters
  • n_splits (int, optional): Number of folds for cross-validation.
    • Default is 2.
  • rebuild (bool, optional): Rebuild the classifier each time. More description needed.
    • Default is True.
  • random_seed (int, optional): Random seed.
    • Default is 42.
  • n_jobs (int, optional): The number of threads to dedicate to this calculation.
  • activation_main (str, optional): Activation function for hidden layers.
    • Default is relu.
  • activation_class (str, optional): Activation function for the output layer.
    • Default is sigmoid.
Returns
  • None: Models created are used in fit().
def fit(self):
 84    def fit(self):
 85        """Fit the model.
 86
 87        Returns
 88        -------
 89        `None`
 90            Models created used in `predict()`.
 91
 92        """
 93        # get dimensions
 94        n_trials, n_channels, n_samples = self.X.shape
 95
 96        # do the rest of the training if train_free is false
 97        X = np.array(self.X)
 98        y = np.array(self.y)
 99
100        # find the number of classes in y there shoud be N + 1, where N is the number of objects in the scene and also the number of classifiers
101        self.num_classifiers = len(list(np.unique(self.y))) - 1
102        logger.info("Number of classes: %s", self.num_classifiers)
103
104        # make a list to hold all of the classifiers
105        self.clfs = []
106
107        # loop through and build the classifiers
108        for i in range(self.num_classifiers):
109            # take a subset / do spatial filtering
110            X = X[:, :, :]  # Does nothing for now
111
112            class_indices = np.logical_or(y == 0, y == (i + 1))
113            X_class = X[class_indices, :, :]
114            y_class = y[class_indices]
115
116            # Try rebuilding the classifier each time
117            if self.rebuild:
118                self.next_fit_trial = 0
119                # tf.keras.backend.clear_session()
120
121            subX = X_class[self.next_fit_trial :, :, :]
122            suby = y_class[self.next_fit_trial :]
123            self.next_fit_trial = n_trials
124
125            for train_idx, test_idx in self.cv.split(subX, suby):
126                X_train, X_test = subX[train_idx], subX[test_idx]
127                y_train, y_test = suby[train_idx], suby[test_idx]
128
129                z_dim, y_dim, x_dim = X_train.shape
130                X_train = X_train.reshape(z_dim, x_dim * y_dim)
131                scaler_train = preprocessing.StandardScaler().fit(X_train)
132                X_train_scaled = scaler_train.transform(X_train)
133
134                logger.info("The shape of X_train_scaled is %s", X_train_scaled.shape)
135
136                z_dim, y_dim, x_dim = X_test.shape
137                X_test = X_test.reshape(z_dim, x_dim * y_dim)
138                scaler_test = preprocessing.StandardScaler().fit(X_test)
139                X_test_scaled = scaler_test.transform(X_test)
140
141                if i == 0:
142                    # Compile the model
143                    logger.info("\nWorking on first model...")
144                    self.clf0and1.compile(
145                        # optimizer=Adam(learning_rate=0.001),
146                        loss="sparse_categorical_crossentropy",
147                        metrics=["accuracy"],
148                    )
149                    # Fit the model
150                    self.clf0and1.fit(
151                        x=X_train_scaled,
152                        y=y_train,
153                        batch_size=5,
154                        epochs=4,
155                        shuffle=True,
156                        verbose=2,
157                        validation_data=(X_test_scaled, y_test),
158                    )  # Need to reshape X_train
159
160                else:
161                    logger.info("\nWorking on second model...")
162                    # Compile the model
163                    self.clf0and2.compile(
164                        # optimizer=Adam(learning_rate=0.001),
165                        loss="sparse_categorical_crossentropy",
166                        metrics=["accuracy"],
167                    )
168                    # Fit the model
169                    self.clf0and2.fit(
170                        x=X_train_scaled,
171                        y=y_train,
172                        batch_size=5,
173                        epochs=4,
174                        shuffle=True,
175                        verbose=2,
176                        validation_data=(X_test_scaled, y_test),
177                    )  # Need to reshape X_train
178
179            # Log performance stats
180            # accuracy
181            # correct = preds == self.y
182            # logger.info("Correct: %s", correct)
183
184            # COMMENTED OUT DUE TO INCOMPLETE IMPLEMENTATION
185            """
186            self.offline_trial_count = n_trials
187            self.offline_trial_counts.append(self.offline_trial_count)
188            # accuracy
189            accuracy = sum(preds == self.y) / len(preds)
190            self.offline_accuracy.append(accuracy)
191            logger.info("Accuracy = %s", accuracy)
192            # precision
193            precision = precision_score(self.y, preds, average="micro")
194            self.offline_precision.append(precision)
195            logger.info("Precision = %s", precision))
196            # recall
197            recall = recall_score(self.y, preds, average="micro")
198            self.offline_recall.append(recall)
199            logger.info("Recall = %s", recall)
200            # confusion matrix in command line
201            cm = confusion_matrix(self.y, preds)
202            self.offline_cm = cm
203            logger.info("Confusion matrix:\n%s", cm)
204            """

Fit the model.

Returns
def predict(self, X):
206    def predict(self, X):
207        """Predict the class labels for the provided data.
208
209        Parameters
210        ----------
211        X : numpy.ndarray
212            3D array where shape = (trials, channels, samples)
213
214        Returns
215        -------
216        prediction : Prediction
217            Results of predict call containing the predicted class labels.  Probabilities
218            are not available (empty list).
219
220        """
221        # if X is 2D, make it 3D with one as first dimension
222        if len(X.shape) < 3:
223            X = X[np.newaxis, ...]
224
225        logger.info("The shape of X is %s", X.shape)
226
227        # self.predict0and1 = Sequential(
228        #     [
229        #         Flatten(),
230        #         Dense(units=8, input_shape=(4,), activation="relu"),
231        #         Dense(units=16, activation="relu"),
232        #         Dense(units=3, activation="sigmoid"),
233        #     ]
234        # )
235
236        # self.predict0and2 = Sequential(
237        #     [
238        #         Flatten(),
239        #         Dense(units=8, input_shape=(4,), activation="relu"),
240        #         Dense(units=16, activation="relu"),
241        #         Dense(units=3, activation="sigmoid"),
242        #     ]
243        # )
244
245        z_dim, y_dim, x_dim = X.shape
246        X_predict = X.reshape(z_dim, x_dim * y_dim)
247        scaler_train = preprocessing.StandardScaler().fit(X_predict)
248        X_predict_scaled = scaler_train.transform(X_predict)
249
250        pred0and1 = self.predict0and1.predict(X_predict_scaled)
251        pred0and2 = self.predict0and2.predict(X_predict_scaled)
252
253        final_predictions = np.array([])
254
255        for row1, row2 in zip(pred0and1, pred0and2):
256            if row1[0] > row1[1] and row2[0] > row2[2]:
257                np.append(final_predictions, 0)
258            elif row1[0] > row1[1] and row2[0] < row2[2]:
259                np.append(final_predictions, 2)
260            elif row1[0] < row1[1] and row2[0] > row2[2]:
261                np.append(final_predictions, 1)
262            elif row1[0] < row1[1] and row2[0] < row2[2]:
263                if row1[1] > row2[2]:
264                    np.append(final_predictions, 1)
265                else:
266                    np.append(final_predictions, 2)
267
268        return Prediction(labels=final_predictions)

Predict the class labels for the provided data.

Parameters
  • X (numpy.ndarray): 3D array where shape = (trials, channels, samples)
Returns
  • prediction (Prediction): Results of predict call containing the predicted class labels. Probabilities are not available (empty list).