Skip to content

Evaluate Module

GeneralEvaluator

Bases: BaseEvaluator

Source code in src/autoencodix/evaluate/_general_evaluator.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
class GeneralEvaluator(BaseEvaluator):
    def __init__(self):
        # super().__init__()
        pass

    @no_type_check
    def evaluate(
        self,
        datasets: DatasetContainer,
        result: Result,
        ml_model_class: ClassifierMixin = linear_model.LogisticRegression(
            max_iter=1000
        ),  # Default is sklearn LogisticRegression
        ml_model_regression: RegressorMixin = linear_model.LinearRegression(),  # Default is sklearn LinearRegression
        params: Union[
            list, str
        ] = "all",  # No default? ... or all params in annotation?
        metric_class: str = "roc_auc_ovo",  # Default is 'roc_auc_ovo' via https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-string-names
        metric_regression: str = "r2",  # Default is 'r2'
        reference_methods: list = [],  # Default [], Options are "PCA", "UMAP", "TSNE", "RandomFeature"
        split_type: str = "use-split",  # Default is "use-split", other options: "CV-5", ... "LOOCV"?
        n_downsample: Union[
            int, None
        ] = 10000,  # Default is 10000, if provided downsample to this number of samples for faster evaluation. Set to None to disable downsampling.
    ) -> Result:
        """Evaluates the performance of machine learning models on various feature representations and clinical parameters.

        This method performs classification or regression tasks using specified machine learning models on different feature sets (e.g., latent space, PCA, UMAP, TSNE, RandomFeature) and clinical annotation parameters. It supports multiple evaluation strategies, including pre-defined train/valid/test splits, k-fold cross-validation, and leave-one-out cross-validation. The results are aggregated and stored in the provided `result` object.
        - Samples with missing annotation values for a given parameter are excluded from the corresponding evaluation.
        - For "RandomFeature", five random feature sets are evaluated.
        - The method appends results to any existing `embedding_evaluation` in the result object.

        Args:
            datasets: A DatasetContainer containing train, valid, and test datasets, each with `sample_ids` and `metadata` (either a DataFrame or a dictionary with a 'paired' key for clinical annotations).
            result: An Result object to store the evaluation results. Should have an `embedding_evaluation` attribute which updated (typically a DataFrame).
            ml_model_class: The scikit-learn classifier to use for classification tasks (default: `sklearn.linear_model.LogisticRegression()`).
            ml_model_regression: The scikit-learn regressor to use for regression tasks (default: `sklearn.linear_model.LinearRegression()`).
            params:List of clinical annotation columns to evaluate, or "all" to use all columns (default: "all").
            metric_class: Scoring metric for classification tasks (default: "roc_auc_ovo").
            metric_regression: Scoring metric for regression tasks (default: "r2").
            reference_methods:List of feature representations to evaluate (e.g., "PCA", "UMAP", "TSNE", "RandomFeature"). "Latent" is always included (default: []).
            split_type: which split to use
                use-split" for pre-defined splits, "CV-N" for N-fold cross-validation, or "LOOCV" for leave-one-out cross-validation (default: "use-split").
            n_downsample: If provided, downsample the data to this number of samples for faster evaluation. Default is 10000. Set to None to disable downsampling.
        Returns:
            The updated result object with evaluation results stored in `embedding_evaluation`.
        Raises
            ValueError: If required annotation data is missing or improperly formatted, or if an unsupported split type is specified.

        """

        already_warned = False

        df_results = pd.DataFrame()

        reference_methods.append("Latent")

        reference_methods = self._expand_reference_methods(
            reference_methods=reference_methods, result=result
        )

        ## Overwrite original datasets with new_datasets if available after predict with other data
        if datasets is None:
            datasets = DatasetContainer()

        if bool(result.new_datasets.test):
            datasets.test = result.new_datasets.test

        if not bool(datasets.train or datasets.valid or datasets.test):
            raise ValueError(
                "No datasets found in result object. Please run predict with new data or save/load with all datasets by using save_all=True."
            )
        elif split_type == "use-split" and not bool(datasets.train):
            warnings.warn(
                "Warning: No train split found in result datasets for 'use-split' evaluation. ML model cannot be trained without a train split. Switch to cross-validation (CV-5) instead."
            )
            split_type = "CV-5"

        for task in reference_methods:
            print(f"Perform ML task with feature df: {task}")

            # clin_data = self._get_clin_data(datasets)
            clin_data = BaseVisualizer._collect_all_metadata(result=result)

            if split_type == "use-split":
                # Pandas dataframe with sample_ids and split information
                sample_split = pd.DataFrame(columns=["SAMPLE_ID", "SPLIT"])

                if datasets.train is not None:
                    if hasattr(datasets.train, "paired_sample_ids"):
                        if datasets.train.paired_sample_ids is not None:
                            sample_ids = datasets.train.paired_sample_ids
                    else:
                        sample_ids = datasets.train.sample_ids
                    sample_split_temp = dict(
                        sample_split,
                        **{
                            "SAMPLE_ID": sample_ids,
                            "SPLIT": ["train"] * len(sample_ids),
                        },
                    )
                    sample_split = pd.concat(
                        [sample_split, pd.DataFrame(sample_split_temp)],
                        axis=0,
                        ignore_index=True,
                    )
                # else:
                #     raise ValueError(
                #         "No training data found. Please provide a valid training dataset."
                #     )
                if datasets.valid is not None:
                    if hasattr(datasets.valid, "paired_sample_ids"):
                        if datasets.valid.paired_sample_ids is not None:
                            sample_ids = datasets.valid.paired_sample_ids
                    else:
                        sample_ids = datasets.valid.sample_ids
                    sample_split_temp = dict(
                        sample_split,
                        **{
                            "SAMPLE_ID": sample_ids,
                            "SPLIT": ["valid"] * len(sample_ids),
                        },
                    )
                    sample_split = pd.concat(
                        [sample_split, pd.DataFrame(sample_split_temp)],
                        axis=0,
                        ignore_index=True,
                    )
                if datasets.test is not None:
                    if hasattr(datasets.test, "paired_sample_ids"):
                        if datasets.test.paired_sample_ids is not None:
                            sample_ids = datasets.test.paired_sample_ids
                    else:
                        sample_ids = datasets.test.sample_ids
                    sample_split_temp = dict(
                        sample_split,
                        **{
                            "SAMPLE_ID": sample_ids,
                            "SPLIT": ["test"] * len(sample_ids),
                        },
                    )
                    sample_split = pd.concat(
                        [sample_split, pd.DataFrame(sample_split_temp)],
                        axis=0,
                        ignore_index=True,
                    )

                sample_split = sample_split.set_index("SAMPLE_ID", drop=False)

            ## df -> task
            subtask = [task]
            if "RandomFeature" in task:
                subtask = [task + "_R" + str(x) for x in range(1, 6)]
            for sub in subtask:
                print(sub)
                # if is_modalix:
                #     modality = task.split("_$_")[1]
                #     task_xmodal = task.split("_$_")[0]

                #     df = self._load_input_for_ml_xmodal(task_xmodal, datasets, result, modality=modality)
                # else:
                df = self._load_input_for_ml(task, datasets, result)

                if params == "all":
                    params = clin_data.columns.tolist()

                for task_param in params:
                    if "Latent" in task:
                        print(f"Perform ML task for target parameter: {task_param}")
                    ## Check if classification or regression task
                    ml_type = self._get_ml_type(clin_data, task_param)

                    if pd.isna(clin_data[task_param]).sum() > 0:
                        # if pd.isna(clin_data[task_param]).values.any():
                        if not already_warned:
                            print(
                                "There are NA values in the annotation file. Samples with missing data will be removed for ML task evaluation."
                            )
                        already_warned = True
                        # logger.warning(clin_data.loc[pd.isna(clin_data[task_param]), task_param])

                        samples_nonna = clin_data.loc[
                            pd.notna(clin_data[task_param]), task_param
                        ].index
                        # print(df)
                        df = df.loc[samples_nonna.intersection(df.index), :]
                        if split_type == "use-split":
                            sample_split = sample_split.loc[
                                samples_nonna.intersection(sample_split.index), :
                            ]
                        # print(sample_split)

                    if n_downsample is not None:
                        if df.shape[0] > n_downsample:
                            sample_idx = np.random.choice(
                                df.shape[0], n_downsample, replace=False
                            )
                            df = df.iloc[sample_idx]
                            if split_type == "use-split":
                                sample_split = sample_split.loc[df.index, :]

                    if ml_type == "classification":
                        metric = metric_class
                        sklearn_ml = ml_model_class

                    if ml_type == "regression":
                        metric = metric_regression
                        sklearn_ml = ml_model_regression

                    if split_type == "use-split":
                        # print("Sample Split:")
                        # print(sample_split)
                        # print("Latent:")
                        # print(df)
                        results = self._single_ml_presplit(
                            sample_split=sample_split,
                            df=df,
                            clin_data=clin_data,
                            task_param=task_param,
                            sklearn_ml=sklearn_ml,
                            metric=metric,
                            ml_type=ml_type,
                        )
                    elif split_type.startswith("CV-"):
                        cv_folds = int(split_type.split("-")[1])

                        results = self._single_ml(
                            df=df,
                            clin_data=clin_data,
                            task_param=task_param,
                            sklearn_ml=sklearn_ml,
                            metric=metric,
                            cv_folds=cv_folds,
                        )
                    elif split_type == "LOOCV":
                        # Leave One Out Cross Validation
                        results = self._single_ml(
                            df=df,
                            clin_data=clin_data,
                            task_param=task_param,
                            sklearn_ml=sklearn_ml,
                            metric=metric,
                            cv_folds=len(df),
                        )
                    else:
                        raise ValueError(
                            f"Your split type {split_type} is not supported. Please use 'use-split', 'CV-5', 'LOOCV' or 'CV-N'."
                        )
                    results = self._enrich_results(
                        results=results,
                        sklearn_ml=sklearn_ml,
                        ml_type=ml_type,
                        task=task,
                        sub=sub,
                    )

                    df_results = pd.concat([df_results, results])

        ## Check if embedding_evaluation is empty
        if (
            hasattr(result, "embedding_evaluation")
            and len(result.embedding_evaluation) == 0
        ):
            result.embedding_evaluation = df_results
        else:
            # merge with existing results
            result.embedding_evaluation = pd.concat(
                [result.embedding_evaluation, df_results], axis=0
            )

        return result

    @staticmethod
    def _single_ml(
        df: pd.DataFrame,
        clin_data: pd.DataFrame,
        task_param: str,
        sklearn_ml: Union[ClassifierMixin, RegressorMixin],
        metric: str,
        cv_folds: int = 5,
    ):
        """Function learns on the given data frame df and label data the provided sklearn model.

        Cross validation is performed according to the config and scores are returned as output as specified by metrics

        Args:
            df: Dataframe with input data
            clin_data: Dataframe with label data
            task_param: Column name with label data
            sklearn_ml: Sklearn ML module specifying the ML algorithm
            metric: string specifying the metric to be calculated by cross validation
            cv_folds:
        Returns:
            score_df: data frame containing metrics (scores) for all CV runs (long format)

        """

        # X -> df
        # Y -> task_param
        y: Union[pd.Series, pd.DataFrame] = clin_data.loc[df.index, task_param]
        score_df = dict()

        ## Cross Validation
        if len(y.unique()) > 1:  # ty: ignore
            scores = cross_validate(
                sklearn_ml, df, y, cv=cv_folds, scoring=metric, return_train_score=True
            )

            # Output

            # Output Format
            # CV_RUN | SCORE_SPLIT | TASK_PARAM | METRIC | VALUE

            score_df["cv_run"] = list()
            score_df["score_split"] = list()
            score_df["CLINIC_PARAM"] = list()
            score_df["metric"] = list()
            score_df["value"] = list()

            cv_runs = ["CV_" + str(x) for x in range(1, cv_folds + 1)]
            task_param_cv = [task_param for x in range(1, cv_folds + 1)]

            for m in scores:
                if m.split("_")[0] == "test" or m.split("_")[0] == "train":
                    split_cv = [m.split("_")[0] for x in range(1, cv_folds + 1)]
                    metric_cv = [metric for x in range(1, cv_folds + 1)]

                    score_df["cv_run"].extend(cv_runs)
                    score_df["score_split"].extend(split_cv)
                    score_df["CLINIC_PARAM"].extend(task_param_cv)
                    score_df["metric"].extend(metric_cv)
                    score_df["value"].extend(scores[m])

        return pd.DataFrame(score_df)

    def _enrich_results(
        self,
        results: pd.DataFrame,
        sklearn_ml: Union[ClassifierMixin, RegressorMixin],
        ml_type: str,
        task: str,
        sub: str,
    ) -> pd.DataFrame:
        res_ml_alg = [str(sklearn_ml) for x in range(0, results.shape[0])]
        res_ml_type = [ml_type for x in range(0, results.shape[0])]
        res_ml_task = [task for x in range(0, results.shape[0])]
        res_ml_subtask = [sub for x in range(0, results.shape[0])]

        results["ML_ALG"] = res_ml_alg
        results["ML_TYPE"] = res_ml_type
        # if is_modalix:
        #     results["MODALITY"] = [modality for x in range(0, results.shape[0])]
        #     results["ML_TASK"] = [task_xmodal for x in range(0, results.shape[0])]
        # else:
        results["ML_TASK"] = res_ml_task
        results["ML_SUBTASK"] = res_ml_subtask

        return results

    @staticmethod
    def _single_ml_presplit(
        sample_split: pd.DataFrame,
        df: pd.DataFrame,
        clin_data: pd.DataFrame,
        task_param: str,
        sklearn_ml: Union[ClassifierMixin, RegressorMixin],
        metric: str,
        ml_type: str,
    ):
        """Trains the provided sklearn model on the training split and evaluates it on train, valid, and test splits using the specified metric.

        Args:
            sample_split: DataFrame with sample IDs and their corresponding split ("train", "valid", "test").
            df: DataFrame with input features, indexed by sample IDs.
            clin_data: DataFrame with label/annotation data, indexed by sample IDs.
            task_param: Column name in clin_data specifying the target variable.
            sklearn_ml: Instantiated sklearn model to use for training and evaluation.
            metric: Scoring metric compatible with sklearn's get_scorer.
            ml_type: Type of machine learning task ("classification" or "regression").

        Returns:
            DataFrame containing evaluation scores for each split (train, valid, test) and the specified metric.

        Raises
            ValueError: If the provided metric is not supported by sklearn.
        """
        split_list = ["train", "valid", "test"]

        score_df = dict()
        score_df["score_split"] = list()
        score_df["CLINIC_PARAM"] = list()
        score_df["metric"] = list()
        score_df["value"] = list()

        X_train = df.loc[
            sample_split.loc[sample_split.SPLIT == "train", "SAMPLE_ID"], :
        ]
        train_samples = [s for s in X_train.index]
        Y_train = clin_data.loc[train_samples, task_param]
        # train model once on training data
        if len(Y_train.unique()) > 1:  # ty: ignore
            sklearn_ml.fit(X_train, Y_train)  # ty: ignore

            # eval on all splits
            for split in split_list:
                X = df.loc[
                    sample_split.loc[sample_split.SPLIT == split, "SAMPLE_ID"], :
                ]
                if X.shape[0] == 0:
                    # No samples in this split, skip
                    continue
                samples = [s for s in X.index]
                Y = clin_data.loc[samples, task_param]

                # Performace on train, valid and test data split

                score_df["score_split"].append(split)
                score_df["CLINIC_PARAM"].append(task_param)
                score_df["metric"].append(metric)
                sklearn_scorer = get_scorer(metric)

                if sklearn_scorer is None:
                    raise ValueError(
                        f"Your metric {metric} is not supported by sklearn. Please use a valid metric."
                    )

                if ml_type == "classification":
                    # Check that Y has only classes which are present in Y_train
                    if (
                        len(
                            set(Y.unique()).difference(  # ty: ignore
                                set(Y_train.unique())  # ty: ignore
                            )  # ty: ignore
                        )  # ty: ignore
                        > 0  # ty: ignore
                    ):  # ty: ignore
                        print(
                            f"Classes in split {split} are not present in training data"
                        )
                        # Adjust Y to only contain classes present in Y_train
                        Y = Y[Y.isin(Y_train.unique())]  # ty: ignore
                        # Adjust X as well
                        X = X.loc[Y.index, :]

                if ml_type == "classification":
                    score_temp = sklearn_scorer(
                        sklearn_ml, X, Y, labels=np.sort(Y_train.unique())
                    )
                elif ml_type == "regression":
                    score_temp = sklearn_scorer(sklearn_ml, X, Y)
                else:
                    raise ValueError(
                        f"Your ML type {ml_type} is not supported. Please use 'classification' or 'regression'."
                    )
                score_df["value"].append(score_temp)
        else:
            ## Warning that there is only one class in the training data
            warnings.warn(
                f"Warning: There is only one class in the training data for task parameter {task_param}. Skipping evaluation for this task."
            )

        return pd.DataFrame(score_df)

    @staticmethod
    def _get_ml_type(clin_data: pd.DataFrame, task_param: str) -> str:
        """Determines the machine learning task type (classification or regression) based on the data type of a specified column in clinical data.

        Args:
            clin_data: The clinical data as a pandas DataFrame.
            task_param: The column name in clin_data to inspect for determining the task type.

        Returns:
            "classification" if the first value in the specified column is a string, otherwise "regression".
        """
        ## Auto-Detection
        if type(list(clin_data[task_param])[0]) is str:
            ml_type = "classification"
        elif clin_data[task_param].unique().shape[0] < 3:
            ml_type = "classification"
        else:
            ml_type = "regression"

        return ml_type

    @staticmethod
    def _load_input_for_ml(
        task: str, dataset: DatasetContainer, result: Result
    ) -> pd.DataFrame:
        """Loads and processes input data for various machine learning tasks based on the specified task type.


        Task Details:
            - "Latent": Concatenates latent representations from train, validation, and test splits at the final epoch.
            - "UMAP": Applies UMAP dimensionality reduction to the concatenated dataset splits.
            - "PCA": Applies PCA dimensionality reduction to the concatenated dataset splits.
            - "TSNE": Applies t-SNE dimensionality reduction to the concatenated dataset splits.
            - "RandomFeature": Randomly samples columns (features) from the concatenated dataset splits.

        Args:
            task: The type of ML task. Supported values are "Latent", "UMAP", "PCA", "TSNE", and "RandomFeature".
            dataset: The dataset container object holding train, validation, and test splits.
            result: The result object containing model configuration and methods to retrieve latent representations.
        Returns:
            A DataFrame containing the processed input data suitable for the specified ML task.
        Raises:
            ValueError: If the provided task is not supported.
        """

        final_epoch = result.model.config.epochs - 1

        # if task == "Latent":
        #     df = pd.concat(
        #         [
        #             result.get_latent_df(epoch=final_epoch, split="train"),
        #             result.get_latent_df(epoch=final_epoch, split="valid"),
        #             result.get_latent_df(epoch=-1, split="test"),
        #         ]
        #     )

        if task == "Latent":
            dfs = []
            for split in ["train", "valid", "test"]:
                df_split = result.get_latent_df(
                    epoch=final_epoch if split != "test" else -1, split=split
                )
                if df_split is not None and not df_split.empty:
                    dfs.append(df_split)

            df = pd.concat(dfs) if dfs else pd.DataFrame()

        elif task in ["UMAP", "PCA", "TSNE", "RandomFeature"]:
            dfs = []
            for split_name in ["train", "valid", "test"]:
                split_data = getattr(dataset, split_name, None)
                if split_data is not None:
                    dfs.append(split_data._to_df())

            if not dfs:
                raise ValueError(
                    "No available dataset splits (train, valid, test) to process."
                )

            df_processed = pd.concat(dfs)

            # elif task in ["UMAP", "PCA", "TSNE", "RandomFeature"]:
            #     if dataset.train is None:
            #         raise ValueError("train attribute of dataset cannot be None")
            #     if dataset.valid is None:
            #         raise ValueError("valid attribute of dataset cannot be None")
            #     if dataset.test is None:
            #         raise ValueError("test attribute of dataset cannot be None")

            #     df_processed = pd.concat(
            #         [
            #             dataset.train._to_df(),
            #             dataset.test._to_df(),
            #             dataset.valid._to_df(),
            #         ]
            #     )
            if task == "UMAP":
                reducer = UMAP(n_components=result.model.config.latent_dim)
                df = pd.DataFrame(
                    reducer.fit_transform(df_processed), index=df_processed.index
                )
            elif task == "PCA":
                reducer = PCA(n_components=result.model.config.latent_dim)
                df = pd.DataFrame(
                    reducer.fit_transform(df_processed), index=df_processed.index
                )
            elif task == "TSNE":
                reducer = TSNE(n_components=result.model.config.latent_dim)
                df = pd.DataFrame(
                    reducer.fit_transform(df_processed), index=df_processed.index
                )
            elif task == "RandomFeature":
                df = df_processed.sample(n=result.model.config.latent_dim, axis=1)
        else:
            raise ValueError(
                f"Your ML task {task} is not supported. Please use Latent, UMAP, PCA or RandomFeature."
            )

        return df

evaluate(datasets, result, ml_model_class=linear_model.LogisticRegression(max_iter=1000), ml_model_regression=linear_model.LinearRegression(), params='all', metric_class='roc_auc_ovo', metric_regression='r2', reference_methods=[], split_type='use-split', n_downsample=10000)

Evaluates the performance of machine learning models on various feature representations and clinical parameters.

This method performs classification or regression tasks using specified machine learning models on different feature sets (e.g., latent space, PCA, UMAP, TSNE, RandomFeature) and clinical annotation parameters. It supports multiple evaluation strategies, including pre-defined train/valid/test splits, k-fold cross-validation, and leave-one-out cross-validation. The results are aggregated and stored in the provided result object. - Samples with missing annotation values for a given parameter are excluded from the corresponding evaluation. - For "RandomFeature", five random feature sets are evaluated. - The method appends results to any existing embedding_evaluation in the result object.

Parameters:

Name Type Description Default
datasets DatasetContainer

A DatasetContainer containing train, valid, and test datasets, each with sample_ids and metadata (either a DataFrame or a dictionary with a 'paired' key for clinical annotations).

required
result Result

An Result object to store the evaluation results. Should have an embedding_evaluation attribute which updated (typically a DataFrame).

required
ml_model_class ClassifierMixin

The scikit-learn classifier to use for classification tasks (default: sklearn.linear_model.LogisticRegression()).

LogisticRegression(max_iter=1000)
ml_model_regression RegressorMixin

The scikit-learn regressor to use for regression tasks (default: sklearn.linear_model.LinearRegression()).

LinearRegression()
params Union[list, str]

List of clinical annotation columns to evaluate, or "all" to use all columns (default: "all").

'all'
metric_class str

Scoring metric for classification tasks (default: "roc_auc_ovo").

'roc_auc_ovo'
metric_regression str

Scoring metric for regression tasks (default: "r2").

'r2'
reference_methods list

List of feature representations to evaluate (e.g., "PCA", "UMAP", "TSNE", "RandomFeature"). "Latent" is always included (default: []).

[]
split_type str

which split to use use-split" for pre-defined splits, "CV-N" for N-fold cross-validation, or "LOOCV" for leave-one-out cross-validation (default: "use-split").

'use-split'
n_downsample Union[int, None]

If provided, downsample the data to this number of samples for faster evaluation. Default is 10000. Set to None to disable downsampling.

10000

Returns: The updated result object with evaluation results stored in embedding_evaluation. Raises ValueError: If required annotation data is missing or improperly formatted, or if an unsupported split type is specified.

Source code in src/autoencodix/evaluate/_general_evaluator.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
@no_type_check
def evaluate(
    self,
    datasets: DatasetContainer,
    result: Result,
    ml_model_class: ClassifierMixin = linear_model.LogisticRegression(
        max_iter=1000
    ),  # Default is sklearn LogisticRegression
    ml_model_regression: RegressorMixin = linear_model.LinearRegression(),  # Default is sklearn LinearRegression
    params: Union[
        list, str
    ] = "all",  # No default? ... or all params in annotation?
    metric_class: str = "roc_auc_ovo",  # Default is 'roc_auc_ovo' via https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-string-names
    metric_regression: str = "r2",  # Default is 'r2'
    reference_methods: list = [],  # Default [], Options are "PCA", "UMAP", "TSNE", "RandomFeature"
    split_type: str = "use-split",  # Default is "use-split", other options: "CV-5", ... "LOOCV"?
    n_downsample: Union[
        int, None
    ] = 10000,  # Default is 10000, if provided downsample to this number of samples for faster evaluation. Set to None to disable downsampling.
) -> Result:
    """Evaluates the performance of machine learning models on various feature representations and clinical parameters.

    This method performs classification or regression tasks using specified machine learning models on different feature sets (e.g., latent space, PCA, UMAP, TSNE, RandomFeature) and clinical annotation parameters. It supports multiple evaluation strategies, including pre-defined train/valid/test splits, k-fold cross-validation, and leave-one-out cross-validation. The results are aggregated and stored in the provided `result` object.
    - Samples with missing annotation values for a given parameter are excluded from the corresponding evaluation.
    - For "RandomFeature", five random feature sets are evaluated.
    - The method appends results to any existing `embedding_evaluation` in the result object.

    Args:
        datasets: A DatasetContainer containing train, valid, and test datasets, each with `sample_ids` and `metadata` (either a DataFrame or a dictionary with a 'paired' key for clinical annotations).
        result: An Result object to store the evaluation results. Should have an `embedding_evaluation` attribute which updated (typically a DataFrame).
        ml_model_class: The scikit-learn classifier to use for classification tasks (default: `sklearn.linear_model.LogisticRegression()`).
        ml_model_regression: The scikit-learn regressor to use for regression tasks (default: `sklearn.linear_model.LinearRegression()`).
        params:List of clinical annotation columns to evaluate, or "all" to use all columns (default: "all").
        metric_class: Scoring metric for classification tasks (default: "roc_auc_ovo").
        metric_regression: Scoring metric for regression tasks (default: "r2").
        reference_methods:List of feature representations to evaluate (e.g., "PCA", "UMAP", "TSNE", "RandomFeature"). "Latent" is always included (default: []).
        split_type: which split to use
            use-split" for pre-defined splits, "CV-N" for N-fold cross-validation, or "LOOCV" for leave-one-out cross-validation (default: "use-split").
        n_downsample: If provided, downsample the data to this number of samples for faster evaluation. Default is 10000. Set to None to disable downsampling.
    Returns:
        The updated result object with evaluation results stored in `embedding_evaluation`.
    Raises
        ValueError: If required annotation data is missing or improperly formatted, or if an unsupported split type is specified.

    """

    already_warned = False

    df_results = pd.DataFrame()

    reference_methods.append("Latent")

    reference_methods = self._expand_reference_methods(
        reference_methods=reference_methods, result=result
    )

    ## Overwrite original datasets with new_datasets if available after predict with other data
    if datasets is None:
        datasets = DatasetContainer()

    if bool(result.new_datasets.test):
        datasets.test = result.new_datasets.test

    if not bool(datasets.train or datasets.valid or datasets.test):
        raise ValueError(
            "No datasets found in result object. Please run predict with new data or save/load with all datasets by using save_all=True."
        )
    elif split_type == "use-split" and not bool(datasets.train):
        warnings.warn(
            "Warning: No train split found in result datasets for 'use-split' evaluation. ML model cannot be trained without a train split. Switch to cross-validation (CV-5) instead."
        )
        split_type = "CV-5"

    for task in reference_methods:
        print(f"Perform ML task with feature df: {task}")

        # clin_data = self._get_clin_data(datasets)
        clin_data = BaseVisualizer._collect_all_metadata(result=result)

        if split_type == "use-split":
            # Pandas dataframe with sample_ids and split information
            sample_split = pd.DataFrame(columns=["SAMPLE_ID", "SPLIT"])

            if datasets.train is not None:
                if hasattr(datasets.train, "paired_sample_ids"):
                    if datasets.train.paired_sample_ids is not None:
                        sample_ids = datasets.train.paired_sample_ids
                else:
                    sample_ids = datasets.train.sample_ids
                sample_split_temp = dict(
                    sample_split,
                    **{
                        "SAMPLE_ID": sample_ids,
                        "SPLIT": ["train"] * len(sample_ids),
                    },
                )
                sample_split = pd.concat(
                    [sample_split, pd.DataFrame(sample_split_temp)],
                    axis=0,
                    ignore_index=True,
                )
            # else:
            #     raise ValueError(
            #         "No training data found. Please provide a valid training dataset."
            #     )
            if datasets.valid is not None:
                if hasattr(datasets.valid, "paired_sample_ids"):
                    if datasets.valid.paired_sample_ids is not None:
                        sample_ids = datasets.valid.paired_sample_ids
                else:
                    sample_ids = datasets.valid.sample_ids
                sample_split_temp = dict(
                    sample_split,
                    **{
                        "SAMPLE_ID": sample_ids,
                        "SPLIT": ["valid"] * len(sample_ids),
                    },
                )
                sample_split = pd.concat(
                    [sample_split, pd.DataFrame(sample_split_temp)],
                    axis=0,
                    ignore_index=True,
                )
            if datasets.test is not None:
                if hasattr(datasets.test, "paired_sample_ids"):
                    if datasets.test.paired_sample_ids is not None:
                        sample_ids = datasets.test.paired_sample_ids
                else:
                    sample_ids = datasets.test.sample_ids
                sample_split_temp = dict(
                    sample_split,
                    **{
                        "SAMPLE_ID": sample_ids,
                        "SPLIT": ["test"] * len(sample_ids),
                    },
                )
                sample_split = pd.concat(
                    [sample_split, pd.DataFrame(sample_split_temp)],
                    axis=0,
                    ignore_index=True,
                )

            sample_split = sample_split.set_index("SAMPLE_ID", drop=False)

        ## df -> task
        subtask = [task]
        if "RandomFeature" in task:
            subtask = [task + "_R" + str(x) for x in range(1, 6)]
        for sub in subtask:
            print(sub)
            # if is_modalix:
            #     modality = task.split("_$_")[1]
            #     task_xmodal = task.split("_$_")[0]

            #     df = self._load_input_for_ml_xmodal(task_xmodal, datasets, result, modality=modality)
            # else:
            df = self._load_input_for_ml(task, datasets, result)

            if params == "all":
                params = clin_data.columns.tolist()

            for task_param in params:
                if "Latent" in task:
                    print(f"Perform ML task for target parameter: {task_param}")
                ## Check if classification or regression task
                ml_type = self._get_ml_type(clin_data, task_param)

                if pd.isna(clin_data[task_param]).sum() > 0:
                    # if pd.isna(clin_data[task_param]).values.any():
                    if not already_warned:
                        print(
                            "There are NA values in the annotation file. Samples with missing data will be removed for ML task evaluation."
                        )
                    already_warned = True
                    # logger.warning(clin_data.loc[pd.isna(clin_data[task_param]), task_param])

                    samples_nonna = clin_data.loc[
                        pd.notna(clin_data[task_param]), task_param
                    ].index
                    # print(df)
                    df = df.loc[samples_nonna.intersection(df.index), :]
                    if split_type == "use-split":
                        sample_split = sample_split.loc[
                            samples_nonna.intersection(sample_split.index), :
                        ]
                    # print(sample_split)

                if n_downsample is not None:
                    if df.shape[0] > n_downsample:
                        sample_idx = np.random.choice(
                            df.shape[0], n_downsample, replace=False
                        )
                        df = df.iloc[sample_idx]
                        if split_type == "use-split":
                            sample_split = sample_split.loc[df.index, :]

                if ml_type == "classification":
                    metric = metric_class
                    sklearn_ml = ml_model_class

                if ml_type == "regression":
                    metric = metric_regression
                    sklearn_ml = ml_model_regression

                if split_type == "use-split":
                    # print("Sample Split:")
                    # print(sample_split)
                    # print("Latent:")
                    # print(df)
                    results = self._single_ml_presplit(
                        sample_split=sample_split,
                        df=df,
                        clin_data=clin_data,
                        task_param=task_param,
                        sklearn_ml=sklearn_ml,
                        metric=metric,
                        ml_type=ml_type,
                    )
                elif split_type.startswith("CV-"):
                    cv_folds = int(split_type.split("-")[1])

                    results = self._single_ml(
                        df=df,
                        clin_data=clin_data,
                        task_param=task_param,
                        sklearn_ml=sklearn_ml,
                        metric=metric,
                        cv_folds=cv_folds,
                    )
                elif split_type == "LOOCV":
                    # Leave One Out Cross Validation
                    results = self._single_ml(
                        df=df,
                        clin_data=clin_data,
                        task_param=task_param,
                        sklearn_ml=sklearn_ml,
                        metric=metric,
                        cv_folds=len(df),
                    )
                else:
                    raise ValueError(
                        f"Your split type {split_type} is not supported. Please use 'use-split', 'CV-5', 'LOOCV' or 'CV-N'."
                    )
                results = self._enrich_results(
                    results=results,
                    sklearn_ml=sklearn_ml,
                    ml_type=ml_type,
                    task=task,
                    sub=sub,
                )

                df_results = pd.concat([df_results, results])

    ## Check if embedding_evaluation is empty
    if (
        hasattr(result, "embedding_evaluation")
        and len(result.embedding_evaluation) == 0
    ):
        result.embedding_evaluation = df_results
    else:
        # merge with existing results
        result.embedding_evaluation = pd.concat(
            [result.embedding_evaluation, df_results], axis=0
        )

    return result

XModalixEvaluator

Bases: GeneralEvaluator

Source code in src/autoencodix/evaluate/_xmodalix_evaluator.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
class XModalixEvaluator(GeneralEvaluator):
    def __init__(self):
        # super().__init__()
        pass

    @staticmethod
    @no_type_check
    def pure_vae_comparison(
        xmodalix_result: Result,
        pure_vae_result: Result,
        to_key: str,
        param: Optional[str] = None,
    ) -> Tuple[Figure, pd.DataFrame]:
        """Compares the reconstruction performance of a pure VAE model and a cross-modal VAE (xmodalix) model using Mean Squared Error (MSE) on test samples.

        For each sample in the test set, computes the MSE between the original and reconstructed images for:
            - Pure VAE reconstructions ("imagix")
            - xmodalix reference reconstructions ("xmodalix_reference")
            - xmodalix translated reconstructions ("xmodalix_translated")
        The results are merged with sample metadata and returned in a long-format DataFrame suitable for plotting. Optionally, boxplots are generated grouped by a specified metadata parameter.

        Args:
            xmodalix_result: The result object containing xmodalix model outputs and test datasets.
            pure_vae_result: The result object containing pure VAE model outputs and test datasets.
            to_key: The key specifying the target modality in the xmodalix dataset.
            param: Metadata column name to group boxplots by. If None, plots are grouped by model only.

        Returns:
                - The matplotlib/seaborn boxplot figure comparing MSE distributions.
                - DataFrame: Long-format DataFrame containing MSE values and associated metadata for each sample and model.
        """

        if "img" not in to_key:
            raise NotImplementedError(
                "Comparison is currently only implemented for the image case."
            )

        ## Pure VAE MSE calculation
        meta_imagix = pure_vae_result.datasets.test.metadata
        if meta_imagix is None:
            raise ValueError("metadata cannot be None")
        sample_ids = list(meta_imagix.index)

        all_sample_order = sample_ids  ## TODO check code, seems unnecessary
        indices = [
            all_sample_order.index(sid) for sid in sample_ids if sid in all_sample_order
        ]

        mse_records = []

        for c in range(len(indices)):
            # print(f"Sample {c+1}/{len(indices)}: {sample_ids[c]}")

            # Original image
            orig = torch.Tensor(
                pure_vae_result.datasets.test.raw_data[indices[c]].img.squeeze()
            )

            # Reconstructed image
            recon = torch.Tensor(
                pure_vae_result.reconstructions.get(split="test", epoch=-1)[
                    indices[c]
                ].squeeze()
            )

            # Calculate MSE via torch
            mse_sample = F.mse_loss(orig, recon, reduction="mean")
            # print(f"Mean Squared Error (MSE) for sample {c+1}: {mse_sample.item()}")

            # Collect results
            mse_records.append(
                {"sample_id": sample_ids[c], "mse_imagix": mse_sample.item()}
            )

        df_imagix_mse = pd.DataFrame(mse_records)
        df_imagix_mse.set_index("sample_id", inplace=True)
        # Merge with meta_imagix
        df_imagix_mse = df_imagix_mse.join(meta_imagix, on="sample_id")

        meta_xmodalix = xmodalix_result.datasets.test.datasets[to_key].metadata
        sample_ids = list(meta_xmodalix.index)

        all_sample_order = sample_ids
        indices = [
            all_sample_order.index(sid) for sid in sample_ids if sid in all_sample_order
        ]

        mse_records = []

        for c in range(len(indices)):
            # print(f"Sample {c+1}/{len(indices)}: {sample_ids[c]}")

            # Original image
            orig = torch.Tensor(
                xmodalix_result.datasets.test.datasets[to_key][indices[c]][1].squeeze()
            )
            # print(orig.shape)

            # Reference Reconstructed image
            reference = torch.Tensor(
                xmodalix_result.reconstructions.get(epoch=-1, split="test")[
                    f"reference_{to_key}_to_{to_key}"
                ][indices[c]].squeeze()
            )
            # print(reference.shape)

            # Translated Reconstructed image
            translation = torch.Tensor(
                xmodalix_result.reconstructions.get(epoch=-1, split="test")[
                    "translation"
                ][indices[c]].squeeze()
            )
            # print(translation.shape)

            # Calculate MSE via torch
            mse_sample_translated = F.mse_loss(orig, translation, reduction="mean")
            # print(f"Mean Squared Error (MSE) for sample {c+1}: {mse_sample_translated.item()}")
            mse_sample_reference = F.mse_loss(orig, reference, reduction="mean")
            # print(f"Mean Squared Error (MSE) for sample {c+1}: {mse_sample_reference.item()}")

            # Collect results
            mse_records.append(
                {
                    "sample_id": sample_ids[c],
                    "mse_xmodalix_translated": mse_sample_translated.item(),
                    "mse_xmodalix_reference": mse_sample_reference.item(),
                }
            )

        df_xmodalix_mse = pd.DataFrame(mse_records)
        df_xmodalix_mse.set_index("sample_id", inplace=True)

        # Merge with meta_xmodalix
        df_xmodalix_mse = df_xmodalix_mse.join(meta_xmodalix, on="sample_id")

        # Merge via sample_id and keep non overlapping entries
        df_both_mse = df_imagix_mse.merge(
            df_xmodalix_mse, on=list(meta_imagix.columns), how="outer"
        )

        # Make long format for plotting
        df_long = df_both_mse.melt(
            id_vars=[
                col
                for col in df_both_mse.columns
                if col
                not in [
                    "mse_imagix",
                    "mse_xmodalix_translated",
                    "mse_xmodalix_reference",
                ]
            ],
            value_vars=[
                "mse_imagix",
                "mse_xmodalix_translated",
                "mse_xmodalix_reference",
            ],
            var_name="model",
            value_name="mse_value",
        )

        df_long["model"] = df_long["model"].map(
            {
                "mse_imagix": "imagix",
                "mse_xmodalix_translated": "xmodalix_translated",
                "mse_xmodalix_reference": "xmodalix_reference",
            }
        )

        if param:
            plt.figure(figsize=(2 * len(df_long[param].unique()), 8))

            fig = sns.boxplot(data=df_long, x=param, y="mse_value", hue="model")
            sns.move_legend(
                fig,
                "lower center",
                bbox_to_anchor=(0.5, 1),
                ncol=3,
                title=None,
                frameon=False,
            )
        else:
            plt.figure(figsize=(5, 8))

            fig = sns.boxplot(data=df_long, x="model", y="mse_value")
            # Rotate tick labels
            plt.xticks(rotation=-45)
            plt.xlabel("")

        return fig, df_long

    @staticmethod
    def _get_clin_data(datasets) -> Union[pd.Series, pd.DataFrame]:
        """Retrieves the clinical annotation DataFrame (clin_data) from the provided datasets.

        Handles both standard and XModalix dataset structures.
        """
        # XModalix-Case
        if hasattr(datasets.train, "datasets"):
            clin_data = pd.DataFrame()
            splits = [datasets.train, datasets.valid, datasets.test]

            for s in splits:
                for k in s.datasets.keys():
                    print(f"Processing dataset: {k}")
                    # Merge metadata by overlapping columns
                    overlap = clin_data.columns.intersection(
                        s.datasets[k].metadata.columns
                    )
                    if overlap.empty:
                        overlap = s.datasets[k].metadata.columns
                    clin_data = pd.concat(
                        [clin_data, s.datasets[k].metadata[overlap]], axis=0
                    )

            # Remove duplicate rows
            clin_data = clin_data[~clin_data.index.duplicated(keep="first")]
        else:
            # Raise error no annotation given
            raise ValueError(
                "No annotation data found. Please provide a valid annotation data type."
            )
        return clin_data

    def _enrich_results(
        self,
        results: pd.DataFrame,
        sklearn_ml: Union[ClassifierMixin, RegressorMixin],
        ml_type: str,
        task: str,
        sub: str,
    ) -> pd.DataFrame:
        res_ml_alg = [str(sklearn_ml) for x in range(0, results.shape[0])]
        res_ml_type = [ml_type for x in range(0, results.shape[0])]
        res_ml_subtask = [sub for x in range(0, results.shape[0])]

        results["ML_ALG"] = res_ml_alg
        results["ML_TYPE"] = res_ml_type

        modality = task.split("_$_")[1]
        task_xmodal = task.split("_$_")[0]

        results["MODALITY"] = [modality for x in range(0, results.shape[0])]
        results["ML_TASK"] = [task_xmodal for x in range(0, results.shape[0])]

        results["ML_SUBTASK"] = res_ml_subtask

        return results

    @staticmethod
    @no_type_check
    def _expand_reference_methods(reference_methods: list, result: Result) -> list:
        """
        Expands each reference method by appending a suffix for every key of used data modalities.
        For each method in `reference_methods`, this function generates new method names by concatenating
        the method name with each key for the data modalities of the xmodalix.
        Args:
            reference_methods (list): A list of reference method names to be expanded.
            result (Result): An object containing latent space information.
        Returns:
            list: A list of expanded reference method names, each suffixed with a key from the latent space.
        """
        if not isinstance(result.latentspaces.get(epoch=-1, split="train"), dict):
            raise NotImplementedError(
                "This evaluate feature does not support .save(save_all=False) results."
            )
        reference_methods = [
            f"{method}_$_{key}"
            for method in reference_methods
            for key in result.latentspaces.get(epoch=-1, split="train").keys()
        ]

        return reference_methods

    ## New for x-modalix
    @staticmethod
    def _load_input_for_ml(
        task: str, dataset: DatasetContainer, result: Result
    ) -> pd.DataFrame:
        """Loads and processes input data for various machine learning tasks based on the specified task type.

        Task Details:
            - "Latent": Concatenates latent representations from train, validation, and test splits at the final epoch.
            - "UMAP": Applies UMAP dimensionality reduction to the concatenated dataset splits.
            - "PCA": Applies PCA dimensionality reduction to the concatenated dataset splits.
            - "TSNE": Applies t-SNE dimensionality reduction to the concatenated dataset splits.
            - "RandomFeature": Randomly samples columns (features) from the concatenated dataset splits.


        Args:
            task: The type of ML task. Supported values are "Latent", "UMAP", "PCA", "TSNE", and "RandomFeature".
            dataset: The dataset container object holding train, validation, and test splits.
            result: The result object containing model configuration and methods to retrieve latent representations.
        Returns:
            A DataFrame containing the processed input data suitable for the specified ML task.
        Raises:
            ValueError: If the provided task is not supported.
        """

        # final_epoch = result.model.config.epochs - 1
        modality = task.split("_$_")[1]
        task = task.split("_$_")[0]

        if task == "Latent":
            df = pd.concat(
                [
                    result.get_latent_df(epoch=-1, split="train", modality=modality),
                    result.get_latent_df(epoch=-1, split="valid", modality=modality),
                    result.get_latent_df(epoch=-1, split="test", modality=modality),
                ]
            )
        elif task in ["UMAP", "PCA", "TSNE", "RandomFeature"]:
            latent_dim = result.get_latent_df(
                epoch=-1, split="train", modality=modality
            ).shape[1]
            if dataset.train is None:
                raise ValueError("train attribute of dataset cannot be None")
            if dataset.valid is None:
                raise ValueError("valid attribute of dataset cannot be None")
            if dataset.test is None:
                raise ValueError("test attribute of dataset cannot be None")

            df_processed = pd.concat(
                [
                    dataset.train._to_df(modality=modality),
                    dataset.test._to_df(modality=modality),
                    dataset.valid._to_df(modality=modality),
                ]
            )
            if task == "UMAP":
                reducer = UMAP(n_components=latent_dim)
                df = pd.DataFrame(
                    reducer.fit_transform(df_processed), index=df_processed.index
                )
            elif task == "PCA":
                reducer = PCA(n_components=latent_dim)
                df = pd.DataFrame(
                    reducer.fit_transform(df_processed), index=df_processed.index
                )
            elif task == "TSNE":
                reducer = TSNE(n_components=latent_dim)
                df = pd.DataFrame(
                    reducer.fit_transform(df_processed), index=df_processed.index
                )
            elif task == "RandomFeature":
                df = df_processed.sample(n=latent_dim, axis=1)
        else:
            raise ValueError(
                f"Your ML task {task} is not supported. Please use Latent, UMAP, PCA or RandomFeature."
            )

        return df

pure_vae_comparison(xmodalix_result, pure_vae_result, to_key, param=None) staticmethod

Compares the reconstruction performance of a pure VAE model and a cross-modal VAE (xmodalix) model using Mean Squared Error (MSE) on test samples.

For each sample in the test set, computes the MSE between the original and reconstructed images for: - Pure VAE reconstructions ("imagix") - xmodalix reference reconstructions ("xmodalix_reference") - xmodalix translated reconstructions ("xmodalix_translated") The results are merged with sample metadata and returned in a long-format DataFrame suitable for plotting. Optionally, boxplots are generated grouped by a specified metadata parameter.

Parameters:

Name Type Description Default
xmodalix_result Result

The result object containing xmodalix model outputs and test datasets.

required
pure_vae_result Result

The result object containing pure VAE model outputs and test datasets.

required
to_key str

The key specifying the target modality in the xmodalix dataset.

required
param Optional[str]

Metadata column name to group boxplots by. If None, plots are grouped by model only.

None

Returns:

Type Description
Figure
  • The matplotlib/seaborn boxplot figure comparing MSE distributions.
DataFrame
  • DataFrame: Long-format DataFrame containing MSE values and associated metadata for each sample and model.
Source code in src/autoencodix/evaluate/_xmodalix_evaluator.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
@staticmethod
@no_type_check
def pure_vae_comparison(
    xmodalix_result: Result,
    pure_vae_result: Result,
    to_key: str,
    param: Optional[str] = None,
) -> Tuple[Figure, pd.DataFrame]:
    """Compares the reconstruction performance of a pure VAE model and a cross-modal VAE (xmodalix) model using Mean Squared Error (MSE) on test samples.

    For each sample in the test set, computes the MSE between the original and reconstructed images for:
        - Pure VAE reconstructions ("imagix")
        - xmodalix reference reconstructions ("xmodalix_reference")
        - xmodalix translated reconstructions ("xmodalix_translated")
    The results are merged with sample metadata and returned in a long-format DataFrame suitable for plotting. Optionally, boxplots are generated grouped by a specified metadata parameter.

    Args:
        xmodalix_result: The result object containing xmodalix model outputs and test datasets.
        pure_vae_result: The result object containing pure VAE model outputs and test datasets.
        to_key: The key specifying the target modality in the xmodalix dataset.
        param: Metadata column name to group boxplots by. If None, plots are grouped by model only.

    Returns:
            - The matplotlib/seaborn boxplot figure comparing MSE distributions.
            - DataFrame: Long-format DataFrame containing MSE values and associated metadata for each sample and model.
    """

    if "img" not in to_key:
        raise NotImplementedError(
            "Comparison is currently only implemented for the image case."
        )

    ## Pure VAE MSE calculation
    meta_imagix = pure_vae_result.datasets.test.metadata
    if meta_imagix is None:
        raise ValueError("metadata cannot be None")
    sample_ids = list(meta_imagix.index)

    all_sample_order = sample_ids  ## TODO check code, seems unnecessary
    indices = [
        all_sample_order.index(sid) for sid in sample_ids if sid in all_sample_order
    ]

    mse_records = []

    for c in range(len(indices)):
        # print(f"Sample {c+1}/{len(indices)}: {sample_ids[c]}")

        # Original image
        orig = torch.Tensor(
            pure_vae_result.datasets.test.raw_data[indices[c]].img.squeeze()
        )

        # Reconstructed image
        recon = torch.Tensor(
            pure_vae_result.reconstructions.get(split="test", epoch=-1)[
                indices[c]
            ].squeeze()
        )

        # Calculate MSE via torch
        mse_sample = F.mse_loss(orig, recon, reduction="mean")
        # print(f"Mean Squared Error (MSE) for sample {c+1}: {mse_sample.item()}")

        # Collect results
        mse_records.append(
            {"sample_id": sample_ids[c], "mse_imagix": mse_sample.item()}
        )

    df_imagix_mse = pd.DataFrame(mse_records)
    df_imagix_mse.set_index("sample_id", inplace=True)
    # Merge with meta_imagix
    df_imagix_mse = df_imagix_mse.join(meta_imagix, on="sample_id")

    meta_xmodalix = xmodalix_result.datasets.test.datasets[to_key].metadata
    sample_ids = list(meta_xmodalix.index)

    all_sample_order = sample_ids
    indices = [
        all_sample_order.index(sid) for sid in sample_ids if sid in all_sample_order
    ]

    mse_records = []

    for c in range(len(indices)):
        # print(f"Sample {c+1}/{len(indices)}: {sample_ids[c]}")

        # Original image
        orig = torch.Tensor(
            xmodalix_result.datasets.test.datasets[to_key][indices[c]][1].squeeze()
        )
        # print(orig.shape)

        # Reference Reconstructed image
        reference = torch.Tensor(
            xmodalix_result.reconstructions.get(epoch=-1, split="test")[
                f"reference_{to_key}_to_{to_key}"
            ][indices[c]].squeeze()
        )
        # print(reference.shape)

        # Translated Reconstructed image
        translation = torch.Tensor(
            xmodalix_result.reconstructions.get(epoch=-1, split="test")[
                "translation"
            ][indices[c]].squeeze()
        )
        # print(translation.shape)

        # Calculate MSE via torch
        mse_sample_translated = F.mse_loss(orig, translation, reduction="mean")
        # print(f"Mean Squared Error (MSE) for sample {c+1}: {mse_sample_translated.item()}")
        mse_sample_reference = F.mse_loss(orig, reference, reduction="mean")
        # print(f"Mean Squared Error (MSE) for sample {c+1}: {mse_sample_reference.item()}")

        # Collect results
        mse_records.append(
            {
                "sample_id": sample_ids[c],
                "mse_xmodalix_translated": mse_sample_translated.item(),
                "mse_xmodalix_reference": mse_sample_reference.item(),
            }
        )

    df_xmodalix_mse = pd.DataFrame(mse_records)
    df_xmodalix_mse.set_index("sample_id", inplace=True)

    # Merge with meta_xmodalix
    df_xmodalix_mse = df_xmodalix_mse.join(meta_xmodalix, on="sample_id")

    # Merge via sample_id and keep non overlapping entries
    df_both_mse = df_imagix_mse.merge(
        df_xmodalix_mse, on=list(meta_imagix.columns), how="outer"
    )

    # Make long format for plotting
    df_long = df_both_mse.melt(
        id_vars=[
            col
            for col in df_both_mse.columns
            if col
            not in [
                "mse_imagix",
                "mse_xmodalix_translated",
                "mse_xmodalix_reference",
            ]
        ],
        value_vars=[
            "mse_imagix",
            "mse_xmodalix_translated",
            "mse_xmodalix_reference",
        ],
        var_name="model",
        value_name="mse_value",
    )

    df_long["model"] = df_long["model"].map(
        {
            "mse_imagix": "imagix",
            "mse_xmodalix_translated": "xmodalix_translated",
            "mse_xmodalix_reference": "xmodalix_reference",
        }
    )

    if param:
        plt.figure(figsize=(2 * len(df_long[param].unique()), 8))

        fig = sns.boxplot(data=df_long, x=param, y="mse_value", hue="model")
        sns.move_legend(
            fig,
            "lower center",
            bbox_to_anchor=(0.5, 1),
            ncol=3,
            title=None,
            frameon=False,
        )
    else:
        plt.figure(figsize=(5, 8))

        fig = sns.boxplot(data=df_long, x="model", y="mse_value")
        # Rotate tick labels
        plt.xticks(rotation=-45)
        plt.xlabel("")

    return fig, df_long