Skip to content

Stackix Module

Stackix

Bases: BasePipeline

Stackix pipeline for training multiple VAEs on different modalities and stacking their latent spaces.

This pipeline uses: 1. StackixPreprocessor to prepare data for multi-modality training 2. StackixTrainer to train individual VAEs, extract latent spaces, and train the final stacked model

Like other pipelines, it follows the standard BasePipeline interface and workflow.

Additional Attributes

_default_config: Is set to StackixConfig here.

Source code in src/autoencodix/stackix.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class Stackix(BasePipeline):
    """Stackix pipeline for training multiple VAEs on different modalities and stacking their latent spaces.

    This pipeline uses:
    1. StackixPreprocessor to prepare data for multi-modality training
    2. StackixTrainer to train individual VAEs, extract latent spaces, and train the final stacked model

    Like other pipelines, it follows the standard BasePipeline interface and workflow.

    Additional Attributes:
        _default_config: Is set to StackixConfig here.

    """

    def __init__(
        self,
        data: Optional[Union[DataPackage, DatasetContainer]] = None,
        trainer_type: Type[BaseTrainer] = StackixTrainer,
        dataset_type: Type[BaseDataset] = StackixDataset,
        model_type: Type[BaseAutoencoder] = VarixArchitecture,
        loss_type: Type[BaseLoss] = VarixLoss,
        preprocessor_type: Type[BasePreprocessor] = StackixPreprocessor,
        visualizer: Type[BaseVisualizer] = GeneralVisualizer,
        evaluator: Optional[Type[BaseEvaluator]] = GeneralEvaluator,
        result: Optional[Result] = None,
        datasplitter_type: Type[DataSplitter] = DataSplitter,
        custom_splits: Optional[Dict[str, np.ndarray]] = None,
        config: Optional[DefaultConfig] = None,
        ontologies: Optional[Union[List, Dict]] = None,
    ) -> None:
        """Initialize the Stackix pipeline.

        See parent class for full list of Args.
        """
        self._default_config = StackixConfig()
        super().__init__(
            data=data,
            dataset_type=dataset_type
            or NumericDataset,  # Fallback, but not directly used
            trainer_type=trainer_type,
            model_type=model_type,
            loss_type=loss_type,
            preprocessor_type=preprocessor_type,
            visualizer=visualizer,
            evaluator=evaluator,
            result=result,
            datasplitter_type=datasplitter_type,
            config=config,
            custom_split=custom_splits,
            ontologies=ontologies,
        )
        if not isinstance(self.config, StackixConfig):
            raise TypeError(
                f"For Stackix Pipeline, we only allow StackixConfig as type for config, got {type(self.config)}"
            )

    def _process_latent_results(
        self, predictor_results: Result, predict_data: DatasetContainer
    ):
        """Processes the latent spaces from the StackixTrainer prediction results.

        Creates a correctly annotated AnnData object.
        This method overrides the BasePipeline implementation to specifically handle
        the aligned latent space from the unpaired/stacked workflow.


        Args:
            predictor_results: Result object after predict step
            predict_data: not used here, only to keep interface structure

        """
        latent = predictor_results.latentspaces.get(epoch=-1, split="test")
        sample_ids = predictor_results.sample_ids.get(epoch=-1, split="test")
        if latent is None:
            import warnings

            warnings.warn(
                "No latent space found in predictor results. Cannot create AnnData object."
            )
            return

        self.result.adata_latent = ad.AnnData(X=latent)
        self.result.adata_latent.obs_names = sample_ids
        self.result.adata_latent.var_names = [
            f"Latent_{i}" for i in range(latent.shape[1])  # ty: ignore
        ]

        # 4. Update the main result object with the rest of the prediction results.
        self.result.update(predictor_results)

        print("Successfully created annotated latent space object (adata_latent).")

__init__(data=None, trainer_type=StackixTrainer, dataset_type=StackixDataset, model_type=VarixArchitecture, loss_type=VarixLoss, preprocessor_type=StackixPreprocessor, visualizer=GeneralVisualizer, evaluator=GeneralEvaluator, result=None, datasplitter_type=DataSplitter, custom_splits=None, config=None, ontologies=None)

Initialize the Stackix pipeline.

See parent class for full list of Args.

Source code in src/autoencodix/stackix.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def __init__(
    self,
    data: Optional[Union[DataPackage, DatasetContainer]] = None,
    trainer_type: Type[BaseTrainer] = StackixTrainer,
    dataset_type: Type[BaseDataset] = StackixDataset,
    model_type: Type[BaseAutoencoder] = VarixArchitecture,
    loss_type: Type[BaseLoss] = VarixLoss,
    preprocessor_type: Type[BasePreprocessor] = StackixPreprocessor,
    visualizer: Type[BaseVisualizer] = GeneralVisualizer,
    evaluator: Optional[Type[BaseEvaluator]] = GeneralEvaluator,
    result: Optional[Result] = None,
    datasplitter_type: Type[DataSplitter] = DataSplitter,
    custom_splits: Optional[Dict[str, np.ndarray]] = None,
    config: Optional[DefaultConfig] = None,
    ontologies: Optional[Union[List, Dict]] = None,
) -> None:
    """Initialize the Stackix pipeline.

    See parent class for full list of Args.
    """
    self._default_config = StackixConfig()
    super().__init__(
        data=data,
        dataset_type=dataset_type
        or NumericDataset,  # Fallback, but not directly used
        trainer_type=trainer_type,
        model_type=model_type,
        loss_type=loss_type,
        preprocessor_type=preprocessor_type,
        visualizer=visualizer,
        evaluator=evaluator,
        result=result,
        datasplitter_type=datasplitter_type,
        config=config,
        custom_split=custom_splits,
        ontologies=ontologies,
    )
    if not isinstance(self.config, StackixConfig):
        raise TypeError(
            f"For Stackix Pipeline, we only allow StackixConfig as type for config, got {type(self.config)}"
        )