Skip to content

Configs Module

DataInfo

Bases: BaseModel, SchemaPrinterMixin

Source code in src/autoencodix/configs/default_config.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class DataInfo(BaseModel, SchemaPrinterMixin):
    # general -------------------------------------
    file_path: str = Field(default="", description="Path to raw data file")
    data_type: Literal["NUMERIC", "CATEGORICAL", "IMG", "ANNOTATION"] = Field(
        default="NUMERIC"
    )
    scaling: Literal[
        "STANDARD", "MINMAX", "ROBUST", "MAXABS", "NONE", "NOTSET", "LOG1P"
    ] = Field(
        default="NOTSET",
        description="Setting the scaling here in DataInfo overrides the globally set scaling method for the specific data modality",
    )  # can also be set globally, for all data modalities.

    filtering: Literal["VAR", "MAD", "CORR", "VARCORR", "NOFILT", "NONZEROVAR"] = Field(
        default="VAR"
    )
    sep: Union[str, None] = Field(default=None)  # for pandas read_csv
    extra_anno_file: Union[str, None] = Field(default=None)

    # single cell specific -------------------------
    is_single_cell: bool = Field(default=False)

    min_cells: float = Field(
        default=0.05,
        ge=0,
        le=1,
        description="Minimum fraction of cells a gene must be expressed in to be kept. Genes expressed in fewer cells will be filtered out.",
    )  # Controls gene filtering based on expression prevalence

    min_genes: float = Field(
        default=0.02,
        ge=0,
        le=1,
        description="Minimum fraction of genes a cell must express to be kept. Cells expressing fewer genes will be filtered out.",
    )  # Controls cell quality filtering
    selected_layers: List[str] = Field(default=["X"])

    is_X: bool = Field(default=False)  # only for single cell data
    normalize_counts: bool = Field(
        default=True, description="Whether to normalize by total counts"
    )
    log_transform: bool = Field(
        default=False, description="Whether to apply log1p transformation"
    )
    k_filter: Optional[int] = Field(
        default=None,
        description="Don't set this gets calculated dynamically, based on k_filter in general config ",
    )
    # image specific ------------------------------
    img_width_resize: Union[int, None] = Field(default=64)
    img_height_resize: Union[int, None] = Field(default=64)
    # annotation specific -------------------------
    # xmodalix specific -------------------------
    translate_direction: Union[Literal["from", "to"], None] = Field(default=None)
    pretrain_epochs: Optional[int] = Field(
        default=None,
        description="Number of pretraining epochs. This overwrites the global 'pretraining_epochs' in DefaultConfig class to have different number of pretraining epochs for each data modality",
    )

    @field_validator("selected_layers")
    @classmethod
    def validate_selected_layers(cls, v):
        if "X" not in v:
            raise ValueError('"X" must always be a part of the selected_layers list')
        return v

    @field_validator("k_filter", mode="before")
    @classmethod
    def _forbid_user_k_filter(cls, v: Any, info: ValidationInfo) -> Any:
        """
        'before'  -> runs only when the value comes from user input.
        After instantiation we can still do  data_info.k_filter = xx
        """
        if v is not None:
            raise ValueError(
                "k_filter is computed automatically for each data modality, based on global k_filter – remove it from your DataInfo configuration."
            )
        return v

    # # add validation to only allow quadratic image resizing
    # @field_validator("img_width_resize", "img_height_resize")
    # @classmethod
    # def validate_image_resize(cls, v, values):
    #     if v is not None and v <= 0:
    #         raise ValueError("Image resize dimensions must be positive integers")
    #     if "img_width_resize" in values and "img_height_resize" in values:
    #         if values["img_width_resize"] != values["img_height_resize"]:
    #             raise ValueError("Image width and height must be the same for resizing")
    #     return v

    @field_validator("img_width_resize", "img_height_resize")
    @classmethod
    def validate_image_resize(cls, v, info: ValidationInfo):
        if v is not None and v <= 0:
            raise ValueError("Image resize dimensions must be positive integers")

        # Access other field values through info.data
        data = info.data
        if "img_width_resize" in data and "img_height_resize" in data:
            if data["img_width_resize"] != data["img_height_resize"]:
                raise ValueError("Image width and height must be the same for resizing")
        return v

DefaultConfig

Bases: BaseModel, SchemaPrinterMixin

Complete configuration for model, training, hardware, and data handling.

Source code in src/autoencodix/configs/default_config.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
class DefaultConfig(BaseModel, SchemaPrinterMixin):
    """Complete configuration for model, training, hardware, and data handling."""

    # Input validation
    model_config = ConfigDict(extra="forbid")
    # Datasets configuration --------------------------------------------------
    data_config: DataConfig = DataConfig(data_info={})
    img_path_col: str = Field(
        default="img_paths",
        description="When working with images, we except a column in your annotation file that specifies the path of the image for a particular sample. Here you can define the name of this column",
    )
    requires_paired: Union[bool, None] = Field(
        default_factory=lambda: True,
        description="Indicator if the samples for the xmodalix are paired, based on some sample id",
    )

    data_case: Union[DataCase, None] = Field(
        default_factory=lambda: None,
        description="Data case for the model, will be determined automatically",
    )
    k_filter: Union[int, None] = Field(
        default=20, description="Number of features to keep"
    )
    scaling: Literal["STANDARD", "MINMAX", "ROBUST", "MAXABS", "NONE", "LOG1P"] = Field(
        default="STANDARD",
        description="Setting the scaling here for all data modalities, can per overruled by setting scaling at data modality level per data modality",
    )

    skip_preprocessing: bool = Field(
        default=False, description="If set don't scale, filter or clean the input data."
    )

    class_param: Optional[str] = Field(default=None)

    # Model configuration -----------------------------------------------------
    latent_dim: int = Field(
        default=16, ge=1, description="Dimension of the latent space"
    )
    hidden_dim: int = Field(
        default=16,
        ge=1,
        description="Hidden dimension of image_vae, applies only to image_vae",
    )
    n_layers: int = Field(
        default=3,
        ge=0,
        description="Number of layers in encoder/decoder, without latent layer. If 0, is only the latent layer.",
    )
    enc_factor: float = Field(
        default=4, gt=0, description="Scaling factor for encoder dimensions"
    )
    maskix_hidden_dim: int = Field(
        default=256,
        ge=8,
        description="The Maskix implementation follows https://doi.org/10.1093/bioinformatics/btae020. The authors use a hidden dimension 0f 256 for their neural network, so we set this as default",
    )
    maskix_swap_prob: float = Field(
        default=0.4,
        ge=0,
        description="For the Maskix input_data masinkg, we sample a probablity if samples within one gene should be swapt. This is done with a Bernoulli distribution, maskix_swap_prob is the probablity passed to the bernoulli distribution ",
    )
    drop_p: float = Field(
        default=0.1, ge=0.0, le=1.0, description="Dropout probability"
    )

    # Training configuration --------------------------------------------------
    save_memory: bool = Field(
        default=False, description="If set to True we don't store TrainingDynamics"
    )
    learning_rate: float = Field(
        default=0.001, gt=0, description="Learning rate for optimization"
    )
    batch_size: int = Field(
        default=32,
        ge=2,
        description="Number of samples per batch, has to be > 1, because we use BatchNorm() Layer",
    )
    epochs: int = Field(default=3, ge=1, description="Number of training epochs")
    weight_decay: float = Field(
        default=0.01, ge=0, description="L2 regularization factor"
    )
    reconstruction_loss: Literal["mse", "bce"] = Field(
        default="mse", description="Type of reconstruction loss"
    )
    default_vae_loss: Literal["kl", "mmd"] = Field(
        default="kl", description="Type of VAE loss"
    )
    loss_reduction: Literal["sum", "mean"] = Field(
        default="sum",
        description="Loss reduction in PyTorch i.e in torch.nn.functional.binary_cross_entropy_with_logits(reduction=loss_reduction)",
    )
    beta: float = Field(
        default=1, ge=0, description="Beta weighting factor for VAE loss"
    )
    beta_mi: float = Field(
        default=1,
        ge=0,
        description="Beta weighting factor for mutual information term in disentangled VAE loss",
    )
    beta_tc: float = Field(
        default=1,
        ge=0,
        description="Beta weighting factor for total correlation term in disentangled VAE loss",
    )
    beta_dimKL: float = Field(
        default=1,
        ge=0,
        description="Beta weighting factor for dimension-wise KL in disentangled VAE loss",
    )
    use_mss: bool = Field(
        default=True,
        description="Using minibatch stratified sampling for disentangled VAE loss calculation (faster estimation)",
    )
    gamma: float = Field(
        default=10.0,
        ge=0,
        description="Gamma weighting factor for Adversial Loss Term i.e. for XModalix Classfier training",
    )
    delta_pair: float = Field(
        default=5.0,
        ge=0,
        description="Delta weighting factor for paired loss term in XModalix Training",
    )
    delta_class: float = Field(
        default=5.0,
        ge=0,
        description="Delta weighting factor for class loss term in XModalix Training",
    )
    delta_mask_predictor: float = Field(
        default=0.7,
        ge=0.0,
        description="Delt weighting factor of the mask predictin loss term for the Maskix",
    )
    delta_mask_corrupted: float = Field(
        default=0.75,
        ge=0.0,
        description="For the Maskix: if >0.5 this gives more weight for the correct reconstruction of corrupted input",
    )
    min_samples_per_split: int = Field(
        default=1, ge=1, description="Minimum number of samples per split"
    )
    anneal_function: Literal[
        "5phase-constant",
        "3phase-linear",
        "3phase-log",
        "logistic-mid",
        "logistic-early",
        "logistic-late",
        "no-annealing",
    ] = Field(
        default="logistic-mid",
        description="Annealing function strategy for VAE loss scheduling",
    )
    pretrain_epochs: int = Field(
        default=0,
        ge=0,
        description="Number of pretraining epochs, can be overwritten in DataInfo to have different number of pretraining epochs for each data modality",
    )

    # Hardware configuration --------------------------------------------------
    device: Literal["cpu", "cuda", "gpu", "tpu", "mps", "auto"] = Field(
        default="auto", description="Device to use"
    )
    # 0 uses cpu and not gpu
    n_gpus: int = Field(default=1, ge=1, description="Number of GPUs to use")
    checkpoint_interval: int = Field(
        default=10, ge=1, description="Interval for saving checkpoints"
    )
    float_precision: Literal[
        "transformer-engine",
        "transformer-engine-float16",
        "16-true",
        "16-mixed",
        "bf16-true",
        "bf16-mixed",
        "32-true",
        "64-true",
        "64",
        "32",
        "16",
        "bf16",
    ] = Field(default="32", description="Floating point precision")
    gpu_strategy: Literal[
        "auto",
        "dp",
        "ddp",
        "ddp_spawn",
        "ddp_find_unused_parameters_true",
        "xla",
        "deepspeed",
        "fsdp",
    ] = Field(default="auto", description="GPU parallelization strategy")

    # Data handling configuration ---------------------------------------------
    train_ratio: float = Field(
        default=0.7, ge=0, lt=1, description="Ratio of data for training"
    )
    test_ratio: float = Field(
        default=0.2, ge=0, lt=1, description="Ratio of data for testing"
    )
    valid_ratio: float = Field(
        default=0.1, ge=0, lt=1, description="Ratio of data for validation"
    )

    # General configuration ---------------------------------------------------
    reproducible: bool = Field(
        default=False, description="Whether to ensure reproducibility"
    )
    global_seed: int = Field(default=1, ge=0, description="Global random seed")

    ##### VALIDATION ##### -----------------------------------------------------
    ##### ----------------- -----------------------------------------------------
    @field_validator("data_config")
    @classmethod
    def validate_data_config(cls, data_config: DataConfig):
        """Main validation logic for dataset consistency and translation."""
        data_info = data_config.data_info

        numeric_count = sum(
            1 for info in data_info.values() if info.data_type == "NUMERIC"
        )
        img_count = sum(1 for info in data_info.values() if info.data_type == "IMG")

        if numeric_count == 0 and img_count == 0:
            raise ConfigValidationError("At least one NUMERIC dataset is required.")

        numeric_datasets = [
            info for info in data_info.values() if info.data_type == "NUMERIC"
        ]
        if numeric_datasets:
            is_single_cell = numeric_datasets[0].is_single_cell
            if any(info.is_single_cell != is_single_cell for info in numeric_datasets):
                raise ConfigValidationError(
                    "All numeric datasets must be either single cell or bulk."
                )

        from_dataset = next(
            (
                (name, info)
                for name, info in data_info.items()
                if info.translate_direction == "from"
            ),
            None,
        )
        to_dataset = next(
            (
                (name, info)
                for name, info in data_info.items()
                if info.translate_direction == "to"
            ),
            None,
        )

        if bool(from_dataset) != bool(to_dataset):
            raise ConfigValidationError(
                "Translation requires exactly one 'from' and one 'to' dataset."
            )

        if from_dataset and to_dataset:
            from_info, to_info = from_dataset[1], to_dataset[1]
            if from_info.data_type == "NUMERIC" and to_info.data_type == "NUMERIC":
                if from_info.is_single_cell != to_info.is_single_cell:
                    raise ConfigValidationError(
                        "Cannot translate between single cell and bulk data."
                    )

        return data_config

    @model_validator(mode="after")
    def determine_case(self) -> "DefaultConfig":
        """Assign the correct DataCase after model validation."""
        data_info = self.data_config.data_info

        # Handle empty data_info case
        if not data_info:
            return self

        # Find 'from' and 'to' datasets
        from_dataset = next(
            (
                (name, info)
                for name, info in data_info.items()
                if info.translate_direction == "from"
            ),
            None,
        )
        to_dataset = next(
            (
                (name, info)
                for name, info in data_info.items()
                if info.translate_direction == "to"
            ),
            None,
        )

        if from_dataset and to_dataset:
            from_info, to_info = from_dataset[1], to_dataset[1]
            if from_info.data_type == "NUMERIC" and to_info.data_type == "NUMERIC":
                self.data_case = (
                    DataCase.SINGLE_CELL_TO_SINGLE_CELL
                    if from_info.is_single_cell
                    else DataCase.BULK_TO_BULK
                )
            elif "IMG" in {from_info.data_type, to_info.data_type}:
                numeric_dataset = (
                    from_info if from_info.data_type == "NUMERIC" else to_info
                )
                # check for IMG_IMG
                if from_info.data_type == "IMG" and to_info.data_type == "IMG":
                    self.data_case = DataCase.IMG_TO_IMG
                else:
                    self.data_case = (
                        DataCase.SINGLE_CELL_TO_IMG
                        if numeric_dataset.is_single_cell
                        else DataCase.IMG_TO_BULK
                    )
        else:
            img_ds = [info for info in data_info.values() if info.data_type == "IMG"]
            if img_ds:
                self.data_case = DataCase.IMG_TO_IMG

            numeric_datasets = [
                info for info in data_info.values() if info.data_type == "NUMERIC"
            ]

            if numeric_datasets:
                numeric_dataset = numeric_datasets[0]
                self.data_case = (
                    DataCase.MULTI_SINGLE_CELL
                    if numeric_dataset.is_single_cell
                    else DataCase.MULTI_BULK
                )
            if self.data_case is None:
                import warnings

                warnings.warn(message="Could not determine data_case")

        return self

    @field_validator("test_ratio", "valid_ratio")
    def validate_ratios(cls, v, values):
        total = (
            sum(
                values.data.get(key, 0)
                for key in ["train_ratio", "test_ratio", "valid_ratio"]
            )
            + v
        )
        if total > 1.0:
            raise ValueError(f"Data split ratios must sum to 1.0 or less (got {total})")
        return v

    # TODO test if other float precisions work with MPS
    @field_validator("float_precision")
    def validate_float_precision(cls, v, values):
        """Validate float precision based on device type."""
        device = values.data["device"]
        if device == "mps" and v != "32":
            raise ValueError("MPS backend only supports float precision '32'")
        return v

    # gpu strategy needs to be auto for mps # TODO test if other strategies work
    @field_validator("gpu_strategy")
    def validate_gpu_strategy(cls, v, values):
        device = values.data.get("device")
        if device == "mps" and v != "auto":
            raise ValueError("MPS backend only supports GPU strategy 'auto'")

    @model_validator(mode="after")
    def validate_k_filter_with_nonzero_var(self):
        k_filter = self.k_filter

        data_info = self.data_config.data_info

        for info in data_info.values():
            if info.filtering == "NONZEROVAR" and k_filter is not None:
                raise ValueError(
                    "k_filter cannot be combined with DataInfo that has scaling set to 'NONZEROVAR'"
                )

        return self

    #### END VALIDATION #### --------------------------------------------------

    #### READIBILITY #### ------------------------------------------------------
    #### ------------ #### ------------------------------------------------------
    @classmethod
    def get_params(cls) -> Dict[str, Dict[str, Any]]:
        """
        Get detailed information about all config fields including types and default values.

        Returns:
            Dictionary containing field name, type, default value, and description if available
        """
        fields_info = {}
        for name, field in cls.model_fields.items():
            fields_info[name] = {
                "type": str(field.annotation),
                "default": field.default,
                "description": field.description or "No description available",
            }
        return fields_info

    @classmethod
    def print_schema(cls, filter_params: Optional[None] = None) -> None:  # type: ignore
        """
        Print a human-readable schema of all config parameters.
        """
        if filter_params:
            filter_params = list(filter_params)
            print("Valid Keyword Arguments:")
            print("-" * 50)
        else:
            print(f"\n{cls.__name__} Configuration Parameters:")
            print("-" * 50)

        for name, info in cls.get_params().items():
            if filter_params and name not in filter_params:
                continue
            print(f"\n{name}:")
            print(f"  Type: {info['type']}")
            print(f"  Default: {info['default']}")  # type: ignore
            print(f"  Description: {info['description']}")

determine_case()

Assign the correct DataCase after model validation.

Source code in src/autoencodix/configs/default_config.py
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
@model_validator(mode="after")
def determine_case(self) -> "DefaultConfig":
    """Assign the correct DataCase after model validation."""
    data_info = self.data_config.data_info

    # Handle empty data_info case
    if not data_info:
        return self

    # Find 'from' and 'to' datasets
    from_dataset = next(
        (
            (name, info)
            for name, info in data_info.items()
            if info.translate_direction == "from"
        ),
        None,
    )
    to_dataset = next(
        (
            (name, info)
            for name, info in data_info.items()
            if info.translate_direction == "to"
        ),
        None,
    )

    if from_dataset and to_dataset:
        from_info, to_info = from_dataset[1], to_dataset[1]
        if from_info.data_type == "NUMERIC" and to_info.data_type == "NUMERIC":
            self.data_case = (
                DataCase.SINGLE_CELL_TO_SINGLE_CELL
                if from_info.is_single_cell
                else DataCase.BULK_TO_BULK
            )
        elif "IMG" in {from_info.data_type, to_info.data_type}:
            numeric_dataset = (
                from_info if from_info.data_type == "NUMERIC" else to_info
            )
            # check for IMG_IMG
            if from_info.data_type == "IMG" and to_info.data_type == "IMG":
                self.data_case = DataCase.IMG_TO_IMG
            else:
                self.data_case = (
                    DataCase.SINGLE_CELL_TO_IMG
                    if numeric_dataset.is_single_cell
                    else DataCase.IMG_TO_BULK
                )
    else:
        img_ds = [info for info in data_info.values() if info.data_type == "IMG"]
        if img_ds:
            self.data_case = DataCase.IMG_TO_IMG

        numeric_datasets = [
            info for info in data_info.values() if info.data_type == "NUMERIC"
        ]

        if numeric_datasets:
            numeric_dataset = numeric_datasets[0]
            self.data_case = (
                DataCase.MULTI_SINGLE_CELL
                if numeric_dataset.is_single_cell
                else DataCase.MULTI_BULK
            )
        if self.data_case is None:
            import warnings

            warnings.warn(message="Could not determine data_case")

    return self

get_params() classmethod

Get detailed information about all config fields including types and default values.

Returns:

Type Description
Dict[str, Dict[str, Any]]

Dictionary containing field name, type, default value, and description if available

Source code in src/autoencodix/configs/default_config.py
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
@classmethod
def get_params(cls) -> Dict[str, Dict[str, Any]]:
    """
    Get detailed information about all config fields including types and default values.

    Returns:
        Dictionary containing field name, type, default value, and description if available
    """
    fields_info = {}
    for name, field in cls.model_fields.items():
        fields_info[name] = {
            "type": str(field.annotation),
            "default": field.default,
            "description": field.description or "No description available",
        }
    return fields_info

print_schema(filter_params=None) classmethod

Print a human-readable schema of all config parameters.

Source code in src/autoencodix/configs/default_config.py
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
@classmethod
def print_schema(cls, filter_params: Optional[None] = None) -> None:  # type: ignore
    """
    Print a human-readable schema of all config parameters.
    """
    if filter_params:
        filter_params = list(filter_params)
        print("Valid Keyword Arguments:")
        print("-" * 50)
    else:
        print(f"\n{cls.__name__} Configuration Parameters:")
        print("-" * 50)

    for name, info in cls.get_params().items():
        if filter_params and name not in filter_params:
            continue
        print(f"\n{name}:")
        print(f"  Type: {info['type']}")
        print(f"  Default: {info['default']}")  # type: ignore
        print(f"  Description: {info['description']}")

validate_data_config(data_config) classmethod

Main validation logic for dataset consistency and translation.

Source code in src/autoencodix/configs/default_config.py
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
@field_validator("data_config")
@classmethod
def validate_data_config(cls, data_config: DataConfig):
    """Main validation logic for dataset consistency and translation."""
    data_info = data_config.data_info

    numeric_count = sum(
        1 for info in data_info.values() if info.data_type == "NUMERIC"
    )
    img_count = sum(1 for info in data_info.values() if info.data_type == "IMG")

    if numeric_count == 0 and img_count == 0:
        raise ConfigValidationError("At least one NUMERIC dataset is required.")

    numeric_datasets = [
        info for info in data_info.values() if info.data_type == "NUMERIC"
    ]
    if numeric_datasets:
        is_single_cell = numeric_datasets[0].is_single_cell
        if any(info.is_single_cell != is_single_cell for info in numeric_datasets):
            raise ConfigValidationError(
                "All numeric datasets must be either single cell or bulk."
            )

    from_dataset = next(
        (
            (name, info)
            for name, info in data_info.items()
            if info.translate_direction == "from"
        ),
        None,
    )
    to_dataset = next(
        (
            (name, info)
            for name, info in data_info.items()
            if info.translate_direction == "to"
        ),
        None,
    )

    if bool(from_dataset) != bool(to_dataset):
        raise ConfigValidationError(
            "Translation requires exactly one 'from' and one 'to' dataset."
        )

    if from_dataset and to_dataset:
        from_info, to_info = from_dataset[1], to_dataset[1]
        if from_info.data_type == "NUMERIC" and to_info.data_type == "NUMERIC":
            if from_info.is_single_cell != to_info.is_single_cell:
                raise ConfigValidationError(
                    "Cannot translate between single cell and bulk data."
                )

    return data_config

validate_float_precision(v, values)

Validate float precision based on device type.

Source code in src/autoencodix/configs/default_config.py
538
539
540
541
542
543
544
@field_validator("float_precision")
def validate_float_precision(cls, v, values):
    """Validate float precision based on device type."""
    device = values.data["device"]
    if device == "mps" and v != "32":
        raise ValueError("MPS backend only supports float precision '32'")
    return v

DisentanglixConfig

Bases: DefaultConfig

A specialized configuration inheriting from DefaultConfig.

Source code in src/autoencodix/configs/disentanglix_config.py
 5
 6
 7
 8
 9
10
11
12
13
14
class DisentanglixConfig(DefaultConfig):
    """
    A specialized configuration inheriting from DefaultConfig.
    """

    beta: float = Field(
        default=0.1,  # Overridden default (was 1.0)
        ge=0,
        description="Beta weighting factor for VAE loss",
    )

MaskixConfig

Bases: DefaultConfig

A specialized configuration inheriting from DefaultConfig.

Source code in src/autoencodix/configs/maskix_config.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class MaskixConfig(DefaultConfig):
    """
    A specialized configuration inheriting from DefaultConfig.
    """

    beta: float = Field(
        default=0.1,  # Overridden default (was 1.0)
        ge=0,
        description="Beta weighting factor for VAE loss",
    )
    epoch: int = Field(
        default=30, ge=0, description="How many epochs should the model train for."
    )
    maskix_hidden_dim: int = Field(
        default=128,
        ge=8,
        description="The Maskix implementation follows https://doi.org/10.1093/bioinformatics/btae020. The authors use a hidden dimension 0f 256 for their neural network, so we set this as default",
    )
    maskix_swap_prob: float = Field(
        default=0.2,
        ge=0,
        description="For the Maskix input_data masinkg, we sample a probablity if samples within one gene should be swapt. This is done with a Bernoulli distribution, maskix_swap_prob is the probablity passed to the bernoulli distribution ",
    )
    delta_mask_predictor: float = Field(
        default=0.7,
        ge=0.0,
        description="Delt weighting factor of the mask predictin loss term for the Maskix",
    )
    delta_mask_corrupted: float = Field(
        default=0.75,
        ge=0.0,
        description="For the Maskix: if >0.5 this gives more weight for the correct reconstruction of corrupted input",
    )

OntixConfig

Bases: DefaultConfig

A specialized configuration for Ontix that only allows scaling methods guaranteeing non-negative outputs.

Source code in src/autoencodix/configs/ontix_config.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class OntixConfig(DefaultConfig):
    """
    A specialized configuration for Ontix that only allows scaling methods
    guaranteeing non-negative outputs.
    """

    # 1. Override the top-level 'scaling' attribute
    scaling: Literal["MINMAX", "NONE", "NOTSET", "LOG1P"] = Field(
        default="MINMAX",
        description="Global scaling method. For Ontix, only 'MINMAX' and 'NONE' are allowed, because we need positive values only",
    )

    # 2. Add a validator for the nested 'scaling' attributes
    @model_validator(mode="after")
    def validate_nested_scaling(self) -> "OntixConfig":
        """
        Ensures that any scaling method set within DataInfo is also a valid
        positive-value scaler.
        """
        # Define the set of allowed scaling methods
        allowed_scalers = {"MINMAX", "NONE", "NOTSET"}

        # Loop through each data modality defined in the data_config
        for modality_name, data_info in self.data_config.data_info.items():
            if data_info.scaling not in allowed_scalers:
                raise ValueError(
                    f"Invalid scaling '{data_info.scaling}' for modality '{modality_name}'. "
                    f"OntixConfig only permits {list(allowed_scalers)}."
                )
        return self

validate_nested_scaling()

Ensures that any scaling method set within DataInfo is also a valid positive-value scaler.

Source code in src/autoencodix/configs/ontix_config.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
@model_validator(mode="after")
def validate_nested_scaling(self) -> "OntixConfig":
    """
    Ensures that any scaling method set within DataInfo is also a valid
    positive-value scaler.
    """
    # Define the set of allowed scaling methods
    allowed_scalers = {"MINMAX", "NONE", "NOTSET"}

    # Loop through each data modality defined in the data_config
    for modality_name, data_info in self.data_config.data_info.items():
        if data_info.scaling not in allowed_scalers:
            raise ValueError(
                f"Invalid scaling '{data_info.scaling}' for modality '{modality_name}'. "
                f"OntixConfig only permits {list(allowed_scalers)}."
            )
    return self

StackixConfig

Bases: DefaultConfig

A specialized configuration inheriting from DefaultConfig. For Stackix, save_memory is always False (feature not supported).

Source code in src/autoencodix/configs/stackix_config.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class StackixConfig(DefaultConfig):
    """
    A specialized configuration inheriting from DefaultConfig.
    For Stackix, `save_memory` is always False (feature not supported).
    """

    beta: float = Field(
        default=0.1,
        ge=0,
        description="Beta weighting factor for VAE loss",
    )

    save_memory: bool = Field(
        default=False,
        description="Always False — not supported for Stackix.",
    )

    @model_validator(mode="before")
    def _force_save_memory_false(cls, values):
        if values.get("save_memory") is True:
            warnings.warn(
                "`save_memory=True` is not supported for StackixConfig — forcing to False., Set the checkpoint_interval to number of epochs if you want to save memory",
                UserWarning,
                stacklevel=2,
            )
            values["save_memory"] = False
        return values

VanillixConfig

Bases: DefaultConfig

A specialized configuration inheriting from DefaultConfig.

Source code in src/autoencodix/configs/vanillix_config.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class VanillixConfig(DefaultConfig):
    """
    A specialized configuration inheriting from DefaultConfig.
    """

    beta: float = Field(
        default=0.1,  # Overridden default (was 1.0)
        ge=0,
        description="Beta weighting factor for VAE loss",
    )
    epoch: int = Field(
        default=30, ge=0, description="How many epochs should the model train for."
    )

VarixConfig

Bases: DefaultConfig

A specialized configuration inheriting from DefaultConfig.

Source code in src/autoencodix/configs/varix_config.py
 5
 6
 7
 8
 9
10
11
12
13
14
class VarixConfig(DefaultConfig):
    """
    A specialized configuration inheriting from DefaultConfig.
    """

    beta: float = Field(
        default=0.1,  # Overridden default (was 1.0)
        ge=0,
        description="Beta weighting factor for VAE loss",
    )

XModalixConfig

Bases: DefaultConfig

A specialized configuration inheriting from DefaultConfig.

This class overrides specific training parameters like pretrain_epochs and beta for the XModalix model, while inheriting all other settings.

Source code in src/autoencodix/configs/xmodalix_config.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class XModalixConfig(DefaultConfig):
    """
    A specialized configuration inheriting from DefaultConfig.

    This class overrides specific training parameters like pretrain_epochs and beta
    for the XModalix model, while inheriting all other settings.
    """

    pretrain_epochs: Optional[int] = Field(
        default=None,  # Overridden default (was 0)
        description="Number of pretraining epochs, can be overwritten in DataInfo to have different number of pretraining epochs for each data modality",
    )

    beta: float = Field(
        default=0.1,  # Overridden default (was 1.0)
        ge=0,
        description="Beta weighting factor for VAE loss",
    )
    requires_paired: bool = Field(default=False)
    save_memory: bool = Field(
        default=False,
        description="Always False — not supported for Stackix.",
    )

    @model_validator(mode="before")
    def _force_save_memory_false(cls, values):
        if values.get("save_memory") is True:
            warnings.warn(
                "`save_memory=True` is not supported for XModalixConfig — forcing to False., Set the checkpoint_interval to number of epochs if you want to save memory",
                UserWarning,
                stacklevel=2,
            )
            values["save_memory"] = False
        return values