Skip to content

Base Module

BaseAutoencoder

Bases: ABC, Module

Interface for building autoencoder models.

Defines standard methods for encoding data to a latent space and decoding back to the original space. Includes a weight initialization method for stable training. Intended to be extended by specific autoencoder variants like VAE.

Attributes:

Name Type Description
input_dim

Number of input features.

config

Configuration object containing model architecture parameters.

_encoder Optional[Module]

Encoder network.

_decoder Optional[Module]

Decoder network.

ontologies

Ontology information, if provided for Ontix

feature_order

For Ontix

Source code in src/autoencodix/base/_base_autoencoder.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class BaseAutoencoder(ABC, nn.Module):
    """Interface for building autoencoder models.

    Defines standard methods for encoding data to a latent space and decoding
    back to the original space. Includes a weight initialization method for
    stable training. Intended to be extended by specific autoencoder variants
    like VAE.

    Attributes:
        input_dim: Number of input features.
        config: Configuration object containing model architecture parameters.
        _encoder: Encoder network.
        _decoder: Decoder network.
        ontologies: Ontology information, if provided for Ontix
        feature_order: For Ontix
    """

    def __init__(
        self,
        config: Optional[DefaultConfig],
        input_dim: Union[int, Tuple[int, ...]],
        ontologies: Optional[Union[Tuple, Dict]] = None,
        feature_order: Optional[Union[Tuple, Dict]] = None,
    ):
        """Initializes the BaseAutoencoder.

        Args:
            config: Configuration object containing model parameters.
                If None, a default configuration will be used.
            input_dim: Number of input features.
            ontologies: Ontology information, if provided for Ontix
            feature_order: For Ontix
        """
        super().__init__()
        if config is None:
            config = DefaultConfig()
        self.input_dim = input_dim
        self._encoder: Optional[nn.Module] = None
        self._decoder: Optional[nn.Module] = None
        self.config = config
        self.ontologies = ontologies
        self.feature_order = feature_order
        self.init_args = dict(
            config=config,
            input_dim=input_dim,
            ontologies=ontologies,
            feature_order=feature_order,
        )

    @abstractmethod
    def _build_network(self) -> None:
        """Builds the encoder and decoder networks for the autoencoder model.

        Populates the self._encoder and self._decoder attributes.
        This method should be implemented by subclasses to define
        the architecture of the encoder and decoder networks.
        """
        pass

    @abstractmethod
    def encode(
        self, x: torch.Tensor
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Encodes the input into the latent space.

        Args:
            x: The input tensor to be encoded.

        Returns:
            The encoded latent space representation, or mu and logvar for VAEs.
        """
        pass

    @abstractmethod
    def get_latent_space(self, x: torch.Tensor) -> torch.Tensor:
        """Returns the latent space representation of the input.

        Method for unification of getting a latent space between Variational
        and Vanilla Autoencoders. This method is a wrapper around the encode
        method, or the reparameterization method for VAE.

        Args:
            x: The input tensor to be encoded.

        Returns:
            The latent space representation of the input tensor.
        """
        pass

    @abstractmethod
    def decode(self, x: torch.Tensor) -> torch.Tensor:
        """Decodes the latent representation back to the input space.

        Args:
            x: The latent tensor to be decoded.

        Returns:
            The decoded tensor, reconstructed from the latent space.
        """
        pass

    @abstractmethod
    def forward(self, x: torch.Tensor) -> ModelOutput:
        """Combines encoding and decoding steps for the autoencoder.

        Args:
            x: The input tensor to be processed.

        Returns:
            The reconstructed input tensor and any additional information,
            depending on the model type.
        """
        pass

    def _init_weights(self, m):
        """Initializes weights using Xavier uniform initialization.

        This weight initialization method helps maintain the variance of
        activations across layers, preventing gradients from vanishing or
        exploding during training. This approach ensures stable and efficient
        training of the autoencoder model.

        Args:
            m: The module to initialize.
        """
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)

__init__(config, input_dim, ontologies=None, feature_order=None)

Initializes the BaseAutoencoder.

Parameters:

Name Type Description Default
config Optional[DefaultConfig]

Configuration object containing model parameters. If None, a default configuration will be used.

required
input_dim Union[int, Tuple[int, ...]]

Number of input features.

required
ontologies Optional[Union[Tuple, Dict]]

Ontology information, if provided for Ontix

None
feature_order Optional[Union[Tuple, Dict]]

For Ontix

None
Source code in src/autoencodix/base/_base_autoencoder.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(
    self,
    config: Optional[DefaultConfig],
    input_dim: Union[int, Tuple[int, ...]],
    ontologies: Optional[Union[Tuple, Dict]] = None,
    feature_order: Optional[Union[Tuple, Dict]] = None,
):
    """Initializes the BaseAutoencoder.

    Args:
        config: Configuration object containing model parameters.
            If None, a default configuration will be used.
        input_dim: Number of input features.
        ontologies: Ontology information, if provided for Ontix
        feature_order: For Ontix
    """
    super().__init__()
    if config is None:
        config = DefaultConfig()
    self.input_dim = input_dim
    self._encoder: Optional[nn.Module] = None
    self._decoder: Optional[nn.Module] = None
    self.config = config
    self.ontologies = ontologies
    self.feature_order = feature_order
    self.init_args = dict(
        config=config,
        input_dim=input_dim,
        ontologies=ontologies,
        feature_order=feature_order,
    )

decode(x) abstractmethod

Decodes the latent representation back to the input space.

Parameters:

Name Type Description Default
x Tensor

The latent tensor to be decoded.

required

Returns:

Type Description
Tensor

The decoded tensor, reconstructed from the latent space.

Source code in src/autoencodix/base/_base_autoencoder.py
100
101
102
103
104
105
106
107
108
109
110
@abstractmethod
def decode(self, x: torch.Tensor) -> torch.Tensor:
    """Decodes the latent representation back to the input space.

    Args:
        x: The latent tensor to be decoded.

    Returns:
        The decoded tensor, reconstructed from the latent space.
    """
    pass

encode(x) abstractmethod

Encodes the input into the latent space.

Parameters:

Name Type Description Default
x Tensor

The input tensor to be encoded.

required

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

The encoded latent space representation, or mu and logvar for VAEs.

Source code in src/autoencodix/base/_base_autoencoder.py
70
71
72
73
74
75
76
77
78
79
80
81
82
@abstractmethod
def encode(
    self, x: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """Encodes the input into the latent space.

    Args:
        x: The input tensor to be encoded.

    Returns:
        The encoded latent space representation, or mu and logvar for VAEs.
    """
    pass

forward(x) abstractmethod

Combines encoding and decoding steps for the autoencoder.

Parameters:

Name Type Description Default
x Tensor

The input tensor to be processed.

required

Returns:

Type Description
ModelOutput

The reconstructed input tensor and any additional information,

ModelOutput

depending on the model type.

Source code in src/autoencodix/base/_base_autoencoder.py
112
113
114
115
116
117
118
119
120
121
122
123
@abstractmethod
def forward(self, x: torch.Tensor) -> ModelOutput:
    """Combines encoding and decoding steps for the autoencoder.

    Args:
        x: The input tensor to be processed.

    Returns:
        The reconstructed input tensor and any additional information,
        depending on the model type.
    """
    pass

get_latent_space(x) abstractmethod

Returns the latent space representation of the input.

Method for unification of getting a latent space between Variational and Vanilla Autoencoders. This method is a wrapper around the encode method, or the reparameterization method for VAE.

Parameters:

Name Type Description Default
x Tensor

The input tensor to be encoded.

required

Returns:

Type Description
Tensor

The latent space representation of the input tensor.

Source code in src/autoencodix/base/_base_autoencoder.py
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@abstractmethod
def get_latent_space(self, x: torch.Tensor) -> torch.Tensor:
    """Returns the latent space representation of the input.

    Method for unification of getting a latent space between Variational
    and Vanilla Autoencoders. This method is a wrapper around the encode
    method, or the reparameterization method for VAE.

    Args:
        x: The input tensor to be encoded.

    Returns:
        The latent space representation of the input tensor.
    """
    pass

BaseDataset

Bases: ABC, Dataset

Interface to guide implementation for custom PyTorch datasets.

Attributes:

Name Type Description
data

The dataset content (can be a torch.Tensor or other data structure).

config

Optional configuration object.

sample_ids

Optional list of identifiers for each sample.

feature_ids

Optional list of identifiers for each feature.

mytype Enum

Enum indicating the dataset type (should be set in subclasses).

Source code in src/autoencodix/base/_base_dataset.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class BaseDataset(abc.ABC, Dataset):
    """Interface to guide implementation for custom PyTorch datasets.

    Attributes:
        data: The dataset content (can be a torch.Tensor or other data structure).
        config: Optional configuration object.
        sample_ids: Optional list of identifiers for each sample.
        feature_ids: Optional list of identifiers for each feature.
        mytype: Enum indicating the dataset type (should be set in subclasses).
    """

    def __init__(
        self,
        data: Union[torch.Tensor, List[ImgData], sp.sparse.spmatrix],
        config: Optional[Any] = None,
        sample_ids: Optional[List[Any]] = None,
        feature_ids: Optional[List[Any]] = None,
    ):
        """Initializes the dataset.

        Args:
            data: The data to be used by the dataset.
            config: Optional configuration parameters.
            sample_ids: Optional identifiers for each sample.
            feature_ids: Optional identifiers for each feature.
            mytype: Enum indicating the dataset type (should be set in subclasses).
        """
        self.data = data
        self.raw_data = data  # for child class ImageDataset
        self.config = config
        self.sample_ids = sample_ids
        self.feature_ids = feature_ids
        self.mytype: Enum  # Should be set in subclasses to indicate the dataset type (e.g., DataSetTypes.NUM or DataSetTypes.IMG)

        self.metadata: Optional[Union[pd.Series, pd.DataFrame]] = (None,)
        self.datasets: Dict[str, BaseDataset] = {}  # for xmodalix child

    def __len__(self) -> int:
        """Returns the number of samples in the dataset.

        Returns:
            The number of samples in the dataset.
        """
        if isinstance(self.data, list):
            return len(self.data)
        else:
            return self.data.shape[0]

    def get_input_dim(self) -> Union[int, Tuple[int, ...]]:
        """Gets the input dimension of the dataset (n_features)

        Returns:
            The input dimension of the dataset's feature space.
        """
        if isinstance(self.data, (torch.Tensor, sp.sparse.spmatrix)):
            return self.data.shape[1]

        elif isinstance(self.data, list):
            if len(self.data) == 0:
                raise ValueError(
                    "Dataset is ImgData, and the list of ImgData is empty, cannot determine input dimension."
                )
            if isinstance(self.data[0], ImgData):
                return self.data[0].img.shape[0]
            else:
                raise ValueError(
                    "List data is not of type ImgData, cannot determine input dimension."
                )
        else:
            raise ValueError("Unsupported data type for input dimension retrieval.")

    def _to_df(self, modality: Optional[str] = None) -> pd.DataFrame:
        """
        Convert the dataset to a pandas DataFrame.

        Returns:
            DataFrame representation of the dataset
        """
        if isinstance(self.data, torch.Tensor):
            return pd.DataFrame(
                self.data.numpy(), columns=self.feature_ids, index=self.sample_ids
            )
        else:
            raise TypeError(
                "Data is not a torch.Tensor and cannot be converted to DataFrame."
            )

__init__(data, config=None, sample_ids=None, feature_ids=None)

Initializes the dataset.

Parameters:

Name Type Description Default
data Union[Tensor, List[ImgData], spmatrix]

The data to be used by the dataset.

required
config Optional[Any]

Optional configuration parameters.

None
sample_ids Optional[List[Any]]

Optional identifiers for each sample.

None
feature_ids Optional[List[Any]]

Optional identifiers for each feature.

None
mytype

Enum indicating the dataset type (should be set in subclasses).

required
Source code in src/autoencodix/base/_base_dataset.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    data: Union[torch.Tensor, List[ImgData], sp.sparse.spmatrix],
    config: Optional[Any] = None,
    sample_ids: Optional[List[Any]] = None,
    feature_ids: Optional[List[Any]] = None,
):
    """Initializes the dataset.

    Args:
        data: The data to be used by the dataset.
        config: Optional configuration parameters.
        sample_ids: Optional identifiers for each sample.
        feature_ids: Optional identifiers for each feature.
        mytype: Enum indicating the dataset type (should be set in subclasses).
    """
    self.data = data
    self.raw_data = data  # for child class ImageDataset
    self.config = config
    self.sample_ids = sample_ids
    self.feature_ids = feature_ids
    self.mytype: Enum  # Should be set in subclasses to indicate the dataset type (e.g., DataSetTypes.NUM or DataSetTypes.IMG)

    self.metadata: Optional[Union[pd.Series, pd.DataFrame]] = (None,)
    self.datasets: Dict[str, BaseDataset] = {}  # for xmodalix child

__len__()

Returns the number of samples in the dataset.

Returns:

Type Description
int

The number of samples in the dataset.

Source code in src/autoencodix/base/_base_dataset.py
55
56
57
58
59
60
61
62
63
64
def __len__(self) -> int:
    """Returns the number of samples in the dataset.

    Returns:
        The number of samples in the dataset.
    """
    if isinstance(self.data, list):
        return len(self.data)
    else:
        return self.data.shape[0]

get_input_dim()

Gets the input dimension of the dataset (n_features)

Returns:

Type Description
Union[int, Tuple[int, ...]]

The input dimension of the dataset's feature space.

Source code in src/autoencodix/base/_base_dataset.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def get_input_dim(self) -> Union[int, Tuple[int, ...]]:
    """Gets the input dimension of the dataset (n_features)

    Returns:
        The input dimension of the dataset's feature space.
    """
    if isinstance(self.data, (torch.Tensor, sp.sparse.spmatrix)):
        return self.data.shape[1]

    elif isinstance(self.data, list):
        if len(self.data) == 0:
            raise ValueError(
                "Dataset is ImgData, and the list of ImgData is empty, cannot determine input dimension."
            )
        if isinstance(self.data[0], ImgData):
            return self.data[0].img.shape[0]
        else:
            raise ValueError(
                "List data is not of type ImgData, cannot determine input dimension."
            )
    else:
        raise ValueError("Unsupported data type for input dimension retrieval.")

BaseEvaluator

Bases: ABC

Source code in src/autoencodix/base/_base_evaluator.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class BaseEvaluator(abc.ABC):
    @abc.abstractmethod
    def evaluate(self, *args):
        """
        Evaluate the Autoencodix pipeline on defined machine learning tasks.

        Subclasses must implement this method to perform evaluation using the provided arguments.

        Args:
            *args: Variable length argument list for evaluation parameters.

        Returns:
            Result: The evaluation result.
        """
        pass

    @staticmethod
    def _expand_reference_methods(reference_methods: list, result: Result) -> list:
        """
        Expands the list of reference methods if needed for evaluation.

        Args:
            reference_methods (list): The list of reference methods to potentially expand.
            result (Result): The evaluation result object.

        Returns:
            list: The (possibly expanded) list of reference methods.
        """
        return reference_methods

evaluate(*args) abstractmethod

Evaluate the Autoencodix pipeline on defined machine learning tasks.

Subclasses must implement this method to perform evaluation using the provided arguments.

Parameters:

Name Type Description Default
*args

Variable length argument list for evaluation parameters.

()

Returns:

Name Type Description
Result

The evaluation result.

Source code in src/autoencodix/base/_base_evaluator.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
@abc.abstractmethod
def evaluate(self, *args):
    """
    Evaluate the Autoencodix pipeline on defined machine learning tasks.

    Subclasses must implement this method to perform evaluation using the provided arguments.

    Args:
        *args: Variable length argument list for evaluation parameters.

    Returns:
        Result: The evaluation result.
    """
    pass

BaseLoss

Bases: Module, ABC

Provides common loss computation functionality for autoencoders.

Implements standard loss calculations including reconstruction loss, KL divergence, and Maximum Mean Discrepancy (MMD), while requiring subclasses to implement the specific forward method.

Attributes:

Name Type Description
config

Configuration parameters for the loss function.

recon_loss Module

Module for computing reconstruction loss (MSE or BCE).

reduction_fn

Function to apply reduction (mean or sum).

compute_kernel

Function to compute kernel for MMD loss.

annealing_scheduler

Helper for loss calculation with annealing.

Source code in src/autoencodix/base/_base_loss.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
class BaseLoss(nn.Module, ABC):
    """Provides common loss computation functionality for autoencoders.

    Implements standard loss calculations including reconstruction loss,
    KL divergence, and Maximum Mean Discrepancy (MMD), while requiring
    subclasses to implement the specific forward method.

    Attributes:
        config: Configuration parameters for the loss function.
        recon_loss: Module for computing reconstruction loss (MSE or BCE).
        reduction_fn: Function to apply reduction (mean or sum).
        compute_kernel: Function to compute kernel for MMD loss.
        annealing_scheduler: Helper for loss calculation with annealing.
    """

    def __init__(self, config: DefaultConfig, annealing_scheduler=None):
        """Initializes the loss module with the specified configuration.

        Args:
            config: Configuration parameters for the loss function.
            annealing_scheduler: Helper class for loss calculation with annealing.

        Raises:
            NotImplementedError: If unsupported loss reduction or reconstruction
                loss type is specified.
        """
        super().__init__()
        self.annealing_scheduler = annealing_scheduler or AnnealingScheduler()
        self.config = config
        self.recon_loss: nn.Module

        if self.config.loss_reduction == "mean":
            self.reduction_fn = torch.mean
        elif self.config.loss_reduction == "sum":
            self.reduction_fn = torch.sum
        else:
            raise NotImplementedError(
                f"Invalid loss reduction type: {self.config.loss_reduction}. "
                f"Only 'mean' and 'sum' are supported."
            )

        if self.config.reconstruction_loss == "mse":
            self.recon_loss = nn.MSELoss(reduction=config.loss_reduction)
        elif self.config.reconstruction_loss == "bce":
            self.recon_loss = nn.BCEWithLogitsLoss(reduction=config.loss_reduction)
        else:
            raise NotImplementedError(
                f"Invalid reconstruction loss type: {self.config.reconstruction_loss}. "
                f"Only 'mse' and 'bce' are supported. Please check the value of "
                f"'config.reconstruction_loss' for typos or unsupported types."
            )

        self.compute_kernel = self._mmd_kernel

    def _mmd_kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Computes Gaussian kernel for Maximum Mean Discrepancy calculation.

        Calculates the kernel matrix between two sets of samples, using a
        Gaussian kernel with normalization by feature dimension.

        Args:
            x: First set of input samples.
            y: Second set of input samples.

        Returns:
            Kernel matrix of shape (x.shape[0], y.shape[0]).
        """
        x_size = x.size(0)
        y_size = y.size(0)
        dim = x.size(1)

        x = x.unsqueeze(1)
        y = y.unsqueeze(0)
        tiled_x = x.expand(x_size, y_size, dim)
        tiled_y = y.expand(x_size, y_size, dim)

        kernel_input = (tiled_x - tiled_y).pow(2).mean(2) / float(dim)
        return torch.exp(-kernel_input)

    def compute_mmd_loss(
        self, z: torch.Tensor, true_samples: torch.Tensor
    ) -> torch.Tensor:
        """Computes Maximum Mean Discrepancy loss.

        Args:
            z: Samples from the encoded distribution.
            true_samples: Samples from the prior distribution.

        Returns:
            The MMD loss value.

        Raises:
            NotImplementedError: If unsupported loss reduction type is specified.
        """
        true_samples_kernel = self.compute_kernel(x=true_samples, y=true_samples)
        z_device = z.device
        true_samples = true_samples.to(z_device)
        z_kernel = self.compute_kernel(z, z)
        ztr_kernel = self.compute_kernel(x=true_samples, y=z)

        if self.config.loss_reduction == "mean":
            return true_samples_kernel.mean() + z_kernel.mean() - 2 * ztr_kernel.mean()
        elif self.config.loss_reduction == "sum":
            return true_samples_kernel.sum() + z_kernel.sum() - 2 * ztr_kernel.sum()
        else:
            raise NotImplementedError(
                f"Invalid loss reduction type: {self.config.loss_reduction}. "
                f"Only 'mean' and 'sum' are supported."
            )

    def compute_kl_loss(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """Computes KL divergence loss between N(mu, logvar) and N(0, 1).

        Args:
            mu: Mean tensor.
            logvar: Log variance tensor.

        Returns:
            The KL divergence loss value.

        Raises:
            ValueError: If mu and logvar do not have the same shape.
        """
        if mu.shape != logvar.shape:
            raise ValueError(
                f"Shape mismatch: mu has shape {mu.shape}, but logvar has shape {logvar.shape}."
            )
        return -0.5 * self.reduction_fn(1 + logvar - mu.pow(2) - logvar.exp())

    def compute_variational_loss(
        self,
        mu: Optional[torch.Tensor],
        logvar: Optional[torch.Tensor],
        z: Optional[torch.Tensor] = None,
        true_samples: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Computes either KL or MMD loss based on configuration.

        Args:
            mu: Mean tensor for variational loss.
            logvar: Log variance tensor for variational loss.
            z: Encoded samples for MMD loss.
            true_samples: Prior samples for MMD loss.

        Returns:
            The computed variational loss.

        Raises:
            ValueError: If required parameters are missing or if mu and logvar have shape mismatch.
            NotImplementedError: If unsupported VAE loss type is specified.
        """

        if self.config.default_vae_loss == "kl":
            if mu is None:
                raise ValueError("mu must be provided for VAE loss")
            if logvar is None:
                raise ValueError("logvar must be provided for VAE loss")
            if mu.shape != logvar.shape:
                raise ValueError(
                    f"Shape mismatch: mu has shape {mu.shape}, but logvar has shape {logvar.shape}"
                )

            return self.compute_kl_loss(mu=mu, logvar=logvar)

        elif self.config.default_vae_loss == "mmd":
            if z is None:
                raise ValueError("z must be provided for MMD loss")
            if true_samples is None:
                raise ValueError("true_samples must be provided for MMD loss")
            return self.compute_mmd_loss(z=z, true_samples=true_samples)
        else:
            raise NotImplementedError(
                f"VAE loss type {self.config.default_vae_loss} is not implemented. "
                f"Only 'kl' and 'mmd' are supported."
            )

    def compute_paired_loss(
        self,
        latentspaces: dict[str, torch.Tensor],
        sample_ids: dict[str, list],
    ) -> torch.Tensor:
        """
        Calculates the paired distance loss across all pairs of modalities in a batch.

        Args:
            latentspaces: A dictionary mapping modality names to their latent space tensors.
                        e.g., {'RNA': tensor_rna, 'ATAC': tensor_atac}
            sample_ids: A dictionary mapping modality names to their list of sample IDs.

        Returns:
            A single scalar tensor representing the total paired loss.
        """

        loss_helper = []
        modality_names = list(latentspaces.keys())

        # 1. Iterate through all unique pairs of modalities
        for mod_a, mod_b in itertools.combinations(modality_names, 2):
            ids_a = sample_ids[mod_a]
            ids_b = sample_ids[mod_b]

            # 2. Find the intersection of sample IDs
            common_ids = set(ids_a) & set(ids_b)

            if not common_ids:
                print("no common ids")
                continue

            # 3. Create a mapping from sample ID to index for efficient lookup
            id_to_idx_a = {sample_id: i for i, sample_id in enumerate(ids_a)}
            id_to_idx_b = {sample_id: i for i, sample_id in enumerate(ids_b)}

            # Get the corresponding indices for the common samples
            indices_a = [id_to_idx_a[common_id] for common_id in common_ids]
            indices_b = [id_to_idx_b[common_id] for common_id in common_ids]

            # 4. Select the latent vectors for the paired samples
            paired_latents_a = latentspaces[mod_a][indices_a]
            paired_latents_b = latentspaces[mod_b][indices_b]

            # 5. Calculate the distance between the aligned latent vectors
            # L1 distance, averaged over latent dimensions and then over samples
            distance = torch.abs(paired_latents_a - paired_latents_b).mean(dim=1)
            pair_loss = self.reduction_fn(distance)
            loss_helper.append(pair_loss)
        if not loss_helper:
            return torch.tensor(0.0)
        return torch.stack(loss_helper).mean()

    @staticmethod
    def _compute_log_gauss_dense(
        z: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor
    ) -> torch.Tensor:
        """Computes the log probability of a Gaussian distribution.

        Args:
            z: Latent variable tensor.
            mu: Mean tensor.
            logvar: Log variance tensor.

        Returns:
            Log probability of the Gaussian distribution.
        """
        return -0.5 * (
            torch.log(torch.tensor([2 * torch.pi]).to(z.device))
            + logvar
            + (z - mu) ** 2 * torch.exp(-logvar)
        )

    @staticmethod
    def _compute_log_import_weight_mat(batch_size: int, n_samples: int) -> torch.Tensor:
        """Computes the log import weight matrix for disentangled loss.
           Similar to: https://github.com/rtqichen/beta-tcvae
        Args:
            batch_size: Number of samples in the batch.
            n_samples: Total number of samples in the dataset.

        Returns:
            Log import weight matrix of shape (batch_size, n_samples).
        """

        N = n_samples
        M = batch_size - 1
        strat_weight = (N - M) / (N * M)
        W = torch.Tensor(batch_size, batch_size).fill_(1 / M)
        W.view(-1)[:: M + 1] = 1 / N
        W.view(-1)[1 :: M + 1] = strat_weight
        W[M - 1, 0] = strat_weight
        return W.log()

    @abstractmethod
    def forward(
        self,
        *args,
        **kwargs,
    ) -> Any:
        """Calculates the loss for the autoencoder.

        This method must be implemented by subclasses to define the specific
        loss computation logic for the autoencoder. The implementation should
        compute the total loss as well as any individual loss components
        (e.g., reconstruction loss, KL divergence, etc.) based on the model's
        output and the provided targets.

        Args:
            *kwargs depending on the loss type and pipeline


        Returns:
            - The total loss value as a scalar tensor.
            - A dictionary of individual loss components, where the keys are
                descriptive strings (e.g., "reconstruction_loss", "kl_loss") and
                the values are the corresponding loss tensors.
            - Implementation in subclasses is flexible, so for new loss classes this can differ.

        Note:
            Subclasses must implement this method to define the specific loss
            computation logic for their use case.
        """
        # TODO maybe standardize the return types more i.e. request a scalar and a dict
        pass

__init__(config, annealing_scheduler=None)

Initializes the loss module with the specified configuration.

Parameters:

Name Type Description Default
config DefaultConfig

Configuration parameters for the loss function.

required
annealing_scheduler

Helper class for loss calculation with annealing.

None

Raises:

Type Description
NotImplementedError

If unsupported loss reduction or reconstruction loss type is specified.

Source code in src/autoencodix/base/_base_loss.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(self, config: DefaultConfig, annealing_scheduler=None):
    """Initializes the loss module with the specified configuration.

    Args:
        config: Configuration parameters for the loss function.
        annealing_scheduler: Helper class for loss calculation with annealing.

    Raises:
        NotImplementedError: If unsupported loss reduction or reconstruction
            loss type is specified.
    """
    super().__init__()
    self.annealing_scheduler = annealing_scheduler or AnnealingScheduler()
    self.config = config
    self.recon_loss: nn.Module

    if self.config.loss_reduction == "mean":
        self.reduction_fn = torch.mean
    elif self.config.loss_reduction == "sum":
        self.reduction_fn = torch.sum
    else:
        raise NotImplementedError(
            f"Invalid loss reduction type: {self.config.loss_reduction}. "
            f"Only 'mean' and 'sum' are supported."
        )

    if self.config.reconstruction_loss == "mse":
        self.recon_loss = nn.MSELoss(reduction=config.loss_reduction)
    elif self.config.reconstruction_loss == "bce":
        self.recon_loss = nn.BCEWithLogitsLoss(reduction=config.loss_reduction)
    else:
        raise NotImplementedError(
            f"Invalid reconstruction loss type: {self.config.reconstruction_loss}. "
            f"Only 'mse' and 'bce' are supported. Please check the value of "
            f"'config.reconstruction_loss' for typos or unsupported types."
        )

    self.compute_kernel = self._mmd_kernel

compute_kl_loss(mu, logvar)

Computes KL divergence loss between N(mu, logvar) and N(0, 1).

Parameters:

Name Type Description Default
mu Tensor

Mean tensor.

required
logvar Tensor

Log variance tensor.

required

Returns:

Type Description
Tensor

The KL divergence loss value.

Raises:

Type Description
ValueError

If mu and logvar do not have the same shape.

Source code in src/autoencodix/base/_base_loss.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def compute_kl_loss(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    """Computes KL divergence loss between N(mu, logvar) and N(0, 1).

    Args:
        mu: Mean tensor.
        logvar: Log variance tensor.

    Returns:
        The KL divergence loss value.

    Raises:
        ValueError: If mu and logvar do not have the same shape.
    """
    if mu.shape != logvar.shape:
        raise ValueError(
            f"Shape mismatch: mu has shape {mu.shape}, but logvar has shape {logvar.shape}."
        )
    return -0.5 * self.reduction_fn(1 + logvar - mu.pow(2) - logvar.exp())

compute_mmd_loss(z, true_samples)

Computes Maximum Mean Discrepancy loss.

Parameters:

Name Type Description Default
z Tensor

Samples from the encoded distribution.

required
true_samples Tensor

Samples from the prior distribution.

required

Returns:

Type Description
Tensor

The MMD loss value.

Raises:

Type Description
NotImplementedError

If unsupported loss reduction type is specified.

Source code in src/autoencodix/base/_base_loss.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def compute_mmd_loss(
    self, z: torch.Tensor, true_samples: torch.Tensor
) -> torch.Tensor:
    """Computes Maximum Mean Discrepancy loss.

    Args:
        z: Samples from the encoded distribution.
        true_samples: Samples from the prior distribution.

    Returns:
        The MMD loss value.

    Raises:
        NotImplementedError: If unsupported loss reduction type is specified.
    """
    true_samples_kernel = self.compute_kernel(x=true_samples, y=true_samples)
    z_device = z.device
    true_samples = true_samples.to(z_device)
    z_kernel = self.compute_kernel(z, z)
    ztr_kernel = self.compute_kernel(x=true_samples, y=z)

    if self.config.loss_reduction == "mean":
        return true_samples_kernel.mean() + z_kernel.mean() - 2 * ztr_kernel.mean()
    elif self.config.loss_reduction == "sum":
        return true_samples_kernel.sum() + z_kernel.sum() - 2 * ztr_kernel.sum()
    else:
        raise NotImplementedError(
            f"Invalid loss reduction type: {self.config.loss_reduction}. "
            f"Only 'mean' and 'sum' are supported."
        )

compute_paired_loss(latentspaces, sample_ids)

Calculates the paired distance loss across all pairs of modalities in a batch.

Parameters:

Name Type Description Default
latentspaces dict[str, Tensor]

A dictionary mapping modality names to their latent space tensors. e.g., {'RNA': tensor_rna, 'ATAC': tensor_atac}

required
sample_ids dict[str, list]

A dictionary mapping modality names to their list of sample IDs.

required

Returns:

Type Description
Tensor

A single scalar tensor representing the total paired loss.

Source code in src/autoencodix/base/_base_loss.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def compute_paired_loss(
    self,
    latentspaces: dict[str, torch.Tensor],
    sample_ids: dict[str, list],
) -> torch.Tensor:
    """
    Calculates the paired distance loss across all pairs of modalities in a batch.

    Args:
        latentspaces: A dictionary mapping modality names to their latent space tensors.
                    e.g., {'RNA': tensor_rna, 'ATAC': tensor_atac}
        sample_ids: A dictionary mapping modality names to their list of sample IDs.

    Returns:
        A single scalar tensor representing the total paired loss.
    """

    loss_helper = []
    modality_names = list(latentspaces.keys())

    # 1. Iterate through all unique pairs of modalities
    for mod_a, mod_b in itertools.combinations(modality_names, 2):
        ids_a = sample_ids[mod_a]
        ids_b = sample_ids[mod_b]

        # 2. Find the intersection of sample IDs
        common_ids = set(ids_a) & set(ids_b)

        if not common_ids:
            print("no common ids")
            continue

        # 3. Create a mapping from sample ID to index for efficient lookup
        id_to_idx_a = {sample_id: i for i, sample_id in enumerate(ids_a)}
        id_to_idx_b = {sample_id: i for i, sample_id in enumerate(ids_b)}

        # Get the corresponding indices for the common samples
        indices_a = [id_to_idx_a[common_id] for common_id in common_ids]
        indices_b = [id_to_idx_b[common_id] for common_id in common_ids]

        # 4. Select the latent vectors for the paired samples
        paired_latents_a = latentspaces[mod_a][indices_a]
        paired_latents_b = latentspaces[mod_b][indices_b]

        # 5. Calculate the distance between the aligned latent vectors
        # L1 distance, averaged over latent dimensions and then over samples
        distance = torch.abs(paired_latents_a - paired_latents_b).mean(dim=1)
        pair_loss = self.reduction_fn(distance)
        loss_helper.append(pair_loss)
    if not loss_helper:
        return torch.tensor(0.0)
    return torch.stack(loss_helper).mean()

compute_variational_loss(mu, logvar, z=None, true_samples=None)

Computes either KL or MMD loss based on configuration.

Parameters:

Name Type Description Default
mu Optional[Tensor]

Mean tensor for variational loss.

required
logvar Optional[Tensor]

Log variance tensor for variational loss.

required
z Optional[Tensor]

Encoded samples for MMD loss.

None
true_samples Optional[Tensor]

Prior samples for MMD loss.

None

Returns:

Type Description
Tensor

The computed variational loss.

Raises:

Type Description
ValueError

If required parameters are missing or if mu and logvar have shape mismatch.

NotImplementedError

If unsupported VAE loss type is specified.

Source code in src/autoencodix/base/_base_loss.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def compute_variational_loss(
    self,
    mu: Optional[torch.Tensor],
    logvar: Optional[torch.Tensor],
    z: Optional[torch.Tensor] = None,
    true_samples: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Computes either KL or MMD loss based on configuration.

    Args:
        mu: Mean tensor for variational loss.
        logvar: Log variance tensor for variational loss.
        z: Encoded samples for MMD loss.
        true_samples: Prior samples for MMD loss.

    Returns:
        The computed variational loss.

    Raises:
        ValueError: If required parameters are missing or if mu and logvar have shape mismatch.
        NotImplementedError: If unsupported VAE loss type is specified.
    """

    if self.config.default_vae_loss == "kl":
        if mu is None:
            raise ValueError("mu must be provided for VAE loss")
        if logvar is None:
            raise ValueError("logvar must be provided for VAE loss")
        if mu.shape != logvar.shape:
            raise ValueError(
                f"Shape mismatch: mu has shape {mu.shape}, but logvar has shape {logvar.shape}"
            )

        return self.compute_kl_loss(mu=mu, logvar=logvar)

    elif self.config.default_vae_loss == "mmd":
        if z is None:
            raise ValueError("z must be provided for MMD loss")
        if true_samples is None:
            raise ValueError("true_samples must be provided for MMD loss")
        return self.compute_mmd_loss(z=z, true_samples=true_samples)
    else:
        raise NotImplementedError(
            f"VAE loss type {self.config.default_vae_loss} is not implemented. "
            f"Only 'kl' and 'mmd' are supported."
        )

forward(*args, **kwargs) abstractmethod

Calculates the loss for the autoencoder.

This method must be implemented by subclasses to define the specific loss computation logic for the autoencoder. The implementation should compute the total loss as well as any individual loss components (e.g., reconstruction loss, KL divergence, etc.) based on the model's output and the provided targets.

Returns:

Type Description
Any
  • The total loss value as a scalar tensor.
Any
  • A dictionary of individual loss components, where the keys are descriptive strings (e.g., "reconstruction_loss", "kl_loss") and the values are the corresponding loss tensors.
Any
  • Implementation in subclasses is flexible, so for new loss classes this can differ.
Note

Subclasses must implement this method to define the specific loss computation logic for their use case.

Source code in src/autoencodix/base/_base_loss.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
@abstractmethod
def forward(
    self,
    *args,
    **kwargs,
) -> Any:
    """Calculates the loss for the autoencoder.

    This method must be implemented by subclasses to define the specific
    loss computation logic for the autoencoder. The implementation should
    compute the total loss as well as any individual loss components
    (e.g., reconstruction loss, KL divergence, etc.) based on the model's
    output and the provided targets.

    Args:
        *kwargs depending on the loss type and pipeline


    Returns:
        - The total loss value as a scalar tensor.
        - A dictionary of individual loss components, where the keys are
            descriptive strings (e.g., "reconstruction_loss", "kl_loss") and
            the values are the corresponding loss tensors.
        - Implementation in subclasses is flexible, so for new loss classes this can differ.

    Note:
        Subclasses must implement this method to define the specific loss
        computation logic for their use case.
    """
    # TODO maybe standardize the return types more i.e. request a scalar and a dict
    pass

BasePipeline

Bases: ABC

Provides a standardized interface for building model pipelines.

Implements methods for preprocessing data, training models, making predictions, evaluating performance, and visualizing results. Subclasses customize behavior by providing specific implementations for processing, training, evaluation, and visualization. For example when using the Stackix Model, we would use the StackixPreprocessor Type for preprocessing.

Attributes:

Name Type Description
config

Configuration for the pipeline's components and behavior.

preprocessed_data Optional[DatasetContainer]

Pre-split and processed data that can be provided by user.

raw_user_data Union[DataPackage, AnnData, MuData, DataFrame, dict]

Raw input data for processing (DataFrames, MuData, etc.).

result

Storage container for all pipeline outputs.

_preprocessor

Component that filters, scales, and cleans data.

_visualizer

Component that generates visual representations of results.

_dataset_type

Base class for dataset implementations.

_trainer_type

Base class for trainer implementations.

_model_type

Base class for model architecture implementations.

_loss_type

Base class for loss function implementations.

_datasets Optional[DatasetContainer]

Split datasets after preprocessing.

_evaluator Optional[DatasetContainer]

Component that assesses model performance. Not implemented yet

_data_splitter

Component that divides data into train/validation/test sets.

_ontologies

Tuple of dictionaries containing the ontologies to be used to construct sparse decoder layers. If a list is provided, it is assumed to be a list of file paths to ontology files. First item in list or tuple will be treated as first layer (after latent space) and so on.

Source code in src/autoencodix/base/_base_pipeline.py
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
class BasePipeline(abc.ABC):
    """Provides a standardized interface for building model pipelines.

    Implements methods for preprocessing data, training models, making predictions,
    evaluating performance, and visualizing results. Subclasses customize behavior
    by providing specific implementations for processing, training, evaluation,
    and visualization. For example when using the Stackix Model, we would use
    the StackixPreprocessor Type for preprocessing.

    Attributes:
        config: Configuration for the pipeline's components and behavior.
        preprocessed_data: Pre-split and processed data that can be provided by user.
        raw_user_data: Raw input data for processing (DataFrames, MuData, etc.).
        result: Storage container for all pipeline outputs.
        _preprocessor: Component that filters, scales, and cleans data.
        _visualizer: Component that generates visual representations of results.
        _dataset_type: Base class for dataset implementations.
        _trainer_type: Base class for trainer implementations.
        _model_type: Base class for model architecture implementations.
        _loss_type: Base class for loss function implementations.
        _datasets: Split datasets after preprocessing.
        _evaluator: Component that assesses model performance. Not implemented yet
        _data_splitter: Component that divides data into train/validation/test sets.
        _ontologies: Tuple of dictionaries containing the ontologies to be used to construct sparse decoder layers.
            If a list is provided, it is assumed to be a list of file paths to ontology files.
            First item in list or tuple will be treated as first layer (after latent space) and so on.

    """

    def __init__(
        self,
        dataset_type: Type[BaseDataset],
        trainer_type: Type[BaseTrainer],
        model_type: Type[BaseAutoencoder],
        loss_type: Type[BaseLoss],
        datasplitter_type: Type[DataSplitter],
        preprocessor_type: Type[BasePreprocessor],
        data: Optional[
            Union[DataPackage, DatasetContainer, ad.AnnData, MuData, pd.DataFrame, dict]  # type: ignore[invalid-type-form]
        ],
        visualizer: Optional[BaseVisualizer] = None,
        evaluator: Optional[BaseEvaluator] = None,
        result: Optional[Result] = None,
        config: Optional[DefaultConfig] = None,
        custom_split: Optional[Dict[str, np.ndarray]] = None,
        ontologies: Optional[Union[Tuple, Dict[Any, Any]]] = None,
        masking_fn: Optional[Callable] = None,
        masking_fn_kwargs: Dict[str, Any] = {},
        **kwargs: dict,
    ) -> None:  # ty: ignore[call-non-callable]
        """Initializes the pipeline with components and configuration.

        Args:
            dataset_type: Class for dataset implementations.
            trainer_type: Class for model training implementations.
            model_type: Class for model architecture implementations.
            loss_type: Class for loss function implementations.
            datasplitter_type: Class for data splitting implementation.
            preprocessor_type: Class for data preprocessing implementation.
            visualizer: Component for generating visualizations.
            data: Input data to be processed or already processed data.
            evaluator: Component for assessing model performance.
            result: Storage container for pipeline outputs.
            config: Configuration parameters for all pipeline components.
            custom_split: User-provided data splits (train/validation/test).
            **kwargs: Additional keyword arguments.

        Raises:
            TypeError: If inputs have incorrect types.
        """
        if not hasattr(self, "_default_config"):
            raise ValueError(
                """
                            The _default_config attribute has not been specified in your pipeline class.

                            Example:
                            self._default_config = XModalixConfig()

                            This error typically occurs when a new architecture is added without setting the
                            _default_config in its corresponding pipeline class.

                            For more details, please refer to the 'how to add a new architecture' section in our documentation.
                            """
            )

        self._validate_config(config=config)
        self._validate_user_input(data=data)
        self.masking_fn = masking_fn
        self.masking_fn_kwargs = masking_fn_kwargs
        processed_data = data if isinstance(data, DatasetContainer) else None
        raw_user_data = (
            data
            if isinstance(data, (DataPackage, ad.AnnData, MuData, pd.DataFrame, dict))
            else None
        )
        if processed_data is not None and not isinstance(
            processed_data, DatasetContainer
        ):
            raise TypeError(
                f"Expected data type to be DatasetContainer, got {type(processed_data)}."
            )

        self.preprocessed_data: Optional[DatasetContainer] = processed_data
        self.raw_user_data: Union[
            DataPackage, ad.AnnData, MuData, pd.DataFrame, dict  # type: ignore[invalid-type-form]
        ] = raw_user_data
        self._trainer_type = trainer_type
        self._trainer: Optional[BaseTrainer] = None
        self._model_type = model_type
        self._loss_type = loss_type
        self._preprocessor_type = preprocessor_type
        if self.raw_user_data is not None:
            self.raw_user_data, datacase = self._handle_direct_user_data(
                data=self.raw_user_data,
            )
            self.config.data_case = datacase
            self._fill_data_info()

        self.ontologies = ontologies
        self._preprocessor = self._preprocessor_type(
            config=self.config, ontologies=self.ontologies
        )

        self.visualizer = (
            visualizer()  # ty: ignore[call-non-callable]
            if visualizer is not None
            else BaseVisualizer()  # ty: ignore[call-non-callable]
        )  # ty: ignore[call-non-callable]
        self.evaluator = (
            evaluator()  # ty: ignore[call-non-callable]
            if evaluator is not None
            else BaseEvaluator()  # ty: ignore[call-non-callable]
        )  # ty: ignore[call-non-callable]
        self.result = result if result is not None else Result()
        self._dataset_type = dataset_type
        self._data_splitter = datasplitter_type(
            config=self.config, custom_splits=custom_split
        )

        self._datasets: Optional[DatasetContainer] = (
            processed_data  # None, or user input
        )

    def _validate_config(self, config: Any) -> None:
        """Sets config to default if None, or validates its type.
        Args:
            config: Configuration object to validate or set to default.
        Raises:
            TypeError: If config is not of type DefaultConfig
        """
        if config is None:
            self.config = self._default_config  # type: ignore
        else:
            if not isinstance(config, DefaultConfig):
                raise TypeError(
                    f"Expected config type to be DefaultConfig, got {type(config)}."
                )
            if not isinstance(config, type(self._default_config)):  # type: ignore
                warnings.warn(
                    f"Your config is of type: {type(config)}, for this pipeline the default params of: {type(self._default_config)} work best"
                )
            self.config = config

    def _validate_user_input(self, data: Any) -> None:
        """Ensures that user-provided data is of a valid type.
        Args:
            data: User-provided data to validate.
        Raises:
            TypeError: If data is not of a supported type.
        """
        if not isinstance(
            data,
            (
                DataPackage,
                ad.AnnData,
                MuData,
                pd.DataFrame,
                dict,
                type(None),
                DatasetContainer,
            ),
        ):
            raise TypeError(
                f"Expected data type to be one of [DataPackage, AnnData, MuData, "
                f"pd.DataFrame, dict, DatasetContainer], got {type(data)}."
            )

    def _handle_direct_user_data(
        self,
        data,
    ) -> Tuple[DataPackage, DataCase]:
        """Converts raw user data into a standardized DataPackage format.

        Args:
            data: Raw input data in various formats.

        Returns:
            DataPackage containing the standardized data
            DataCase, muliti_single_cell or multi_bulk, etc.

        Raises:
            TypeError: If data format is not supported.
            ValueError: If data doesn't meet format requirements or data_case
                cannot be inferred.
        """
        print(f"in handle_direct_user_data with data: {type(data)}")
        data_case = self.config.data_case
        if isinstance(data, DataPackage):
            data_package = data
            data_case = self.config.data_case
        elif isinstance(data, ad.AnnData):
            mudata = MuData({"user-data": data})
            data_package = DataPackage(multi_sc={"multi_sc": mudata})
            if self.config.data_case is None:
                data_case = DataCase.MULTI_SINGLE_CELL
        elif isinstance(data, MuData):
            data_package = DataPackage(multi_sc={"multi_sc": data})
            if self.config.data_case is None:
                data_case = DataCase.MULTI_SINGLE_CELL
        elif isinstance(data, pd.DataFrame):
            data_package = DataPackage(multi_bulk={"user-data": data})
            if self.config.data_case is None:
                data_case = DataCase.MULTI_BULK
        elif isinstance(data, dict):
            # Check if all values in the dictionary are pandas DataFrames
            if all(isinstance(value, pd.DataFrame) for value in data.values()):
                data_package = DataPackage(multi_bulk=data)
                if self.config.data_case is None:
                    data_case = DataCase.MULTI_BULK
            else:
                raise ValueError(
                    "All values in the dictionary must be pandas DataFrames."
                )
        if data_case is None:
            raise ValueError("data_case must be provided if it cannot be inferred.")

        return data_package, data_case

    def _validate_raw_user_data(self) -> None:
        """Validates the format and content of user-provided raw data.

        Ensures that raw_user_data is a valid DataPackage with properly formatted
        attributes.

        Raises:
            TypeError: If raw_user_data is not a DataPackage.
            ValueError: If DataPackage attributes aren't dictionaries or all are None.
        """
        if not isinstance(self.raw_user_data, DataPackage):
            raise TypeError(
                f"Expected raw_user_data to be of type DataPackage, got "
                f"{type(self.raw_user_data)}."
            )

        all_none = True
        for attr_name in self.raw_user_data.__annotations__:
            attr_value = getattr(self.raw_user_data, attr_name)
            if attr_value is not None:
                all_none = False
                if not isinstance(attr_value, dict):
                    raise ValueError(
                        f"Attribute '{attr_name}' of raw_user_data must be a dictionary, "
                        f"got {type(attr_value)}."
                    )

        if all_none:
            raise ValueError(
                "All attributes of raw_user_data are None. At least one must be non-None."
            )

    def _fill_data_info(self) -> None:
        """Populates the config's data_info with entries for all data keys.

        Creates DataInfo objects for each data key found in raw_user_data
        if they don't already exist in the configuration.
        This method is needed, when the user provides data via the Pipeline and
        not via the config.
        """
        all_keys = []
        for k in self.raw_user_data.__annotations__:
            attr_value = getattr(self.raw_user_data, k)
            all_keys.append(k)
            if isinstance(attr_value, dict):
                all_keys.extend(attr_value.keys())
                for k, v in attr_value.items():
                    if isinstance(v, MuData):
                        all_keys.extend(v.mod.keys())
        for k in all_keys:
            if self.config.data_config.data_info.get(k) is None:
                self.config.data_config.data_info[k] = DataInfo()

    def _validate_user_data(self):
        """Validates user-provided data based on its source and format.

        Performs different validation based on whether the user provided
        preprocessed data, raw data, or a data configuration.

        Raises:
            Various exceptions depending on validation results.
        """
        if self.raw_user_data is None:
            if self._datasets is not None:  # case when user passes preprocessed data
                self._validate_container()
            else:  # user passes data via config
                self._validate_config_data()
        else:
            self._validate_raw_user_data()

    def _validate_container(self):
        """Validates that a DatasetContainer has at least one valid dataset.

        Ensures the container has properly formatted datasets and at least
        one split is present.

        Raises:
            ValueError: If container validation fails.
        """
        if self.preprocessed_data is None:
            raise ValueError("DatasetContainer is None. Please provide valid datasets.")
        none_count = 0
        if not isinstance(self.preprocessed_data.train, Dataset):
            if self.preprocessed_data.train is not None:
                raise ValueError(
                    f"Train dataset has to be either None or Dataset, got "
                    f"{type(self.preprocessed_data.train)}"
                )
            none_count += 1
        if not isinstance(self.preprocessed_data.test, Dataset):
            if self.preprocessed_data.test is not None:
                raise ValueError(
                    f"Test dataset has to be either None or Dataset, got "
                    f"{type(self.preprocessed_data.test)}"
                )
            none_count += 1

        if not isinstance(self.preprocessed_data.valid, Dataset):
            if self.preprocessed_data.valid is not None:
                raise ValueError(
                    f"Valid dataset has to be either None or Dataset, got "
                    f"{type(self.preprocessed_data.valid)}"
                )
            none_count += 1
        if none_count == 3:
            raise ValueError("At least one split needs to be provided")

    def _validate_config_data(self):
        """Validates the data configuration provided via config.

        Ensures the data configuration has the necessary components based
        on the data types being processed.

        Raises:
            ValueError: If data configuration validation fails.
        """
        data_info_dict = self.config.data_config.data_info
        if not data_info_dict:
            raise ValueError("data_info dictionary is empty.")

        # Check if there's at least one non-annotation file
        non_annotation_files = {
            key: info
            for key, info in data_info_dict.items()
            if info.data_type != "ANNOTATION"
        }

        if not non_annotation_files:
            raise ValueError("At least one non-annotation file must be provided.")

        # Check if there's any non-single-cell data
        non_single_cell_data = {
            key: info
            for key, info in data_info_dict.items()
            if not info.is_single_cell and info.data_type != "ANNOTATION"
        }

        # If there's non-single-cell data, check for annotation file
        if non_single_cell_data:
            annotation_files = {
                key: info
                for key, info in data_info_dict.items()
                if info.data_type == "ANNOTATION"
            }

            if not annotation_files:
                raise ValueError(
                    "When working with non-single-cell data, an annotation file must be "
                    "provided."
                )

    def preprocess(self, config: Optional[Union[None, DefaultConfig]] = None, **kwargs):
        """Filters, normalizes and prepares data for model training.

        Processes raw input data into the format required by the model and creates
        train/validation/test splits as needed.

        Args:
            config: Optional custom configuration for preprocessing.
            **kwargs: Additional configuration parameters as keyword arguments.

        Raises:
            NotImplementedError: If preprocessor is not initialized.
        """
        if self._preprocessor_type is None:
            raise NotImplementedError("Preprocessor not initialized")
        self._validate_user_data()
        if self.preprocessed_data is None:
            self.preprocessed_data = self._preprocessor.preprocess(
                raw_user_data=self.raw_user_data,  # type: ignore
            )
            self.result.datasets = self.preprocessed_data
            self._datasets = self.preprocessed_data
        else:
            self._datasets = self.preprocessed_data
            self.result.datasets = self.preprocessed_data

    def fit(self, config: Optional[Union[None, DefaultConfig]] = None, **kwargs):
        """Trains the model on preprocessed data.

        Creates and configures a trainer instance, then executes the training
        process using the preprocessed datasets.

        Args:
            config: Optional custom configuration for training.
            **kwargs: Additional configuration parameters as keyword arguments.

        Raises:
            ValueError: If datasets aren't available for training.
        """
        if self._datasets is None:
            raise ValueError(
                "Datasets not built. Please run the preprocess method first."
            )

        self._trainer = self._trainer_type(
            trainset=self._datasets.train,
            validset=self._datasets.valid,
            result=self.result,
            config=self.config,
            model_type=self._model_type,
            loss_type=self._loss_type,
            ontologies=self.ontologies,  # Ontix
            masking_fn=self.masking_fn if hasattr(self, "masking_fn") else None,
            masking_fn_kwargs=(
                self.masking_fn_kwargs if hasattr(self, "masking_fn_kwargs") else None
            ),
        )

        trainer_result: Result = self._trainer.train()
        self.result.update(other=trainer_result)

    def predict(
        self,
        data: Optional[
            Union[
                DataPackage,
                DatasetContainer,
                ad.AnnData,
                MuData,  # ty: ignore[invalid-type-form]
            ]  # ty: ignore[invalid-type-form]
        ] = None,  # ty: ignore[invalid-type-form]
        config: Optional[Union[None, DefaultConfig]] = None,
        from_key: Optional[str] = None,
        to_key: Optional[str] = None,
        **kwargs,
    ):
        """Generates predictions using the trained model.

        Uses the trained model to make predictions on test data or new data
        provided by the user. Processes the results and stores them in the
        result container.

        Args:
            data: Optional new data for predictions.
            config: Optional custom configuration for prediction.
            **kwargs: Additional configuration parameters as keyword arguments.

        Raises:
            NotImplementedError: If required components aren't initialized.
            ValueError: If no test data is available or data format is invalid.
        """
        self._validate_prediction_requirements()
        if self._trainer is None:
            raise ValueError(
                "Trainer not initialized, call fit first. If you used .save and .load, then you shoul not call .fit, then this is a bug."
                "In this case please submit an issue."
            )

        self._trainer.setup_trainer(old_model=self.result.model)
        original_input = data
        predict_data = self._prepare_prediction_data(data=data)

        predictor_results = self._generate_predictions(
            predict_data=predict_data,
        )

        self._process_latent_results(
            predictor_results=predictor_results, predict_data=predict_data
        )
        self._postprocess_reconstruction(
            predictor_results=predictor_results,
            original_input=original_input,
            predict_data=predict_data,
        )
        self.result.update(predictor_results)
        return self.result

    def _validate_prediction_requirements(self):
        """Validate that required components are initialized."""
        if self._preprocessor is None:
            raise NotImplementedError("Preprocessor not initialized")
        if self.result.model is None:
            raise NotImplementedError(
                "Model not trained. Please run the fit method first"
            )

    def _prepare_prediction_data(
        self,
        data: Optional[
            Union[
                DataPackage,
                DatasetContainer,
                ad.AnnData,
                MuData,  # ty: ignore[invalid-type-form]
            ]  # ty: ignore[invalid-type-form]
        ] = None,  # ty: ignore[invalid-type-form]
    ) -> DatasetContainer:
        """Prepare and validate input data for prediction.
        Args:
            data: Optional new data for predictions. If None, uses existing datasets.
        Returns:
            DatasetContainer: The prepared dataset container for predictions.
        Raises:
            ValueError: If data type is unsupported or no test data is available.
        """
        if data is None:
            return self._get_existing_datasets()
        elif isinstance(data, DatasetContainer):
            return self._handle_dataset_container(data=data)
        elif isinstance(data, (DataPackage, ad.AnnData, MuData, dict, pd.DataFrame)):
            return self._handle_user_data(data=data)
        else:
            raise ValueError(f"Unsupported data type: {type(data)}")

    def _get_existing_datasets(self) -> DatasetContainer:
        """Get existing preprocessed datasets and validate them for prediction.
        Returns:
            DatasetContainer: The preprocessed datasets available for prediction.
        Raises:
            ValueError: If no datasets are available or no test data is present.
        """
        if self._datasets is None:
            raise ValueError(
                "No data provided for prediction and no preprocessed datasets "
                "available. Please run the preprocess method first or provide "
                "data for prediction."
            )
        if self._datasets.test is None:
            raise ValueError("No test data available for prediction")
        return self._datasets

    def _handle_dataset_container(self, data: DatasetContainer) -> DatasetContainer:
        """Handle DatasetContainer input for prediction.
        Args:
            data: DatasetContainer containing preprocessed datasets.
        Returns:
            DatasetContainer: The processed dataset container for predictions.
        """
        self.result.new_datasets = data

        if hasattr(self._preprocessor, "_dataset_container"):
            self._preprocessor._dataset_container = data

        return data

    def _handle_user_data(self, data: Any) -> DatasetContainer:
        """Handle user-provided data (DataPackage, AnnData, etc.).
        Args:
            data: Raw user data in various formats (DataPackage, AnnData, etc.).
        Returns:
            DatasetContainer: The processed dataset container for predictions.
        Raises:
            ValueError: If data type is unsupported or no test data is available.
        """
        processed_data, _ = self._handle_direct_user_data(data=data)
        predict_data = self._preprocessor.preprocess(
            raw_user_data=processed_data, predict_new_data=True
        )
        self.result.new_datasets = predict_data
        return predict_data

    def _validate_prediction_data(self, predict_data: DatasetContainer):
        """Validate that prediction data has required test split."""
        if predict_data.test is None:
            raise ValueError(
                f"The data for prediction need to be a DatasetContainer with a test "
                f"attribute, got: {predict_data}"
            )

    def _generate_predictions(
        self,
        predict_data: DatasetContainer,
    ):
        """Generate predictions using the trained model.
        Args:
            predict_data: DatasetContainer with preprocessed datasets for prediction.
        Returns:
            Predictor results containing latent spaces and reconstructions.

        """
        self._validate_prediction_data(predict_data=predict_data)
        return self._trainer.predict(
            data=predict_data.test,
            model=self.result.model,
        )  # type: ignore

    def _process_latent_results(
        self, predictor_results, predict_data: DatasetContainer
    ):
        """Process and store latent space results.
        Args:
            predictor_results: Results from the prediction step containing latents.
            predict_data: DatasetContainer with preprocessed datasets for prediction.
        """
        latent = predictor_results.latentspaces.get(epoch=-1, split="test")
        if isinstance(latent, dict):
            print("Detected dictionary in latent results, extracting array...")
            latent = next(iter(latent.values()))  # TODO better adjust for xmodal
        self.result.adata_latent = ad.AnnData(latent)
        self.result.adata_latent.obs_names = predict_data.test.sample_ids  # type: ignore
        self.result.adata_latent.uns["var_names"] = predict_data.test.feature_ids  # type: ignore
        self.result.update(predictor_results)

    def _postprocess_reconstruction(
        self, predictor_results, original_input, predict_data: DatasetContainer
    ):
        """Postprocess reconstruction results based on input type.

        This outpus the reconstruction in the same format as the original input data,
        whether it is a DatasetContainer, DataPackage, AnnData, MuData, or other formats.

        Args:
            predictor_results: Results from the prediction step containing reconstructions.
            original_input: Original input data format (if provided).
            predict_data: DatasetContainer with preprocessed datasets for prediction.
        Raises:
            ValueError: If reconstruction fails or data types are incompatible.
        """
        raw_recon: Union[Dict, np.ndarray, torch.Tensor] = (
            self.result.reconstructions.get(epoch=-1, split="test")
        )
        if isinstance(raw_recon, np.ndarray):
            raw_recon = torch.from_numpy(raw_recon)  # type: ignore
        elif isinstance(raw_recon, dict):
            raw_recon = raw_recon.get("translation")  # type: ignore
            if raw_recon is None:
                raise ValueError(
                    f"Raw recon is dict, but has no translation key, this should not happen: {raw_recon}"
                )
            raw_recon = torch.from_numpy(raw_recon)  # type: ignore
        else:
            raise ValueError(
                f"type of raw_recon has to be 'dict' or 'np.ndarray', got: {type(raw_recon)}"
            )

        if original_input is None:
            # Using existing datasets
            self._handle_dataset_container_reconstruction(
                raw_recon=raw_recon,  # type: ignore
                dataset_container=predict_data,
                context="existing datasets",
            )
        elif isinstance(original_input, DatasetContainer):
            self._handle_dataset_container_reconstruction(
                raw_recon=raw_recon,  # type: ignore
                dataset_container=original_input,
                context="provided DatasetContainer",
            )
        elif self.config.data_case == DataCase.MULTI_SINGLE_CELL:
            self._handle_multi_single_cell_reconstruction(
                raw_recon=raw_recon,
                predictor_results=predictor_results,  # type: ignore
            )
        elif isinstance(
            original_input, (DataPackage, ad.AnnData, MuData, dict, pd.DataFrame)
        ):
            self._handle_user_data_reconstruction(
                raw_recon=raw_recon, predictor_results=predictor_results
            )
        else:
            self._handle_unsupported_reconstruction()

    def _handle_dataset_container_reconstruction(
        self,
        raw_recon: torch.Tensor,
        dataset_container: DatasetContainer,
        context: str = "DatasetContainer",
    ):
        """Handle reconstruction for DatasetContainer input.
        Args:
            raw_recon: Raw reconstruction tensor from the model.
            dataset_container: Original DatasetContainer provided by the user.
            context: Description of the data context for error messages.
        Raises:
            ValueError: If no test data is available in the container.
        """

        # if dataset_container.test is None:
        #     raise ValueError(f"No test data available in {context} for reconstruction.")
        # temp = copy.deepcopy(dataset_container.test)
        # temp.data = raw_recon
        # self.result.final_reconstruction = temp
        pass

    def _handle_multi_single_cell_reconstruction(
        self, raw_recon: torch.Tensor, predictor_results: Result
    ):
        """Handle reconstruction for multi-single-cell data
        Args:
            raw_recon: Raw reconstruction tensor from the model.
            predictor_results: Results from the prediction step containing reconstructions.
        Raises:
            ValueError: If reconstruction formatting fails or data types are incompatible.
        """
        pkg = self._preprocessor.format_reconstruction(
            reconstruction=raw_recon, result=predictor_results
        )
        if not isinstance(pkg.multi_sc, dict):
            raise ValueError(
                "Expected pkg.multi_sc to be a dictionary, got "
                f"{type(pkg.multi_sc)} instead."
            )
        self.result.final_reconstruction = pkg.multi_sc["multi_sc"]

    def _handle_user_data_reconstruction(
        self, raw_recon: torch.Tensor, predictor_results
    ):
        """Handle reconstruction for user-provided data formats.
        Args:
            raw_recon: Raw reconstruction tensor from the model.
            predictor_results: Results from the prediction step containing reconstructions.
        """
        pkg = self._preprocessor.format_reconstruction(
            reconstruction=raw_recon, result=predictor_results
        )
        self.result.final_reconstruction = pkg

    def _handle_unsupported_reconstruction(self):
        """Handle cases where reconstruction formatting is not available."""
        print(
            "Reconstruction Formatting (the process of using the reconstruction "
            "output of the autoencoder models and combine it with metadata to get "
            "the exact same data structure as the raw input data i.e, a DataPackage, "
            "DatasetContainer, or AnnData) not available for this data type or case."
        )

    def decode(
        self, latent: Union[torch.Tensor, ad.AnnData, pd.DataFrame]
    ) -> Union[torch.Tensor, ad.AnnData, pd.DataFrame]:
        """Transforms latent space representations back to input space.

        Handles various input formats for the latent representation and
        returns the decoded data in a matching format.

        Args:
            latent: Latent space representation to decode.

        Returns:
            Decoded data in a format matching the input.

        Raises:
            TypeError: If no model has been trained or input type is invalid.
            ValueError: If latent dimensions are incompatible with the model.
        """
        if self.result.model is None:
            raise TypeError("No model trained yet, use fit() or run() method first")
        recons: torch.Tensor
        if isinstance(latent, ad.AnnData):
            latent_data = torch.tensor(
                latent.X, dtype=torch.float32
            )  # Ensure float for compatibility

            expected_latent_dim = self.config.latent_dim
            if not latent_data.shape[1] == expected_latent_dim:
                raise ValueError(
                    f"Input AnnData's .X has shape {latent_data.shape}, but the model "
                    f"expects a latent vector of size {expected_latent_dim}. Consider "
                    f"projecting the AnnData to the correct latent space first."
                )
            latent_tensor = latent_data

            recons = self._trainer.decode(x=latent_tensor)
            if self._datasets is None:
                raise ValueError(
                    "No datasets available in the DatasetContainer to reconstruct "
                    "AnnData objects. Please provide a valid DatasetContainer."
                )
            if self._datasets.train is None:
                raise ValueError(
                    "The train dataset in the DatasetContainer is None. "
                    "Please provide a valid train dataset to reconstruct AnnData objects."
                )
            if not isinstance(self._datasets.train, BaseDataset):
                raise TypeError(
                    "The train dataset in the DatasetContainer must be a BaseDataset "
                    "to reconstruct AnnData objects."
                )
            recons_adata = ad.AnnData(
                X=recons.to("cpu").detach().numpy(),
                obs=pd.DataFrame(index=latent.obs_names),
                var=pd.DataFrame(index=self._datasets.train.feature_ids),
            )

            return recons_adata
        elif isinstance(latent, pd.DataFrame):
            latent_tensor = torch.tensor(latent.values, dtype=torch.float32)
            recons = self._trainer.decode(x=latent_tensor)
            return pd.DataFrame(
                recons.to("cpu").detach().numpy(),
                index=latent.index,
                columns=latent.columns,
            )
        elif isinstance(latent, torch.Tensor):
            # Check size compatibility
            expected_latent_dim = self.config.latent_dim
            if not latent.shape[1] == expected_latent_dim:
                if self._trainer._model._mu.out_features == latent.shape[1]:
                    warnings.warn(
                        f"latent_prior has latent dimension {latent.shape[1]}, "
                        "which matches the input feature dimension of the model. Did you "
                        "mean to provide latent vectors of dimension "
                        "For Ontix this is the default behaviour and the warning can be ignored. "
                        f"{self.config.latent_dim}?"
                    )
                else:
                    raise ValueError(
                        f"latent_prior has incompatible latent dimension {latent.shape[1]}, "
                        f"expected {self.config.latent_dim}. or {self._trainer._model._mu.out_features}."
                    )

            latent_tensor = latent
        else:
            raise TypeError(
                f"Input 'latent' must be either a torch.Tensor or an AnnData object, "
                f"not {type(latent)}."
            )

        return self._trainer.decode(x=latent_tensor)

    def evaluate(
        self,
        ml_model_class: ClassifierMixin = linear_model.LogisticRegression(),
        ml_model_regression: RegressorMixin = linear_model.LinearRegression(),
        params: Union[
            list, str
        ] = [],  # Default empty list, to use all parameters use string "all"
        metric_class: str = "roc_auc_ovo",  # Default is 'roc_auc_ovo' via https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-string-names
        metric_regression: str = "r2",  # Default is 'r2'
        reference_methods: list = [],  # Default [], Options are "PCA", "UMAP", "TSNE", "RandomFeature"
        split_type: Literal[
            "use-split", "CV-5", "LOOC"
        ] = "use-split",  # Default is "use-split", other options: "CV-5", ... "LOOCV"?
        n_downsample: Optional[int] = 10000,
    ) -> Result:
        """TODO"""
        if self.evaluator is None:
            raise NotImplementedError("Evaluator not initialized")
        if self.result.model is None:
            raise NotImplementedError(
                "Model not trained. Please run the fit method first"
            )
        if not is_classifier(ml_model_class):
            warnings.warn(
                "The provided model is not a sklearn-type classifier. "
                "Evaluation continues but may produce incorrect results or errors."
            )
        if not is_regressor(ml_model_regression):
            warnings.warn(
                "The provided model is not a sklearn-type regressor. "
                "Evaluation continues but may produce incorrect results or errors."
            )

        if len(params) == 0:
            if self.config.data_config.annotation_columns is None:
                params = []  # type: ignore
            else:
                params = self.config.data_config.annotation_columns  # type: ignore

        if len(params) == 0:
            raise ValueError(
                "No parameters specified for evaluation. Please provide a list of "
                "parameters or ensure that annotation_columns are set in the config."
            )

        if "RandomFeature" in reference_methods:
            if self._datasets is None:
                raise ValueError(
                    "Datasets not available for adding RandomFeature. Please keep "
                    "preprocessed data available before evaluation."
                )

        if len(self.result.latentspaces._data) == 0:
            raise ValueError(
                "No latent spaces found in results. Please run predict() to "
                "calculate embeddings before evaluation."
            )

        self.result = self.evaluator.evaluate(
            datasets=self._datasets,
            result=self.result,
            ml_model_class=ml_model_class,
            ml_model_regression=ml_model_regression,
            params=params,
            metric_class=metric_class,
            metric_regression=metric_regression,
            reference_methods=reference_methods,
            split_type=split_type,
            n_downsample=n_downsample,
        )

        _: Any = self.visualizer._plot_evaluation(result=self.result)

        return self.result

    def visualize(self, config: Optional[Union[None, DefaultConfig]] = None, **kwargs):
        """Creates visualizations of model results and performance.

        Args:
            config: Optional custom configuration for visualization.
            **kwargs: Additional configuration parameters.

        Raises:
            NotImplementedError: If visualizer is not initialized.
        """
        if self.visualizer is None:
            raise NotImplementedError("Visualizer not initialized")

        self.visualizer.visualize(result=self.result, config=self.config)

    def show_result(self, split: str = "all", **kwargs):
        """Displays key visualizations of model results.

        This method generates the following visualizations:
        1. Loss Curves: Displays the absolute loss curves to provide insights into
           the model's training and validation performance over epochs.
        2. Latent Space Ridgeline Plot: Visualizes the distribution of the latent
           space representations across different dimensions, offering a high-level
           overview of the learned embeddings.
        3. Latent Space 2D Scatter Plot: Projects the latent space into two dimensions
           for a detailed view of the clustering or separation of data points.

        These visualizations help in understanding the model's performance and
        the structure of the latent space representations.
        """
        print("Creating plots ...")

        params: Optional[Union[List[str], str]] = kwargs.pop("params", None)
        # Check if params are empty and annotation columns are available in config
        if params is None and self.config.data_config.annotation_columns:
            params = self.config.data_config.annotation_columns

        if len(self.result.losses._data) != 0:
            self.visualizer.show_loss(plot_type="absolute")
        else:
            warnings.warn(
                "No loss data found in results. Skipping loss curve visualization."
            )

        if len(self.result.latentspaces._data) != 0:
            self.visualizer.show_latent_space(
                result=self.result, plot_type="Ridgeline", split=split, param=params
            )
            self.visualizer.show_latent_space(
                result=self.result, plot_type="2D-scatter", split=split, param=params
            )
        else:
            warnings.warn(
                "No latent spaces found in results. Please run predict() to "
                "calculate embeddings."
            )

    def run(
        self, data: Optional[Union[DatasetContainer, DataPackage]] = None
    ) -> Result:
        """Executes the complete pipeline from preprocessing to visualization.

        Runs all pipeline steps in sequence and returns the result.

        Args:
            data: Optional data for prediction (overrides test data).

        Returns:
            Complete pipeline results.
        """
        self.preprocess()
        self.fit()
        self.predict(data=data)
        self.visualize()
        return self.result

    def save(self, file_path: str, save_all: bool = False):
        """Saves the pipeline to a file.

        Args:
            file_path: Path where the pipeline should be saved.
        """
        saver = Saver(file_path, save_all=save_all)
        saver.save(self)

    @classmethod
    def load(cls, file_path) -> Any:
        """Loads a pipeline from a file.

        Args:
            file_path: Path to the saved pipeline.

        Returns:
            The loaded pipeline instance.
        """
        loader = Loader(file_path)
        return loader.load()

    def sample_latent_space(
        self,
        n_samples: int,
        split: str = "test",
        epoch: int = -1,
    ) -> torch.Tensor:
        """Samples latent space points from the learned distribution.

        If `n_samples` is not provided, this method returns one latent point per
        sample in the specified split (legacy behavior). If `n_samples` is given,
        it draws samples from the aggregated posterior distribution of the split.

        Args:
            split: The split to sample from (train, valid, test), default is test.
            epoch: The epoch to sample from, default is the last epoch (-1).
            n_samples: Optional number of latent points to sample. If None,
                returns one latent point per available sample in the split.

        Returns:
            z: torch.Tensor - The sampled latent space points.

        Raises:
            ValueError: If the model has not been trained or latent statistics
                have not been computed.
            TypeError: If mu or logvar are not numpy arrays.
        """

        if not hasattr(self, "_trainer") or self._trainer is None:
            raise ValueError("Model is not trained yet. Please train the model first.")
        if self.result.mus is None or self.result.sigmas is None:
            raise ValueError("Model has not learned the latent space distribution yet.")
        if not isinstance(n_samples, int) or n_samples <= 0:
            raise ValueError("n_samples must be a positive integer.")

        mu = self.result.mus.get(split=split, epoch=epoch)
        logvar = self.result.sigmas.get(split=split, epoch=epoch)

        if not isinstance(mu, np.ndarray):
            raise TypeError(
                f"Expected value to be of type numpy.ndarray, got {type(mu)}."
                "This can happen if the model was not trained with VAE loss or if you forgot to run predict()"
            )
        if not isinstance(logvar, np.ndarray):
            raise TypeError(
                f"Expected value to be of type numpy.ndarray, got {type(logvar)}."
            )

        mu_t = torch.from_numpy(mu).to(
            device=self._trainer._model.device, dtype=self._trainer._model.dtype
        )
        logvar_t = torch.from_numpy(logvar).to(
            device=self._trainer._model.device, dtype=self._trainer._model.dtype
        )

        with torch.no_grad():
            global_mu = mu_t.mean(dim=0)
            global_logvar = logvar_t.mean(dim=0)

            mu_exp = global_mu.expand(n_samples, -1)
            logvar_exp = global_logvar.expand(n_samples, -1)

            z = self._trainer._model.reparameterize(mu_exp, logvar_exp)
            return z

    def generate(
        self,
        n_samples: Optional[int] = None,
        latent_prior: Optional[Union[np.ndarray, torch.Tensor]] = None,
        split: str = "test",
        epoch: int = -1,
    ) -> torch.Tensor:
        """Generates new samples from the model's latent space.

        This method allows for the generation of new data samples by sampling
        from the model's latent space. Users can either provide a custom latent
        prior or specify the number of samples to generate. If a custom latent
        prior is provided, its batch dimension must be compatible with n_samples.

        Args:
            n_samples: The number of samples to generate.
            latent_prior: Optional custom latent prior distribution. If provided,
                this will be used for sampling instead of the learned distribution.
                The prior must either be a single latent vector or a batch of
                latent vectors matching n_samples.
            split: The split to sample from (train, valid, test), default is test.
            epoch: The epoch to sample from, default is the last epoch (-1).

        Returns:
            torch.Tensor: The generated samples in the input space.

        Raises:
            ValueError: If n_samples is not a positive integer or if the latent
                prior has incompatible dimensions.
            TypeError: If latent_prior is not a numpy array or tensor.
        """
        if not isinstance(n_samples, int) or n_samples <= 0:
            if latent_prior is None:
                raise ValueError(
                    "n_samples must be a positive integer or latent_prior provided."
                )

        if latent_prior is None:
            latent_prior = self.sample_latent_space(
                n_samples=n_samples, split=split, epoch=epoch
            )

        if isinstance(latent_prior, np.ndarray):
            latent_prior = torch.from_numpy(latent_prior).to(
                device=self._trainer._model.device,
                dtype=self._trainer._model.dtype,
            )
        if not isinstance(latent_prior, torch.Tensor):
            raise TypeError(
                f"latent_prior must be numpy.ndarray or torch.Tensor, got {type(latent_prior)}."
            )
        if not latent_prior.shape[1] == self.config.latent_dim:
            if self._trainer._model._mu.out_features == latent_prior.shape[1]:
                warnings.warn(
                    f"latent_prior has latent dimension {latent_prior.shape[1]}, "
                    "which matches the input feature dimension of the model. Did you "
                    "mean to provide latent vectors of dimension "
                    "For Ontix this is the default behaviour and the warning can be ignored. "
                    f"{self.config.latent_dim}?"
                )
            else:
                raise ValueError(
                    f"latent_prior has incompatible latent dimension {latent_prior.shape[1]}, "
                    f"expected {self.config.latent_dim}."
                )

        with torch.no_grad():
            generated = self.decode(latent=latent_prior)
            return generated

    def explain(
        self,
        explainer: Any,
        baseline_type: Literal["mean", "random_sample"] = "mean",
        n_subset: int = 100,
        llm_explain: bool = False,
        llm_client: Literal["ollama", "mistral"] = "mistral",
        llm_model: str = "mistral-medium-latest",
    ):  # TODO Vincent: add return type
        my_converter = AnnDataConverter()
        dataset: Optional[DatasetContainer] = get_dataset(self.result)
        if dataset is None:
            raise ValueError(
                "No dataset available for explanation."
                "This happens if you used .save and .load, and did not run .predict before."
                "This can also happen if you run .explain before .preprocess or .fit."
            )
        adata_train: Optional[Dict[str, ad.AnnData]] = my_converter.dataset_to_adata(
            dataset, split="train"
        )
        adata_test: Optional[Dict[str, ad.AnnData]] = my_converter.dataset_to_adata(
            dataset, split="test"
        )
        adata_valid: Optional[Dict[str, ad.AnnData]] = my_converter.dataset_to_adata(
            dataset, split="valid"
        )
        model = self.result.model
        if model is None:
            raise ValueError(
                "No model available for explanation."
                "This happens if you used .save and .load, and did not run .fit before."
                "This can also happen if you run .explain before .fit."
            )
        # TODO Vincent: Implement feature importance explanation
        # Best with Explainer class that gets initialized here and has a method
        # Maybe like:
        # explainer = FeatureImportanceExplainer(adata_train, adata_test, model, explainer, ...)
        # output = explainer.explain()
        # also note tha adata_<split> can be None, if the split is not available
        # so best to check this before concatenating or using them

        if llm_explain:
            # TODO Vincent:
            # Je nachdem wie die Gene Liste aussieht, müsstet du noch in src/autoencodix/utils/_llm_explainer.py
            # in _init_prompt anpassen, wie der prompt gebaut wird. Ich gehe jetzt von einer Liste aus String aus, aber
            # ich wusste nicht genau was dein return Typ ist.

            llm_explainer = LLMExplainer(
                client_name=llm_client,
                model_name=llm_model,
                gene_list=["GeneA", "GeneB", "GeneC"],  # Example gene list
            )
            explanation = llm_explainer.explain()
            print("LLM Explanation:")
            print(explanation)
            return explanation

__init__(dataset_type, trainer_type, model_type, loss_type, datasplitter_type, preprocessor_type, data, visualizer=None, evaluator=None, result=None, config=None, custom_split=None, ontologies=None, masking_fn=None, masking_fn_kwargs={}, **kwargs)

Initializes the pipeline with components and configuration.

Parameters:

Name Type Description Default
dataset_type Type[BaseDataset]

Class for dataset implementations.

required
trainer_type Type[BaseTrainer]

Class for model training implementations.

required
model_type Type[BaseAutoencoder]

Class for model architecture implementations.

required
loss_type Type[BaseLoss]

Class for loss function implementations.

required
datasplitter_type Type[DataSplitter]

Class for data splitting implementation.

required
preprocessor_type Type[BasePreprocessor]

Class for data preprocessing implementation.

required
visualizer Optional[BaseVisualizer]

Component for generating visualizations.

None
data Optional[Union[DataPackage, DatasetContainer, AnnData, MuData, DataFrame, dict]]

Input data to be processed or already processed data.

required
evaluator Optional[BaseEvaluator]

Component for assessing model performance.

None
result Optional[Result]

Storage container for pipeline outputs.

None
config Optional[DefaultConfig]

Configuration parameters for all pipeline components.

None
custom_split Optional[Dict[str, ndarray]]

User-provided data splits (train/validation/test).

None
**kwargs dict

Additional keyword arguments.

{}

Raises:

Type Description
TypeError

If inputs have incorrect types.

Source code in src/autoencodix/base/_base_pipeline.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def __init__(
    self,
    dataset_type: Type[BaseDataset],
    trainer_type: Type[BaseTrainer],
    model_type: Type[BaseAutoencoder],
    loss_type: Type[BaseLoss],
    datasplitter_type: Type[DataSplitter],
    preprocessor_type: Type[BasePreprocessor],
    data: Optional[
        Union[DataPackage, DatasetContainer, ad.AnnData, MuData, pd.DataFrame, dict]  # type: ignore[invalid-type-form]
    ],
    visualizer: Optional[BaseVisualizer] = None,
    evaluator: Optional[BaseEvaluator] = None,
    result: Optional[Result] = None,
    config: Optional[DefaultConfig] = None,
    custom_split: Optional[Dict[str, np.ndarray]] = None,
    ontologies: Optional[Union[Tuple, Dict[Any, Any]]] = None,
    masking_fn: Optional[Callable] = None,
    masking_fn_kwargs: Dict[str, Any] = {},
    **kwargs: dict,
) -> None:  # ty: ignore[call-non-callable]
    """Initializes the pipeline with components and configuration.

    Args:
        dataset_type: Class for dataset implementations.
        trainer_type: Class for model training implementations.
        model_type: Class for model architecture implementations.
        loss_type: Class for loss function implementations.
        datasplitter_type: Class for data splitting implementation.
        preprocessor_type: Class for data preprocessing implementation.
        visualizer: Component for generating visualizations.
        data: Input data to be processed or already processed data.
        evaluator: Component for assessing model performance.
        result: Storage container for pipeline outputs.
        config: Configuration parameters for all pipeline components.
        custom_split: User-provided data splits (train/validation/test).
        **kwargs: Additional keyword arguments.

    Raises:
        TypeError: If inputs have incorrect types.
    """
    if not hasattr(self, "_default_config"):
        raise ValueError(
            """
                        The _default_config attribute has not been specified in your pipeline class.

                        Example:
                        self._default_config = XModalixConfig()

                        This error typically occurs when a new architecture is added without setting the
                        _default_config in its corresponding pipeline class.

                        For more details, please refer to the 'how to add a new architecture' section in our documentation.
                        """
        )

    self._validate_config(config=config)
    self._validate_user_input(data=data)
    self.masking_fn = masking_fn
    self.masking_fn_kwargs = masking_fn_kwargs
    processed_data = data if isinstance(data, DatasetContainer) else None
    raw_user_data = (
        data
        if isinstance(data, (DataPackage, ad.AnnData, MuData, pd.DataFrame, dict))
        else None
    )
    if processed_data is not None and not isinstance(
        processed_data, DatasetContainer
    ):
        raise TypeError(
            f"Expected data type to be DatasetContainer, got {type(processed_data)}."
        )

    self.preprocessed_data: Optional[DatasetContainer] = processed_data
    self.raw_user_data: Union[
        DataPackage, ad.AnnData, MuData, pd.DataFrame, dict  # type: ignore[invalid-type-form]
    ] = raw_user_data
    self._trainer_type = trainer_type
    self._trainer: Optional[BaseTrainer] = None
    self._model_type = model_type
    self._loss_type = loss_type
    self._preprocessor_type = preprocessor_type
    if self.raw_user_data is not None:
        self.raw_user_data, datacase = self._handle_direct_user_data(
            data=self.raw_user_data,
        )
        self.config.data_case = datacase
        self._fill_data_info()

    self.ontologies = ontologies
    self._preprocessor = self._preprocessor_type(
        config=self.config, ontologies=self.ontologies
    )

    self.visualizer = (
        visualizer()  # ty: ignore[call-non-callable]
        if visualizer is not None
        else BaseVisualizer()  # ty: ignore[call-non-callable]
    )  # ty: ignore[call-non-callable]
    self.evaluator = (
        evaluator()  # ty: ignore[call-non-callable]
        if evaluator is not None
        else BaseEvaluator()  # ty: ignore[call-non-callable]
    )  # ty: ignore[call-non-callable]
    self.result = result if result is not None else Result()
    self._dataset_type = dataset_type
    self._data_splitter = datasplitter_type(
        config=self.config, custom_splits=custom_split
    )

    self._datasets: Optional[DatasetContainer] = (
        processed_data  # None, or user input
    )

decode(latent)

Transforms latent space representations back to input space.

Handles various input formats for the latent representation and returns the decoded data in a matching format.

Parameters:

Name Type Description Default
latent Union[Tensor, AnnData, DataFrame]

Latent space representation to decode.

required

Returns:

Type Description
Union[Tensor, AnnData, DataFrame]

Decoded data in a format matching the input.

Raises:

Type Description
TypeError

If no model has been trained or input type is invalid.

ValueError

If latent dimensions are incompatible with the model.

Source code in src/autoencodix/base/_base_pipeline.py
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
def decode(
    self, latent: Union[torch.Tensor, ad.AnnData, pd.DataFrame]
) -> Union[torch.Tensor, ad.AnnData, pd.DataFrame]:
    """Transforms latent space representations back to input space.

    Handles various input formats for the latent representation and
    returns the decoded data in a matching format.

    Args:
        latent: Latent space representation to decode.

    Returns:
        Decoded data in a format matching the input.

    Raises:
        TypeError: If no model has been trained or input type is invalid.
        ValueError: If latent dimensions are incompatible with the model.
    """
    if self.result.model is None:
        raise TypeError("No model trained yet, use fit() or run() method first")
    recons: torch.Tensor
    if isinstance(latent, ad.AnnData):
        latent_data = torch.tensor(
            latent.X, dtype=torch.float32
        )  # Ensure float for compatibility

        expected_latent_dim = self.config.latent_dim
        if not latent_data.shape[1] == expected_latent_dim:
            raise ValueError(
                f"Input AnnData's .X has shape {latent_data.shape}, but the model "
                f"expects a latent vector of size {expected_latent_dim}. Consider "
                f"projecting the AnnData to the correct latent space first."
            )
        latent_tensor = latent_data

        recons = self._trainer.decode(x=latent_tensor)
        if self._datasets is None:
            raise ValueError(
                "No datasets available in the DatasetContainer to reconstruct "
                "AnnData objects. Please provide a valid DatasetContainer."
            )
        if self._datasets.train is None:
            raise ValueError(
                "The train dataset in the DatasetContainer is None. "
                "Please provide a valid train dataset to reconstruct AnnData objects."
            )
        if not isinstance(self._datasets.train, BaseDataset):
            raise TypeError(
                "The train dataset in the DatasetContainer must be a BaseDataset "
                "to reconstruct AnnData objects."
            )
        recons_adata = ad.AnnData(
            X=recons.to("cpu").detach().numpy(),
            obs=pd.DataFrame(index=latent.obs_names),
            var=pd.DataFrame(index=self._datasets.train.feature_ids),
        )

        return recons_adata
    elif isinstance(latent, pd.DataFrame):
        latent_tensor = torch.tensor(latent.values, dtype=torch.float32)
        recons = self._trainer.decode(x=latent_tensor)
        return pd.DataFrame(
            recons.to("cpu").detach().numpy(),
            index=latent.index,
            columns=latent.columns,
        )
    elif isinstance(latent, torch.Tensor):
        # Check size compatibility
        expected_latent_dim = self.config.latent_dim
        if not latent.shape[1] == expected_latent_dim:
            if self._trainer._model._mu.out_features == latent.shape[1]:
                warnings.warn(
                    f"latent_prior has latent dimension {latent.shape[1]}, "
                    "which matches the input feature dimension of the model. Did you "
                    "mean to provide latent vectors of dimension "
                    "For Ontix this is the default behaviour and the warning can be ignored. "
                    f"{self.config.latent_dim}?"
                )
            else:
                raise ValueError(
                    f"latent_prior has incompatible latent dimension {latent.shape[1]}, "
                    f"expected {self.config.latent_dim}. or {self._trainer._model._mu.out_features}."
                )

        latent_tensor = latent
    else:
        raise TypeError(
            f"Input 'latent' must be either a torch.Tensor or an AnnData object, "
            f"not {type(latent)}."
        )

    return self._trainer.decode(x=latent_tensor)

evaluate(ml_model_class=linear_model.LogisticRegression(), ml_model_regression=linear_model.LinearRegression(), params=[], metric_class='roc_auc_ovo', metric_regression='r2', reference_methods=[], split_type='use-split', n_downsample=10000)

TODO

Source code in src/autoencodix/base/_base_pipeline.py
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
def evaluate(
    self,
    ml_model_class: ClassifierMixin = linear_model.LogisticRegression(),
    ml_model_regression: RegressorMixin = linear_model.LinearRegression(),
    params: Union[
        list, str
    ] = [],  # Default empty list, to use all parameters use string "all"
    metric_class: str = "roc_auc_ovo",  # Default is 'roc_auc_ovo' via https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-string-names
    metric_regression: str = "r2",  # Default is 'r2'
    reference_methods: list = [],  # Default [], Options are "PCA", "UMAP", "TSNE", "RandomFeature"
    split_type: Literal[
        "use-split", "CV-5", "LOOC"
    ] = "use-split",  # Default is "use-split", other options: "CV-5", ... "LOOCV"?
    n_downsample: Optional[int] = 10000,
) -> Result:
    """TODO"""
    if self.evaluator is None:
        raise NotImplementedError("Evaluator not initialized")
    if self.result.model is None:
        raise NotImplementedError(
            "Model not trained. Please run the fit method first"
        )
    if not is_classifier(ml_model_class):
        warnings.warn(
            "The provided model is not a sklearn-type classifier. "
            "Evaluation continues but may produce incorrect results or errors."
        )
    if not is_regressor(ml_model_regression):
        warnings.warn(
            "The provided model is not a sklearn-type regressor. "
            "Evaluation continues but may produce incorrect results or errors."
        )

    if len(params) == 0:
        if self.config.data_config.annotation_columns is None:
            params = []  # type: ignore
        else:
            params = self.config.data_config.annotation_columns  # type: ignore

    if len(params) == 0:
        raise ValueError(
            "No parameters specified for evaluation. Please provide a list of "
            "parameters or ensure that annotation_columns are set in the config."
        )

    if "RandomFeature" in reference_methods:
        if self._datasets is None:
            raise ValueError(
                "Datasets not available for adding RandomFeature. Please keep "
                "preprocessed data available before evaluation."
            )

    if len(self.result.latentspaces._data) == 0:
        raise ValueError(
            "No latent spaces found in results. Please run predict() to "
            "calculate embeddings before evaluation."
        )

    self.result = self.evaluator.evaluate(
        datasets=self._datasets,
        result=self.result,
        ml_model_class=ml_model_class,
        ml_model_regression=ml_model_regression,
        params=params,
        metric_class=metric_class,
        metric_regression=metric_regression,
        reference_methods=reference_methods,
        split_type=split_type,
        n_downsample=n_downsample,
    )

    _: Any = self.visualizer._plot_evaluation(result=self.result)

    return self.result

fit(config=None, **kwargs)

Trains the model on preprocessed data.

Creates and configures a trainer instance, then executes the training process using the preprocessed datasets.

Parameters:

Name Type Description Default
config Optional[Union[None, DefaultConfig]]

Optional custom configuration for training.

None
**kwargs

Additional configuration parameters as keyword arguments.

{}

Raises:

Type Description
ValueError

If datasets aren't available for training.

Source code in src/autoencodix/base/_base_pipeline.py
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
def fit(self, config: Optional[Union[None, DefaultConfig]] = None, **kwargs):
    """Trains the model on preprocessed data.

    Creates and configures a trainer instance, then executes the training
    process using the preprocessed datasets.

    Args:
        config: Optional custom configuration for training.
        **kwargs: Additional configuration parameters as keyword arguments.

    Raises:
        ValueError: If datasets aren't available for training.
    """
    if self._datasets is None:
        raise ValueError(
            "Datasets not built. Please run the preprocess method first."
        )

    self._trainer = self._trainer_type(
        trainset=self._datasets.train,
        validset=self._datasets.valid,
        result=self.result,
        config=self.config,
        model_type=self._model_type,
        loss_type=self._loss_type,
        ontologies=self.ontologies,  # Ontix
        masking_fn=self.masking_fn if hasattr(self, "masking_fn") else None,
        masking_fn_kwargs=(
            self.masking_fn_kwargs if hasattr(self, "masking_fn_kwargs") else None
        ),
    )

    trainer_result: Result = self._trainer.train()
    self.result.update(other=trainer_result)

generate(n_samples=None, latent_prior=None, split='test', epoch=-1)

Generates new samples from the model's latent space.

This method allows for the generation of new data samples by sampling from the model's latent space. Users can either provide a custom latent prior or specify the number of samples to generate. If a custom latent prior is provided, its batch dimension must be compatible with n_samples.

Parameters:

Name Type Description Default
n_samples Optional[int]

The number of samples to generate.

None
latent_prior Optional[Union[ndarray, Tensor]]

Optional custom latent prior distribution. If provided, this will be used for sampling instead of the learned distribution. The prior must either be a single latent vector or a batch of latent vectors matching n_samples.

None
split str

The split to sample from (train, valid, test), default is test.

'test'
epoch int

The epoch to sample from, default is the last epoch (-1).

-1

Returns:

Type Description
Tensor

torch.Tensor: The generated samples in the input space.

Raises:

Type Description
ValueError

If n_samples is not a positive integer or if the latent prior has incompatible dimensions.

TypeError

If latent_prior is not a numpy array or tensor.

Source code in src/autoencodix/base/_base_pipeline.py
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
def generate(
    self,
    n_samples: Optional[int] = None,
    latent_prior: Optional[Union[np.ndarray, torch.Tensor]] = None,
    split: str = "test",
    epoch: int = -1,
) -> torch.Tensor:
    """Generates new samples from the model's latent space.

    This method allows for the generation of new data samples by sampling
    from the model's latent space. Users can either provide a custom latent
    prior or specify the number of samples to generate. If a custom latent
    prior is provided, its batch dimension must be compatible with n_samples.

    Args:
        n_samples: The number of samples to generate.
        latent_prior: Optional custom latent prior distribution. If provided,
            this will be used for sampling instead of the learned distribution.
            The prior must either be a single latent vector or a batch of
            latent vectors matching n_samples.
        split: The split to sample from (train, valid, test), default is test.
        epoch: The epoch to sample from, default is the last epoch (-1).

    Returns:
        torch.Tensor: The generated samples in the input space.

    Raises:
        ValueError: If n_samples is not a positive integer or if the latent
            prior has incompatible dimensions.
        TypeError: If latent_prior is not a numpy array or tensor.
    """
    if not isinstance(n_samples, int) or n_samples <= 0:
        if latent_prior is None:
            raise ValueError(
                "n_samples must be a positive integer or latent_prior provided."
            )

    if latent_prior is None:
        latent_prior = self.sample_latent_space(
            n_samples=n_samples, split=split, epoch=epoch
        )

    if isinstance(latent_prior, np.ndarray):
        latent_prior = torch.from_numpy(latent_prior).to(
            device=self._trainer._model.device,
            dtype=self._trainer._model.dtype,
        )
    if not isinstance(latent_prior, torch.Tensor):
        raise TypeError(
            f"latent_prior must be numpy.ndarray or torch.Tensor, got {type(latent_prior)}."
        )
    if not latent_prior.shape[1] == self.config.latent_dim:
        if self._trainer._model._mu.out_features == latent_prior.shape[1]:
            warnings.warn(
                f"latent_prior has latent dimension {latent_prior.shape[1]}, "
                "which matches the input feature dimension of the model. Did you "
                "mean to provide latent vectors of dimension "
                "For Ontix this is the default behaviour and the warning can be ignored. "
                f"{self.config.latent_dim}?"
            )
        else:
            raise ValueError(
                f"latent_prior has incompatible latent dimension {latent_prior.shape[1]}, "
                f"expected {self.config.latent_dim}."
            )

    with torch.no_grad():
        generated = self.decode(latent=latent_prior)
        return generated

load(file_path) classmethod

Loads a pipeline from a file.

Parameters:

Name Type Description Default
file_path

Path to the saved pipeline.

required

Returns:

Type Description
Any

The loaded pipeline instance.

Source code in src/autoencodix/base/_base_pipeline.py
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
@classmethod
def load(cls, file_path) -> Any:
    """Loads a pipeline from a file.

    Args:
        file_path: Path to the saved pipeline.

    Returns:
        The loaded pipeline instance.
    """
    loader = Loader(file_path)
    return loader.load()

predict(data=None, config=None, from_key=None, to_key=None, **kwargs)

Generates predictions using the trained model.

Uses the trained model to make predictions on test data or new data provided by the user. Processes the results and stores them in the result container.

Parameters:

Name Type Description Default
data Optional[Union[DataPackage, DatasetContainer, AnnData, MuData]]

Optional new data for predictions.

None
config Optional[Union[None, DefaultConfig]]

Optional custom configuration for prediction.

None
**kwargs

Additional configuration parameters as keyword arguments.

{}

Raises:

Type Description
NotImplementedError

If required components aren't initialized.

ValueError

If no test data is available or data format is invalid.

Source code in src/autoencodix/base/_base_pipeline.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
def predict(
    self,
    data: Optional[
        Union[
            DataPackage,
            DatasetContainer,
            ad.AnnData,
            MuData,  # ty: ignore[invalid-type-form]
        ]  # ty: ignore[invalid-type-form]
    ] = None,  # ty: ignore[invalid-type-form]
    config: Optional[Union[None, DefaultConfig]] = None,
    from_key: Optional[str] = None,
    to_key: Optional[str] = None,
    **kwargs,
):
    """Generates predictions using the trained model.

    Uses the trained model to make predictions on test data or new data
    provided by the user. Processes the results and stores them in the
    result container.

    Args:
        data: Optional new data for predictions.
        config: Optional custom configuration for prediction.
        **kwargs: Additional configuration parameters as keyword arguments.

    Raises:
        NotImplementedError: If required components aren't initialized.
        ValueError: If no test data is available or data format is invalid.
    """
    self._validate_prediction_requirements()
    if self._trainer is None:
        raise ValueError(
            "Trainer not initialized, call fit first. If you used .save and .load, then you shoul not call .fit, then this is a bug."
            "In this case please submit an issue."
        )

    self._trainer.setup_trainer(old_model=self.result.model)
    original_input = data
    predict_data = self._prepare_prediction_data(data=data)

    predictor_results = self._generate_predictions(
        predict_data=predict_data,
    )

    self._process_latent_results(
        predictor_results=predictor_results, predict_data=predict_data
    )
    self._postprocess_reconstruction(
        predictor_results=predictor_results,
        original_input=original_input,
        predict_data=predict_data,
    )
    self.result.update(predictor_results)
    return self.result

preprocess(config=None, **kwargs)

Filters, normalizes and prepares data for model training.

Processes raw input data into the format required by the model and creates train/validation/test splits as needed.

Parameters:

Name Type Description Default
config Optional[Union[None, DefaultConfig]]

Optional custom configuration for preprocessing.

None
**kwargs

Additional configuration parameters as keyword arguments.

{}

Raises:

Type Description
NotImplementedError

If preprocessor is not initialized.

Source code in src/autoencodix/base/_base_pipeline.py
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
def preprocess(self, config: Optional[Union[None, DefaultConfig]] = None, **kwargs):
    """Filters, normalizes and prepares data for model training.

    Processes raw input data into the format required by the model and creates
    train/validation/test splits as needed.

    Args:
        config: Optional custom configuration for preprocessing.
        **kwargs: Additional configuration parameters as keyword arguments.

    Raises:
        NotImplementedError: If preprocessor is not initialized.
    """
    if self._preprocessor_type is None:
        raise NotImplementedError("Preprocessor not initialized")
    self._validate_user_data()
    if self.preprocessed_data is None:
        self.preprocessed_data = self._preprocessor.preprocess(
            raw_user_data=self.raw_user_data,  # type: ignore
        )
        self.result.datasets = self.preprocessed_data
        self._datasets = self.preprocessed_data
    else:
        self._datasets = self.preprocessed_data
        self.result.datasets = self.preprocessed_data

run(data=None)

Executes the complete pipeline from preprocessing to visualization.

Runs all pipeline steps in sequence and returns the result.

Parameters:

Name Type Description Default
data Optional[Union[DatasetContainer, DataPackage]]

Optional data for prediction (overrides test data).

None

Returns:

Type Description
Result

Complete pipeline results.

Source code in src/autoencodix/base/_base_pipeline.py
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
def run(
    self, data: Optional[Union[DatasetContainer, DataPackage]] = None
) -> Result:
    """Executes the complete pipeline from preprocessing to visualization.

    Runs all pipeline steps in sequence and returns the result.

    Args:
        data: Optional data for prediction (overrides test data).

    Returns:
        Complete pipeline results.
    """
    self.preprocess()
    self.fit()
    self.predict(data=data)
    self.visualize()
    return self.result

sample_latent_space(n_samples, split='test', epoch=-1)

Samples latent space points from the learned distribution.

If n_samples is not provided, this method returns one latent point per sample in the specified split (legacy behavior). If n_samples is given, it draws samples from the aggregated posterior distribution of the split.

Parameters:

Name Type Description Default
split str

The split to sample from (train, valid, test), default is test.

'test'
epoch int

The epoch to sample from, default is the last epoch (-1).

-1
n_samples int

Optional number of latent points to sample. If None, returns one latent point per available sample in the split.

required

Returns:

Name Type Description
z Tensor

torch.Tensor - The sampled latent space points.

Raises:

Type Description
ValueError

If the model has not been trained or latent statistics have not been computed.

TypeError

If mu or logvar are not numpy arrays.

Source code in src/autoencodix/base/_base_pipeline.py
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
def sample_latent_space(
    self,
    n_samples: int,
    split: str = "test",
    epoch: int = -1,
) -> torch.Tensor:
    """Samples latent space points from the learned distribution.

    If `n_samples` is not provided, this method returns one latent point per
    sample in the specified split (legacy behavior). If `n_samples` is given,
    it draws samples from the aggregated posterior distribution of the split.

    Args:
        split: The split to sample from (train, valid, test), default is test.
        epoch: The epoch to sample from, default is the last epoch (-1).
        n_samples: Optional number of latent points to sample. If None,
            returns one latent point per available sample in the split.

    Returns:
        z: torch.Tensor - The sampled latent space points.

    Raises:
        ValueError: If the model has not been trained or latent statistics
            have not been computed.
        TypeError: If mu or logvar are not numpy arrays.
    """

    if not hasattr(self, "_trainer") or self._trainer is None:
        raise ValueError("Model is not trained yet. Please train the model first.")
    if self.result.mus is None or self.result.sigmas is None:
        raise ValueError("Model has not learned the latent space distribution yet.")
    if not isinstance(n_samples, int) or n_samples <= 0:
        raise ValueError("n_samples must be a positive integer.")

    mu = self.result.mus.get(split=split, epoch=epoch)
    logvar = self.result.sigmas.get(split=split, epoch=epoch)

    if not isinstance(mu, np.ndarray):
        raise TypeError(
            f"Expected value to be of type numpy.ndarray, got {type(mu)}."
            "This can happen if the model was not trained with VAE loss or if you forgot to run predict()"
        )
    if not isinstance(logvar, np.ndarray):
        raise TypeError(
            f"Expected value to be of type numpy.ndarray, got {type(logvar)}."
        )

    mu_t = torch.from_numpy(mu).to(
        device=self._trainer._model.device, dtype=self._trainer._model.dtype
    )
    logvar_t = torch.from_numpy(logvar).to(
        device=self._trainer._model.device, dtype=self._trainer._model.dtype
    )

    with torch.no_grad():
        global_mu = mu_t.mean(dim=0)
        global_logvar = logvar_t.mean(dim=0)

        mu_exp = global_mu.expand(n_samples, -1)
        logvar_exp = global_logvar.expand(n_samples, -1)

        z = self._trainer._model.reparameterize(mu_exp, logvar_exp)
        return z

save(file_path, save_all=False)

Saves the pipeline to a file.

Parameters:

Name Type Description Default
file_path str

Path where the pipeline should be saved.

required
Source code in src/autoencodix/base/_base_pipeline.py
1036
1037
1038
1039
1040
1041
1042
1043
def save(self, file_path: str, save_all: bool = False):
    """Saves the pipeline to a file.

    Args:
        file_path: Path where the pipeline should be saved.
    """
    saver = Saver(file_path, save_all=save_all)
    saver.save(self)

show_result(split='all', **kwargs)

Displays key visualizations of model results.

This method generates the following visualizations: 1. Loss Curves: Displays the absolute loss curves to provide insights into the model's training and validation performance over epochs. 2. Latent Space Ridgeline Plot: Visualizes the distribution of the latent space representations across different dimensions, offering a high-level overview of the learned embeddings. 3. Latent Space 2D Scatter Plot: Projects the latent space into two dimensions for a detailed view of the clustering or separation of data points.

These visualizations help in understanding the model's performance and the structure of the latent space representations.

Source code in src/autoencodix/base/_base_pipeline.py
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
def show_result(self, split: str = "all", **kwargs):
    """Displays key visualizations of model results.

    This method generates the following visualizations:
    1. Loss Curves: Displays the absolute loss curves to provide insights into
       the model's training and validation performance over epochs.
    2. Latent Space Ridgeline Plot: Visualizes the distribution of the latent
       space representations across different dimensions, offering a high-level
       overview of the learned embeddings.
    3. Latent Space 2D Scatter Plot: Projects the latent space into two dimensions
       for a detailed view of the clustering or separation of data points.

    These visualizations help in understanding the model's performance and
    the structure of the latent space representations.
    """
    print("Creating plots ...")

    params: Optional[Union[List[str], str]] = kwargs.pop("params", None)
    # Check if params are empty and annotation columns are available in config
    if params is None and self.config.data_config.annotation_columns:
        params = self.config.data_config.annotation_columns

    if len(self.result.losses._data) != 0:
        self.visualizer.show_loss(plot_type="absolute")
    else:
        warnings.warn(
            "No loss data found in results. Skipping loss curve visualization."
        )

    if len(self.result.latentspaces._data) != 0:
        self.visualizer.show_latent_space(
            result=self.result, plot_type="Ridgeline", split=split, param=params
        )
        self.visualizer.show_latent_space(
            result=self.result, plot_type="2D-scatter", split=split, param=params
        )
    else:
        warnings.warn(
            "No latent spaces found in results. Please run predict() to "
            "calculate embeddings."
        )

visualize(config=None, **kwargs)

Creates visualizations of model results and performance.

Parameters:

Name Type Description Default
config Optional[Union[None, DefaultConfig]]

Optional custom configuration for visualization.

None
**kwargs

Additional configuration parameters.

{}

Raises:

Type Description
NotImplementedError

If visualizer is not initialized.

Source code in src/autoencodix/base/_base_pipeline.py
960
961
962
963
964
965
966
967
968
969
970
971
972
973
def visualize(self, config: Optional[Union[None, DefaultConfig]] = None, **kwargs):
    """Creates visualizations of model results and performance.

    Args:
        config: Optional custom configuration for visualization.
        **kwargs: Additional configuration parameters.

    Raises:
        NotImplementedError: If visualizer is not initialized.
    """
    if self.visualizer is None:
        raise NotImplementedError("Visualizer not initialized")

    self.visualizer.visualize(result=self.result, config=self.config)

BasePreprocessor

Bases: ABC

Contains logic for data preprocessing in the Autoencodix framework.

This class defines the general preprocessing workflow and provides methods for handling different data modalities and data cases. Subclasses should implement the preprocess method to perform specific preprocessing steps.

Attributes:

Name Type Description
config

A DefaultConfig object containing preprocessing configurations.

processed_data

A dictionary to store processed DataPackage objects for each data split.

bulk_genes_to_keep Optional[Dict[str, List[str]]]

Optional list of genes to keep for bulk data.

bulk_scalers Optional[Dict[str, Any]]

Optional dictionary of scalers for bulk data.

sc_genes_to_keep Optional[Dict[str, List[str]]]

Optional dictionary mapping modality keys to lists of genes to keep for single-cell data.

sc_scalers Optional[Dict[str, Dict[str, Any]]]

Optional dictionary mapping modality keys to scalers for single-cell data.

sc_general_genes_to_keep Optional[Dict[str, List]]

Optional dictionary mapping modality keys to lists of genes to keep filtered by non-SC specific methods.

data_readers Dict[Enum, Any]

A dictionary mapping DataCase enum values to data reader instances for different modalities.

_dataset_container Optional[DatasetContainer]

Optional DatasetContainer to hold the processed datasets.

Source code in src/autoencodix/base/_base_preprocessor.py
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
class BasePreprocessor(abc.ABC):
    """Contains logic for data preprocessing in the Autoencodix framework.

    This class defines the general preprocessing workflow and provides
    methods for handling different data modalities and data cases.
    Subclasses should implement the `preprocess` method to perform
    specific preprocessing steps.

    Attributes:
        config: A DefaultConfig object containing preprocessing configurations.
        processed_data: A dictionary to store processed DataPackage objects for each data split.
        bulk_genes_to_keep: Optional list of genes to keep for bulk data.
        bulk_scalers: Optional dictionary of scalers for bulk data.
        sc_genes_to_keep: Optional dictionary mapping modality keys to lists of genes to keep for single-cell data.
        sc_scalers: Optional dictionary mapping modality keys to scalers for single-cell data.
        sc_general_genes_to_keep: Optional dictionary mapping modality keys to lists of genes to keep filtered by non-SC specific methods.
        data_readers: A dictionary mapping DataCase enum values to data reader instances for different modalities.
        _dataset_container: Optional DatasetContainer to hold the processed datasets.
    """

    def __init__(
        self,
        config: DefaultConfig,
        ontologies: Optional[Union[Tuple[Any, Any], Dict[Any, Any]]] = None,
    ):
        """Initializes the BasePreprocessor with a configuration object.

        Args :
            config: A DefaultConfig object containing preprocessing configurations.
            ontologies: Ontology information, if provided for Ontix.
        """
        self.config = config
        self._dataset_container: Optional[DatasetContainer] = None
        self.processed_data = Dict[str, Dict[str, Union[Any, DataPackage]]]
        self.bulk_genes_to_keep: Optional[Dict[str, List[str]]] = None
        self.bulk_scalers: Optional[Dict[str, Any]] = None
        self.sc_genes_to_keep: Optional[Dict[str, List[str]]] = None
        self.sc_scalers: Optional[Dict[str, Dict[str, Any]]] = None
        self.sc_general_genes_to_keep: Optional[Dict[str, List]] = None
        self._ontologies: Optional[Union[Tuple[Any, Any], Dict[Any, Any]]] = ontologies
        self.data_readers: Dict[Enum, Any] = {
            DataCase.MULTI_SINGLE_CELL: SingleCellDataReader(),
            DataCase.MULTI_BULK: BulkDataReader(config=self.config),
            DataCase.BULK_TO_BULK: BulkDataReader(config=self.config),
            DataCase.SINGLE_CELL_TO_SINGLE_CELL: SingleCellDataReader(),
            DataCase.IMG_TO_BULK: {
                "bulk": BulkDataReader(config=self.config),
                "img": ImageDataReader(config=self.config),
            },
            DataCase.SINGLE_CELL_TO_IMG: {
                "sc": SingleCellDataReader(),
                "img": ImageDataReader(config=self.config),
            },
            DataCase.IMG_TO_IMG: ImageDataReader(config=self.config),
        }

    @abc.abstractmethod
    def preprocess(
        self,
        raw_user_data: Optional[DataPackage] = None,
        predict_new_data: bool = False,
    ) -> DatasetContainer:
        """To be implemented by subclasses for specific preprocessing steps.
        Args:
            raw_user_data: Users can provide raw data. This is an alternative way of
                providing data via filepaths in the config. If this param is passed, we skip the data reading step.
            predict_new_data: Indicates whether the user wants to predict with unseen data.
                If this is the case, we don't split the data and only prerpocess.
        """
        pass

    def _general_preprocess(
        self,
        raw_user_data: Optional[DataPackage] = None,
        predict_new_data: bool = False,
    ) -> Dict[str, Dict[str, Union[Any, DataPackage]]]:
        """Orchestrates the preprocessing steps.

        This method determines the data case from the configuration and calls
        the appropriate processing function for that data case.

        Args:
            raw_user_data: Optional DataPackage containing user-provided data.
                If provided, the data reading step is skipped.
            predict_new_data: Boolean indicating whether to preprocess new unseen data
                without splitting it into train/validation/test sets.

        Returns:
            A dictionary containing processed DataPackage objects for each data split
            (e.g., 'train', 'validation', 'test').

        Raises:
            ValueError: If an unsupported data case is encountered.
        """
        self.predict_new_data = predict_new_data
        datacase = self.config.data_case
        if datacase is None:
            raise TypeError(
                "datacase can't be None. Please ensure the configuration specifies a valid DataCase."
            )
        if raw_user_data is None:
            self.from_key, self.to_key = self._get_translation_keys()
        else:
            self.from_key, self.to_key = self._get_user_translation_keys(
                raw_user_data=raw_user_data
            )
        process_function = self._get_process_function(datacase=datacase)
        if process_function:
            return process_function(raw_user_data=raw_user_data)
        else:
            raise ValueError(f"Unsupported data case: {datacase}")

    def _get_process_function(self, datacase: DataCase) -> Any:
        """Returns the appropriate processing function based on the data case.

        Args:
            datacase: The DataCase enum value representing the current data case.

        Returns:
            A callable function that performs the preprocessing for the given data case,
            or None if the data case is not supported.
        """
        process_map = {
            DataCase.MULTI_SINGLE_CELL: self._process_multi_single_cell,
            DataCase.MULTI_BULK: self._process_multi_bulk_case,
            DataCase.BULK_TO_BULK: self._process_multi_bulk_case,
            DataCase.SINGLE_CELL_TO_SINGLE_CELL: self._process_multi_single_cell,
            DataCase.IMG_TO_BULK: self._process_img_to_bulk_case,
            DataCase.SINGLE_CELL_TO_IMG: self._process_sc_to_img_case,
            DataCase.IMG_TO_IMG: self._process_img_to_img_case,
        }
        return process_map.get(datacase)

    def _process_data_case(
        self, data_package: DataPackage, modality_processors: Dict[Any, Any]
    ) -> Union[Dict[str, Dict[str, Union[Any, DataPackage]]], Dict[str, Any]]:
        """Processes the data package based on the provided modality processors.

        This method handles the common preprocessing steps for different data cases,
        including splitting the data package, removing NaNs, and applying
        modality-specific processors.

        Args::
            data_package: The DataPackage object to be processed.
            modality_processors: A dictionary mapping modality keys (e.g., 'multi_sc', 'from_modality')
                to callable processor functions that will be applied to the corresponding modality data.

        Returns:
            A dictionary containing processed DataPackage objects for each data split.
        """
        if self.predict_new_data:
            # we get the data modality keys from this structure in postsplit processing
            # for predict_new data there do not exits real splits, because all is "test" data
            # but the preprocessing code expects this splits, so we mock them
            # use train, because processing logic expects train split
            mock_split: Dict[str, Dict[str, Union[Any, DataPackage]]] = {
                "test": {
                    "data": data_package,
                    "indices": {"paired": np.array([])},
                },
                "valid": {"data": None, "indices": {"paired": np.array([])}},
                "train": {"data": data_package, "indices": {"paired": np.array([])}},
            }
            if self.config.skip_preprocessing:
                return mock_split

            clean_package = self._remove_nans(data_package=data_package)
            mock_split["test"]["data"] = clean_package
            for modality_key, (
                presplit_processor,
                postsplit_processor,
            ) in modality_processors.items():
                modality_data = clean_package[modality_key]
                if modality_data:
                    processed_modality_data = presplit_processor(modality_data)
                    # mock the split
                    clean_package[modality_key] = processed_modality_data
                    mock_split["test"]["data"] = clean_package
                    mock_split = postsplit_processor(mock_split)
            return mock_split
        # normal case without new data -----------------------------------
        if self.config.skip_preprocessing:
            split_packages, _ = self._split_data_package(data_package=data_package)
            return split_packages
        clean_package = self._remove_nans(data_package=data_package)
        for modality_key, (presplit_processor, _) in modality_processors.items():
            modality_data = clean_package[modality_key]
            if modality_data:
                processed_modality_data = presplit_processor(modality_data)
                clean_package[modality_key] = processed_modality_data
        split_packages, indices = self._split_data_package(data_package=clean_package)
        processed_splits = {}
        for modality_key, (_, postsplit_processor) in modality_processors.items():
            split_packages = postsplit_processor(split_packages)
        for split_name, split_package in split_packages.items():
            split_indices = {
                name: {
                    split: idx
                    for split, idx in indices[name].items()
                    if split == split_name
                }
                for name in indices.keys()
            }
            processed_splits[split_name] = {
                "data": split_package["data"],
                "indices": split_indices,
            }
        return processed_splits

    def _process_multi_single_cell(
        self, raw_user_data: Optional[DataPackage] = None
    ) -> Dict[str, Dict[str, Union[Any, DataPackage]]]:
        """Process MULTI_SINGLE_CELL case

        Reads multi-single-cell data, performs data splitting, NaN removal,
        and applies single-cell specific filtering.
        Args:
            raw_user_data: Optional DataPackage containing user-provided data.

        Returns:
            A dictionary containing processed DataPackage objects for each data split.
        Raises:
            ValueError: If multi_sc in data_package is None.

        """
        if raw_user_data is None:
            screader = self.data_readers[DataCase.MULTI_SINGLE_CELL]  # type: ignore

            mudata = screader.read_data(config=self.config)
            data_package: DataPackage = DataPackage()
            data_package.multi_sc = mudata
        else:
            data_package = raw_user_data
        if self.config.requires_paired:
            common_ids = data_package.get_common_ids()
            if data_package.multi_sc is None:
                raise ValueError("multi_sc in data_package is None")
            data_package.multi_sc = {
                "multi_sc": data_package.multi_sc["multi_sc"][common_ids]
            }

        def presplit_processor(modality_data: Any) -> Any:
            if modality_data is None:
                return modality_data
            sc_filter = SingleCellFilter(
                data_info=self.config.data_config.data_info, config=self.config
            )
            return sc_filter.presplit_processing(multi_sc=modality_data)

        def postsplit_processor(
            split_data: Dict[str, Dict[str, Any]],
        ) -> Dict[str, Dict[str, Any]]:
            return self._postsplit_multi_single_cell(
                split_data=split_data, datapackage_key="multi_sc"
            )

        return self._process_data_case(
            data_package,
            modality_processors={"multi_sc": (presplit_processor, postsplit_processor)},
        )

    def _process_multi_bulk_case(
        self,
        raw_user_data: Optional[DataPackage] = None,
    ) -> Dict[str, Dict[str, Union[Any, DataPackage]]]:
        """
        Process MULTI_BULK case.

        Reads multi-bulk data, performs data splitting, NaN removal,
        and applies filtering and scaling to bulk dataframes.
        Args:
            raw_user_data: Optional DataPackage containing user-provided data.

        Returns:
            A dictionary containing processed DataPackage objects for each data split.
        """
        if raw_user_data is None:
            bulkreader = self.data_readers[DataCase.MULTI_BULK]
            bulk_dfs, annotation = bulkreader.read_data()

            data_package = DataPackage(multi_bulk=bulk_dfs, annotation=annotation)
        else:
            data_package = raw_user_data
        if self.config.requires_paired:
            common_ids = data_package.get_common_ids()
            unpaired_data = data_package.multi_bulk
            unpaired_anno = data_package.annotation
            if unpaired_anno is None:
                raise ValueError("annotation attribute of datapackge cannot be None")
            if unpaired_data is None:
                raise ValueError("multi_bulk attribute of datapackge cannot be None")
            data_package.multi_bulk = {
                k: v.loc[common_ids] for k, v in unpaired_data.items()
            }

            data_package.annotation = {
                k: v.loc[common_ids]  # ty: ignore
                for k, v in unpaired_anno.items()  # ty: ignore
            }

        def presplit_processor(
            modality_data: Dict[str, Union[pd.DataFrame, None]],
        ) -> Dict[str, Union[pd.DataFrame, None]]:
            """For the multi_bulk modality we perform all operations after splitting at the moment."""
            return modality_data

        def postsplit_processor(
            split_data: Dict[str, Dict[str, Any]],
        ) -> Dict[str, Dict[str, Any]]:
            return self._postsplit_multi_bulk(split_data=split_data)

        return self._process_data_case(
            data_package,
            modality_processors={
                "multi_bulk": (presplit_processor, postsplit_processor)
            },
        )

    def _calc_k_filter(
        self, i: int, remainder: int, base_features: int
    ) -> Optional[int]:
        if self.config.k_filter is None:
            return None
        extra = 1 if i < remainder else 0
        return base_features + extra

    def _postsplit_multi_single_cell(
        self,
        split_data: Dict[str, Dict[str, Any]],
        datapackage_key: str = "multi_sc",
        modality_key: Optional[str] = None,
    ) -> Dict[str, Dict[str, Any]]:
        """Post-split processing for multi-single-cell data.
        This method applies filtering and scaling to the single-cell data after it has been split.
        Now supports multiple MuData objects in the input dictionary.

        Args:
            split_data: A dictionary containing the split data for each data split.
            datapackage_key: The key in the DataPackage that contains the multi-single-cell data.
            modality_key: Optional specific modality key for backward compatibility.
                        If provided, only processes that specific modality.
                        If None, processes all modalities in the dictionary.

        Returns:
            A dictionary containing processed DataPackage objects for each data split.

        Raises:
            ValueError: If the train split data is None.
        """
        processed_splits: Dict[str, Dict[str, Any]] = {}
        train_split: Optional[Dict[str, Any]] = split_data.get("train")

        if train_split is None:
            raise ValueError(
                "Train split data is None. Ensure that the data package contains valid train data."
            )

        train_data: Optional[Any] = train_split.get("data")
        if train_data is None:
            raise ValueError(
                "Train split data is None. Ensure that the data package contains valid train data."
            )

        # Get all modality keys from the train data
        mudata_dict = train_data[datapackage_key]

        if modality_key is not None:
            if modality_key not in mudata_dict:
                raise ValueError(
                    f"Specified modality_key '{modality_key}' not found in {list(mudata_dict.keys())}"
                )
            modality_keys = [modality_key]
            print(
                f"Processing single modality (backward compatibility): {modality_key}"
            )
        else:
            modality_keys = list(mudata_dict.keys())
            print(f"Processing {len(modality_keys)} MuData objects: {modality_keys}")

        # Initialize storage for scalers and gene filters for each modality
        # if we do this for the first time, we need a train split and we dont
        # fitted any scalers or features to keep yet.
        # that's why in the predict_new case we can keep the mocksplit for train None
        # because we never get in this if
        if (
            self.sc_scalers is None
            and self.sc_genes_to_keep is None
            and self.sc_general_genes_to_keep is None
        ) or ("modality" in datapackage_key):
            # Process each MuData object in the train split
            processed_mudata_dict = {}
            all_scalers = {}
            all_sc_genes_to_keep = {}
            all_general_genes_to_keep = {}

            for current_modality_key in modality_keys:
                print(f"Processing train modality: {current_modality_key}")

                sc_filter = SingleCellFilter(
                    data_info=self.config.data_config.data_info, config=self.config
                )

                # Single-cell specific filtering
                filtered_train, sc_genes_to_keep = sc_filter.sc_postsplit_processing(
                    mudata=mudata_dict[current_modality_key]
                )

                # General post-processing
                processed_train, general_genes_to_keep, scalers = (
                    sc_filter.general_postsplit_processing(
                        mudata=filtered_train, scaler_map=None, gene_map=None
                    )
                )

                # Store processed data and filters for this modality
                processed_mudata_dict[current_modality_key] = processed_train
                all_scalers[current_modality_key] = scalers
                all_sc_genes_to_keep[current_modality_key] = sc_genes_to_keep
                all_general_genes_to_keep[current_modality_key] = general_genes_to_keep

            # Store all scalers and gene filters
            self.sc_scalers = all_scalers
            self.sc_genes_to_keep = all_sc_genes_to_keep
            self.sc_general_genes_to_keep = all_general_genes_to_keep

            # Update train data with processed MuData objects
            train_data[datapackage_key] = processed_mudata_dict

        else:
            # Use existing scalers and gene filters
            all_scalers = self.sc_scalers  # type: ignore
            all_sc_genes_to_keep = self.sc_genes_to_keep  # type: ignore
            all_general_genes_to_keep = self.sc_general_genes_to_keep  # type: ignore

        # Store processed train split
        processed_splits["train"] = {
            "data": train_data,
            "indices": split_data["train"]["indices"],
        }

        # Process other splits (val, test, etc.)
        for split, split_package in split_data.items():
            if split == "train":
                continue

            data_package = split_package["data"]
            if data_package is None:
                processed_splits[split] = split_package
                continue

            print(f"Processing {split} split")
            processed_mudata_dict = {}

            # Process each MuData object in this split
            for current_modality_key in modality_keys:
                print(f"Processing {split} modality: {current_modality_key}")

                sc_filter = SingleCellFilter(
                    data_info=self.config.data_config.data_info, config=self.config
                )

                # Apply single-cell filtering using train-derived gene map
                filtered_sc_data, _ = sc_filter.sc_postsplit_processing(
                    mudata=data_package[datapackage_key][current_modality_key],
                    gene_map=all_sc_genes_to_keep[current_modality_key],
                )

                # Apply general processing using train-derived scalers and gene map
                processed_general_data, _, _ = sc_filter.general_postsplit_processing(
                    mudata=filtered_sc_data,
                    gene_map=all_general_genes_to_keep[current_modality_key],
                    scaler_map=all_scalers[current_modality_key],
                )

                processed_mudata_dict[current_modality_key] = processed_general_data

            # Update data package with all processed MuData objects
            data_package[datapackage_key] = processed_mudata_dict

            processed_splits[split] = {
                "data": data_package,
                "indices": split_package["indices"],
            }

        return processed_splits

    def _postsplit_multi_bulk(
        self,
        split_data: Dict[str, Dict[str, Any]],
        datapackage_key: str = "multi_bulk",
    ) -> Dict[str, Dict[str, Any]]:
        """Post-split processing for multi-bulk data.

        This method applies filtering and scaling to the bulk dataframes after they have been split.

        Args:
            split_data: A dictionary containing the split data for each data split.
            datapackage_key: The key in the DataPackage that contains the multi-bulk data.
        Returns:
            A dictionary containing processed DataPackage objects for each data split.
        Raises:
            ValueError: If the train split data is None.
        """

        train_split: Optional[Dict[str, Any]] = split_data.get("train")
        if train_split is None:
            raise ValueError(
                "Train split data is None. Ensure that the data package contains valid train data."
            )
        train_data: Optional[Any] = train_split.get("data")
        genes_to_keep_map: Dict[str, List[str]] = {}
        scalers: Dict[str, Any] = {}
        processed_splits: Dict[str, Dict[str, Any]] = {}

        if (self.bulk_scalers is None and self.bulk_genes_to_keep is None) or (
            "modality" in datapackage_key
        ):
            if train_data is None:
                raise ValueError(
                    "Train split data is None. Ensure that the data package contains valid train data."
                )
            n_modalities: int = len(train_data[datapackage_key].keys())
            remainder: int = 0
            base_features = 0
            if self.config.k_filter is not None:
                base_features = self.config.k_filter // n_modalities
                remainder = self.config.k_filter % n_modalities

            # Get valid modality keys (those that are not None)
            modality_keys = [
                k for k, v in train_data[datapackage_key].items() if v is not None
            ]

            for i, k in enumerate(modality_keys):
                v = train_data[datapackage_key][k]
                cur_k_filter = self._calc_k_filter(
                    i=i, base_features=base_features, remainder=remainder
                )
                self.config.data_config.data_info[k].k_filter = cur_k_filter

                data_processor = DataFilter(
                    data_info=self.config.data_config.data_info[k],
                    config=self.config,
                    ontologies=self._ontologies,
                )
                filtered_df, genes_to_keep = data_processor.filter(df=v)
                scaler = data_processor.fit_scaler(df=filtered_df)
                genes_to_keep_map[k] = genes_to_keep
                scalers[k] = scaler
                scaled_df = data_processor.scale(df=filtered_df, scaler=scaler)
                train_data[datapackage_key][k] = scaled_df
                # Check if indices stayed the same after filtering
                if not filtered_df.index.equals(v.index):
                    mismatched_indices = filtered_df.index.symmetric_difference(v.index)
                    raise ValueError(
                        f"Indices mismatch after filtering for modality {k}. "
                        f"Mismatched indices: {mismatched_indices}. "
                        "Ensure filtering does not alter the indices."
                    )

            self.bulk_scalers = scalers
            self.bulk_genes_to_keep = genes_to_keep_map  # type: ignore
        else:
            scalers, genes_to_keep_map = self.bulk_scalers, self.bulk_genes_to_keep  # type: ignore

        processed_splits["train"] = {
            "data": train_data,
            "indices": split_data["train"]["indices"],
        }

        for split_name, split_package in split_data.items():
            if split_name == "train":
                continue
            if split_package["data"] is None:
                processed_splits[split_name] = split_data[split_name]
                continue

            processed_package = split_package["data"]
            for k, v in processed_package[datapackage_key].items():
                if v is None:
                    continue
                data_processor = DataFilter(
                    data_info=self.config.data_config.data_info[k],
                    config=self.config,
                    ontologies=self._ontologies,
                )
                filtered_df, _ = data_processor.filter(
                    df=v, genes_to_keep=genes_to_keep_map[k]
                )
                scaled_df = data_processor.scale(df=filtered_df, scaler=scalers[k])
                processed_package[datapackage_key][k] = scaled_df
                if not filtered_df.index.equals(v.index):
                    raise ValueError(
                        f"Indices mismatch after filtering for modality {k}. "
                        "Ensure filtering does not alter the indices."
                    )

            processed_splits[split_name] = {
                "data": processed_package,
                "indices": split_package["indices"],
            }

        return processed_splits

    def _process_img_to_bulk_case(
        self, raw_user_data: Optional[DataPackage] = None
    ) -> Dict[str, Dict[str, Union[Any, DataPackage]]]:
        """Process IMG_TO_BULK case

        Reads image and bulk data, prepares from/to modalities (IMG->BULK or BULK->IMG),
        performs data splitting, NaN removal, and applies normalization to image data
        and filtering/scaling to bulk dataframes.
        Args:
            raw_user_data: Optional DataPackage containing user-provided data.
                If provided, the data reading step is skipped.
        Returns:
            A dictionary containing processed DataPackage objects for each data split.
        Raises:
            TypeError: If from_key or to_key is None, indicating that translation keys must be specified.
        """

        if raw_user_data is None:
            bulkreader = self.data_readers[DataCase.IMG_TO_BULK]["bulk"]
            imgreader = self.data_readers[DataCase.IMG_TO_BULK]["img"]

            bulk_dfs, annotation_bulk = bulkreader.read_data()
            images, annotation_img = imgreader.read_data(config=self.config)

            annotation = {**annotation_bulk, **annotation_img}

            data_package = DataPackage(
                multi_bulk=bulk_dfs, img=images, annotation=annotation
            )

        else:
            data_package = raw_user_data

        if self.config.requires_paired:
            common_ids = data_package.get_common_ids()

            images = data_package.img
            if images is None:
                raise ValueError("Images cannot be None")
            data_package.img = {
                k: self.filter_imgdata_list(img_list=v, ids=common_ids)
                for k, v in images.items()
            }
            unpaired_data = data_package.multi_bulk
            unpaired_anno = data_package.annotation
            if unpaired_anno is None:
                raise ValueError("annotation attribute of datapackge cannot be None")
            if unpaired_data is None:
                raise ValueError("multi_bulk attribute of datapackge cannot be None")
            data_package.multi_bulk = {
                k: v.loc[common_ids] for k, v in unpaired_data.items()
            }

            data_package.annotation = {
                k: v.loc[common_ids]  # ty: ignore
                for k, v in unpaired_anno.items()  # ty: ignore
            }

        def presplit_processor(
            modality_data: Dict[str, Union[pd.DataFrame, List[ImgData]]],
        ) -> Dict[str, Union[pd.DataFrame, List[ImgData]]]:
            for modality_key, data in modality_data.items():
                if self._is_image_data(data=data):
                    modality_data[modality_key] = self._normalize_image_data(
                        images=data,  # type: ignore
                        info_key=modality_key,  # type: ignore
                    )
            # we don't need to filter bulk data here
            # because we do it in the postsplit step
            return modality_data

        def postsplit_processor(
            split_data: Dict[str, Dict[str, Any]], datapackage_key: str
        ) -> Dict[str, Dict[str, Any]]:
            if datapackage_key == "multi_bulk":
                return self._postsplit_multi_bulk(
                    split_data=split_data, datapackage_key=datapackage_key
                )
            return split_data  # for img data we don't need to do anything

        return self._process_data_case(
            data_package,
            modality_processors={
                "multi_bulk": (  # TODO change to multi_bulk and img for all translation cases and ajdust processors accordingly
                    lambda data: presplit_processor(modality_data=data),
                    lambda data: postsplit_processor(
                        split_data=data, datapackage_key="multi_bulk"
                    ),
                ),
                "img": (
                    lambda data: presplit_processor(modality_data=data),
                    lambda data: postsplit_processor(
                        split_data=data, datapackage_key="img"
                    ),
                ),
            },
        )

    def _process_sc_to_img_case(
        self, raw_user_data: Optional[DataPackage] = None
    ) -> Dict[str, Dict[str, Union[Any, DataPackage]]]:
        """Process SC_TO_IMG case.

        Reads single-cell and image data, prepares from/to modalities (SC->IMG or IMG->SC),
        performs data splitting, NaN removal, and applies single-cell specific filtering
        to single-cell data and normalization to image data.

        Args:
            raw_user_data: Optional DataPackage containing user-provided data.

        Returns:
            A dictionary containing processed DataPackage objects for each data split.
        """
        if raw_user_data is None:
            screader = self.data_readers[DataCase.SINGLE_CELL_TO_IMG]["sc"]
            imgreader = self.data_readers[DataCase.SINGLE_CELL_TO_IMG]["img"]

            # only one mudata type in this case we know this
            mudata_dict = screader.read_data(config=self.config)
            images, annotation = imgreader.read_data(config=self.config)

            data_package = DataPackage(
                multi_sc=mudata_dict, img=images, annotation=annotation
            )
        else:
            data_package = raw_user_data
        if self.config.requires_paired:
            common_ids = data_package.get_common_ids()
            if data_package.multi_sc is None:
                raise ValueError("multi_sc in data_package is None")
            data_package.multi_sc = {
                "multi_sc": data_package.multi_sc["multi_sc"][common_ids]
            }
            images = data_package.img
            if images is None:
                raise ValueError("Images cannot be None")
            data_package.img = {
                k: self.filter_imgdata_list(img_list=v, ids=common_ids)
                for k, v in images.items()
            }

        def presplit_processor(
            modality_data: Dict[str, Union[Any, List[ImgData]]],
        ) -> Dict[str, Union[Any, List[ImgData]]]:
            was_image = False
            for modality_key, data in modality_data.items():
                if self._is_image_data(data=data):
                    was_image = True
                    modality_data[modality_key] = self._normalize_image_data(
                        images=data,  # type: ignore
                        info_key=modality_key,  # type: ignore
                    )

            if was_image:
                return modality_data
            else:
                sc_filter = SingleCellFilter(
                    data_info=self.config.data_config.data_info, config=self.config
                )
                return sc_filter.presplit_processing(multi_sc=modality_data)

        def postsplit_processor(
            split_data: Dict[str, Dict[str, Any]], datapackage_key: str
        ) -> Dict[str, Dict[str, Any]]:
            if datapackage_key == "multi_sc":
                return self._postsplit_multi_single_cell(
                    split_data=split_data, datapackage_key=datapackage_key
                )
            # No postsplit processing needed for image data
            return split_data

        return self._process_data_case(
            data_package,
            modality_processors={
                "multi_sc": (
                    lambda data: presplit_processor(modality_data=data),
                    lambda data: postsplit_processor(
                        split_data=data, datapackage_key="multi_sc"
                    ),
                ),
                "img": (
                    lambda data: presplit_processor(modality_data=data),
                    lambda data: postsplit_processor(
                        split_data=data, datapackage_key="img"
                    ),
                ),
            },
        )

    def _process_img_to_img_case(
        self, raw_user_data: Optional[DataPackage] = None
    ) -> Dict[str, DataPackage]:
        """Process IMG_TO_IMG case.

        Reads image data for from/to modalities, performs data splitting,
        NaN removal, and applies normalization to both from and to image data.

        Args:
            raw_user_data: Optional DataPackage containing user-provided data.
                If provided, the data reading step is skipped.
        Returns:
            A dictionary containing processed DataPackage objects for each data split.
        Raises:
            TypeError: If from_key or to_key is None, indicating that translation keys must be specified.
        """
        if raw_user_data is None:
            imgreader = self.data_readers[DataCase.IMG_TO_IMG]
            images, annotation = imgreader.read_data(config=self.config)

            data_package = DataPackage(img=images, annotation=annotation)
        else:
            data_package = raw_user_data

        if self.config.requires_paired:
            common_ids = data_package.get_common_ids()

            images = data_package.img
            if images is None:
                raise ValueError("Images cannot be None")
            data_package.img = {
                k: self.filter_imgdata_list(img_list=v, ids=common_ids)
                for k, v in images.items()
            }

        def presplit_processor(modality_data: Dict[str, List]) -> Dict[str, List]:
            """Processes img-to-img modality data with normalization for images."""
            print("calling normalize image in _process_ing_to_img_case")
            return {
                k: self._normalize_image_data(v, k) for k, v in modality_data.items()
            }

        def postsplit_processor(
            split_data: Dict[str, Dict[str, Any]],
        ) -> Dict[str, Dict[str, Any]]:
            """No postsplit processing needed for image data."""
            return split_data

        return self._process_data_case(
            data_package,
            modality_processors={
                "img": (
                    lambda data: presplit_processor(
                        data,
                    ),
                    postsplit_processor,
                ),
            },
        )

    # This method would be inside your GeneralPreprocessor or a similar class
    def _split_data_package(
        self, data_package: DataPackage
    ) -> Tuple[Dict[str, Optional[Dict[str, Any]]], Dict[str, Any]]:
        """Splits a data package into train/validation/test sets.

        This method first uses PairedUnpairedSplitter to generate a single,
        synchronized set of indices for all modalities. It then uses
        DataPackageSplitter to apply these indices to the data.

        Args:
            data_package: The DataPackage to be split.

        Returns:
            A tuple containing:
            1. A dictionary of the split DataPackages.
            2. A dictionary of the synchronized integer indices used for the split.
        """
        pairing_splitter = PairedUnpairedSplitter(
            data_package=data_package, config=self.config
        )
        split_indices_config = pairing_splitter.split()
        data_package_splitter = DataPackageSplitter(
            data_package=data_package,
            config=self.config,
            indices=split_indices_config,
        )
        split_datasets = data_package_splitter.split()
        return split_datasets, split_indices_config

    def _is_image_data(self, data: Any) -> bool:
        """Check if data is image data.

        Determines if the provided data is a list of objects that are considered
        image data based on having an 'img' attribute.

        Args:
            data: The data to check.

        Returns:
            True if the data is image data, False otherwise.
        """
        if data is None:
            return False
        if isinstance(data, list) and hasattr(data[0], "img"):
            return True
        return False

    def _remove_nans(self, data_package: DataPackage) -> DataPackage:
        """Remove NaN values from the data package.

        Utilizes NaNRemover to identify and remove rows containing NaN values
        in relevant annotation columns within the DataPackage.

        Args:
            data_package: The DataPackage from which to remove NaNs.

        Returns:
            The DataPackage with NaN values removed.
        """
        nanremover = NaNRemover(
            config=self.config,
        )
        return nanremover.remove_nan(data=data_package)

    def _normalize_image_data(self, images: List, info_key: str) -> List:
        """Process images with normalization.

        Normalizes a list of image data objects using ImageNormalizer based on
        the scaling method specified in the configuration for the given info_key.

        Args:
            images: A list of image data objects (each having an 'img' attribute).
            info_key: The key referencing data information in the configuration to get the scaling method.

        Returns:
            A list of processed image data objects with normalized image data.
        """

        scaling_method = self.config.data_config.data_info[info_key].scaling
        if scaling_method == "NOTSET":
            scaling_method = self.config.scaling
        processed_images = []
        normalizer = ImageNormalizer()  # Instance created once here

        for img in images:
            img.img = normalizer.normalize_image(  # Modify directly
                image=img.img, method=scaling_method
            )
            processed_images.append(img)

        return processed_images

    def _get_translation_keys(self) -> Tuple[Optional[str], Optional[str]]:
        """
        Extract from and to keys from config.

        Retrieves the 'from' and 'to' modality keys from the data configuration
        based on the 'translate_direction' setting.

        Returns:
            A tuple containing the from_key and to_key as strings, or None if not found.

        Raises:
            ValueError: If neither 'from' nor 'to' keys are found in the data configuration.
            TypeError: If the translate_direction is not set for the data_info.
        """
        from_key, to_key = None, None
        for k, v in self.config.data_config.data_info.items():
            if v.translate_direction is None:
                continue
            if v.translate_direction == "from":
                from_key = k
            if v.translate_direction == "to":
                to_key = k
        return from_key, to_key

    def _get_user_translation_keys(self, raw_user_data: DataPackage):
        if len(raw_user_data.from_modality) == 0:  # type: ignore
            return None, None
        elif len(raw_user_data.to_modality) == 0:  # type: ignore
            return None, None
        else:
            if raw_user_data.from_modality is None or raw_user_data.to_modality is None:
                raise TypeError(
                    "from_modality and to_modality cannot be None for Translation"
                )
            try:
                return next(iter(raw_user_data.from_modality.keys())), next(
                    iter(raw_user_data.to_modality.keys())
                )
            except Exception as e:
                print("error getting from or to keys")
                print(e)
                print("returning None")
                return None, None

    @abstractmethod
    def format_reconstruction(
        self, reconstruction: Dict[str, torch.Tensor], result: Optional[Result] = None
    ) -> DataPackage:
        pass

    def filter_imgdata_list(self, img_list, ids):
        filtered = []
        for imgdata in img_list:
            if imgdata.sample_id in ids:
                filtered.append(imgdata)
        return filtered

__init__(config, ontologies=None)

Initializes the BasePreprocessor with a configuration object.

Args

config: A DefaultConfig object containing preprocessing configurations. ontologies: Ontology information, if provided for Ontix.

Source code in src/autoencodix/base/_base_preprocessor.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(
    self,
    config: DefaultConfig,
    ontologies: Optional[Union[Tuple[Any, Any], Dict[Any, Any]]] = None,
):
    """Initializes the BasePreprocessor with a configuration object.

    Args :
        config: A DefaultConfig object containing preprocessing configurations.
        ontologies: Ontology information, if provided for Ontix.
    """
    self.config = config
    self._dataset_container: Optional[DatasetContainer] = None
    self.processed_data = Dict[str, Dict[str, Union[Any, DataPackage]]]
    self.bulk_genes_to_keep: Optional[Dict[str, List[str]]] = None
    self.bulk_scalers: Optional[Dict[str, Any]] = None
    self.sc_genes_to_keep: Optional[Dict[str, List[str]]] = None
    self.sc_scalers: Optional[Dict[str, Dict[str, Any]]] = None
    self.sc_general_genes_to_keep: Optional[Dict[str, List]] = None
    self._ontologies: Optional[Union[Tuple[Any, Any], Dict[Any, Any]]] = ontologies
    self.data_readers: Dict[Enum, Any] = {
        DataCase.MULTI_SINGLE_CELL: SingleCellDataReader(),
        DataCase.MULTI_BULK: BulkDataReader(config=self.config),
        DataCase.BULK_TO_BULK: BulkDataReader(config=self.config),
        DataCase.SINGLE_CELL_TO_SINGLE_CELL: SingleCellDataReader(),
        DataCase.IMG_TO_BULK: {
            "bulk": BulkDataReader(config=self.config),
            "img": ImageDataReader(config=self.config),
        },
        DataCase.SINGLE_CELL_TO_IMG: {
            "sc": SingleCellDataReader(),
            "img": ImageDataReader(config=self.config),
        },
        DataCase.IMG_TO_IMG: ImageDataReader(config=self.config),
    }

preprocess(raw_user_data=None, predict_new_data=False) abstractmethod

To be implemented by subclasses for specific preprocessing steps. Args: raw_user_data: Users can provide raw data. This is an alternative way of providing data via filepaths in the config. If this param is passed, we skip the data reading step. predict_new_data: Indicates whether the user wants to predict with unseen data. If this is the case, we don't split the data and only prerpocess.

Source code in src/autoencodix/base/_base_preprocessor.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@abc.abstractmethod
def preprocess(
    self,
    raw_user_data: Optional[DataPackage] = None,
    predict_new_data: bool = False,
) -> DatasetContainer:
    """To be implemented by subclasses for specific preprocessing steps.
    Args:
        raw_user_data: Users can provide raw data. This is an alternative way of
            providing data via filepaths in the config. If this param is passed, we skip the data reading step.
        predict_new_data: Indicates whether the user wants to predict with unseen data.
            If this is the case, we don't split the data and only prerpocess.
    """
    pass

BaseTrainer

Bases: ABC

General training logic for all autoencoder models.

This class sets up the model, optimizer, and data loaders. It also handles reproducibility and model-specific configurations. Subclasses must implement model training and prediction logic.

Attributes:

Name Type Description
_trainset

The dataset used for training.

_validset

The dataset used for validation, if provided.

_result

An object to store and manage training results.

_config

Configuration object containing training hyperparameters and settings.

_model_type

The autoencoder model class to be trained.

_loss_fn

Instantiated loss function specific to the model.

_trainloader

DataLoader for the training dataset.

_validloader

DataLoader for the validation dataset, if provided.

_model

The instantiated model architecture.

_optimizer

The optimizer used for training.

_fabric

Lightning Fabric wrapper for device and precision management.

ontologies

Ontology information, if provided for Ontix

Source code in src/autoencodix/base/_base_trainer.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
class BaseTrainer(abc.ABC):
    """General training logic for all autoencoder models.

    This class sets up the model, optimizer, and data loaders. It also handles
    reproducibility and model-specific configurations. Subclasses must implement
    model training and prediction logic.

    Attributes:
        _trainset: The dataset used for training.
        _validset: The dataset used for validation, if provided.
        _result: An object to store and manage training results.
        _config: Configuration object containing training hyperparameters and settings.
        _model_type: The autoencoder model class to be trained.
        _loss_fn: Instantiated loss function specific to the model.
        _trainloader: DataLoader for the training dataset.
        _validloader: DataLoader for the validation dataset, if provided.
        _model: The instantiated model architecture.
        _optimizer: The optimizer used for training.
        _fabric: Lightning Fabric wrapper for device and precision management.
        ontologies: Ontology information, if provided for Ontix
    """

    def __init__(
        self,
        trainset: Optional[BaseDataset],
        validset: Optional[BaseDataset],
        result: Result,
        config: DefaultConfig,
        model_type: Type[BaseAutoencoder],
        loss_type: Type[BaseLoss],
        ontologies: Optional[
            Union[Tuple, List]
        ] = None,  # Addition to Varix, mandotory for Ontix
        **kwargs,
    ):
        self._trainset = trainset
        self._model_type = model_type
        self._validset = validset
        self._result = result
        self._config = config
        self._loss_type = loss_type
        self.ontologies = ontologies
        self.setup_trainer()

    def setup_trainer(self, old_model=None):
        if old_model is None:
            self._input_validation()
            self._init_loaders()

        self._loss_fn = self._loss_type(config=self._config)

        self._handle_reproducibility()
        # Internal data handling
        self._model: BaseAutoencoder
        self._fabric = Fabric(
            accelerator=self._config.device,
            devices=self._config.n_gpus,
            precision=self._config.float_precision,
            strategy=self._config.gpu_strategy,
        )

        self._fabric.launch()
        self._setup_fabric(old_model=old_model)

        self._n_cpus = os.cpu_count()
        if self._n_cpus is None:
            self._n_cpus = 0

    def _setup_fabric(self, old_model=None):
        """
        Sets up the model, optimizer, and data loaders with Lightning Fabric.
        """
        self._init_model_architecture(
            ontologies=self.ontologies, old_model=old_model
        )  # Ontix

        self._optimizer = torch.optim.AdamW(
            params=self._model.parameters(),
            lr=self._config.learning_rate,
            weight_decay=self._config.weight_decay,
        )

        self._model, self._optimizer = self._fabric.setup(self._model, self._optimizer)
        if old_model is None:
            self._trainloader = self._fabric.setup_dataloaders(self._trainloader)  # type: ignore
            if self._validloader is not None:
                self._validloader = self._fabric.setup_dataloaders(self._validloader)  # type: ignore

    def _init_loaders(self):
        """Initializes the DataLoaders for training and validation datasets."""
        # g = torch.Generator()
        # g.manual_seed(self._config.global_seed)
        last_batch_is_one_sample = len(self._trainset) % self._config.batch_size == 1
        corrected_bs = (
            self._config.batch_size + 1
            if last_batch_is_one_sample
            else self._config.batch_size
        )
        if last_batch_is_one_sample:
            warnings.warn(
                f"increased batch_size to {corrected_bs} for trainset, to avoid dropping samples and having batches (makes trainingdynamics messy with missing samples per epoch) of size one (fails for Models with BachNorm)"
            )

        self._trainloader = DataLoader(
            cast(BaseDataset, self._trainset),
            shuffle=True,
            batch_size=corrected_bs,
            worker_init_fn=self._seed_worker,
            # generator=g,
        )
        if self._validset:
            last_batch_is_one_sample = (
                len(self._validset) % self._config.batch_size == 1
            )
            corrected_bs = (
                self._config.batch_size + 1
                if last_batch_is_one_sample
                else self._config.batch_size
            )
            if last_batch_is_one_sample:
                warnings.warn(
                    f"increased batch_size to {corrected_bs} for validset, to avoid dropping samples and having batches (makes trainingdynamics messy with missing samples per epoch) of size one (fails for Models with BachNorm)"
                )

            self._validloader = DataLoader(
                dataset=self._validset,
                batch_size=self._config.batch_size,
                shuffle=False,
            )
        else:
            self._validloader = None  # type: ignore

    def _input_validation(self) -> None:
        if self._trainset is None:
            raise ValueError(
                "Trainset cannot be None. Check the indices you provided with a custom split or be sure that the train_ratio attribute of the config is >0."
            )
        if not isinstance(self._trainset, BaseDataset):
            raise TypeError(
                f"Expected train type to be an instance of BaseDataset, got {type(self._trainset)}."
            )
        if self._validset is None:
            print("training without validation")
        elif not isinstance(self._validset, BaseDataset):
            raise TypeError(
                f"Expected valid type to be an instance of BaseDataset, got {type(self._validset)}."
            )
        if self._config is None:
            raise ValueError("Config cannot be None.")

    def _handle_reproducibility(self) -> None:
        """Sets all relevant seeds for reproducibility

        Raises:
            NotImplementedError: If the device is set to "mps" (Apple Silicon).
        """
        if self._config.reproducible:
            torch.use_deterministic_algorithms(True)
            torch.manual_seed(seed=self._config.global_seed)
            random.seed(self._config.global_seed)
            np.random.seed(self._config.global_seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed(seed=self._config.global_seed)
                torch.cuda.manual_seed_all(seed=self._config.global_seed)
                torch.backends.cudnn.deterministic = True
                torch.backends.cudnn.benchmark = False
            if torch.backends.mps.is_available():
                torch.mps.manual_seed(seed=self._config.global_seed)

                # torch.use_deterministic_algorithms(True, warn_only=True)

                print(
                    "Warning: MPS backend has limited support for deterministic algorithms. "
                    "Seeding is active, but full reproducibility is not guaranteed."
                )
            else:
                print(
                    f"Reproducibility settings for device {self._config.device} are not implemented or necessary i.e. for cpu."
                )

    def _seed_worker(self, worker_id):
        worker_seed = self._config.global_seed + worker_id
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    def _init_model_architecture(self, ontologies: tuple, old_model=None) -> None:
        """Initializes the model architecture, based on the model type and input dimension."""
        if old_model:
            self._model = old_model
            return

        self._input_dim = cast(BaseDataset, self._trainset).get_input_dim()
        self.feature_order = self._trainset.feature_ids
        if ontologies is None:
            self._model = self._model_type(
                config=self._config, input_dim=self._input_dim
            )

        else:
            self._model = self._model_type(
                config=self._config,
                input_dim=self._input_dim,
                ontologies=ontologies,
                feature_order=self._trainset.feature_ids,  # type: ignore
            )

    def _should_checkpoint(self, epoch: int) -> bool:
        return (
            (epoch + 1) % self._config.checkpoint_interval == 0
            or epoch == self._config.epochs - 1
        )

    @abc.abstractmethod
    def train(self, epochs_overwrite: Optional[int] = None) -> Result:
        pass

    @abc.abstractmethod
    def decode(self, x: torch.Tensor) -> torch.Tensor:
        pass

    @abc.abstractmethod
    def predict(
        self, data: BaseDataset, model: Optional[torch.nn.Module] = None, **kwargs
    ) -> Result:
        pass

    @abc.abstractmethod
    def purge(self) -> None:
        """Cleans up any resources used during training, such as cached data or large attributes."""
        pass

purge() abstractmethod

Cleans up any resources used during training, such as cached data or large attributes.

Source code in src/autoencodix/base/_base_trainer.py
245
246
247
248
@abc.abstractmethod
def purge(self) -> None:
    """Cleans up any resources used during training, such as cached data or large attributes."""
    pass

BaseVisualizer

Bases: ABC

Defines the interface for visualizing training results.

Attributes:

Name Type Description
plots

A nested dictionary to store various plots.

Source code in src/autoencodix/base/_base_visualizer.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
class BaseVisualizer(abc.ABC):
    """Defines the interface for visualizing training results.

    Attributes:
        plots: A nested dictionary to store various plots.
    """

    def __init__(self):
        self.plots = nested_dict()

    def __setitem__(self, key, elem):
        self.plots[key] = elem

    ### Abstract Methods ###
    @abc.abstractmethod
    def visualize(self, result: Result, config: DefaultConfig) -> Result:
        pass

    @abc.abstractmethod
    def show_latent_space(
        self,
        result: Result,
        plot_type: str = "2D-scatter",
        labels: Optional[Union[list, pd.Series, None]] = None,
        param: Optional[Union[list, str]] = None,
        epoch: Optional[Union[int, None]] = None,
        split: str = "all",
    ) -> None:
        pass

    @abc.abstractmethod
    def show_weights(self) -> None:
        pass

    ### General Functions used by all Visualizers in similar way ###

    def show_loss(self, plot_type: str = "absolute") -> None:
        """
        Display the loss plot.
        Args:
            plot_type: Type of loss plot to display. Options are "absolute" or "relative". Options are
                "absolute" for the absolute loss plot and
                "relative" for the relative loss plot.
                Defaults to "absolute".
        Returns:
            None
        """
        if plot_type == "absolute":
            if "loss_absolute" not in self.plots.keys():
                print("Absolute loss plot not found in the plots dictionary")
                print(
                    "This happens, when you did not run visualize() or if you saved and loaded the model with `save_all=False`"
                )
            else:
                fig = self.plots["loss_absolute"]
                show_figure(fig)
                plt.show()
        if plot_type == "relative":
            if "loss_relative" not in self.plots.keys():
                print("Relative loss plot not found in the plots dictionary")

                print(
                    "This happens, when you did not run visualize() or if you saved and loaded the model with `save_all=False`"
                )
            else:
                fig = self.plots["loss_relative"]
                fig.show()
                # show_figure(fig)
                # plt.show()

        if plot_type not in ["absolute", "relative"]:
            print(
                "Type of loss plot not recognized. Please use 'absolute' or 'relative'"
            )

    def show_evaluation(
        self,
        param: str,
        metric: str,
        ml_alg: Optional[str] = None,
    ) -> None:
        """
        Displays the evaluation plot for a specific clinical parameter, metric, and optionally ML algorithm.
        Args:
            param: clinical parameter to visualize.
            metric: metric to visualize.
            ml_alg: ML algorithm to visualize. If None, plots all available algorithms.
        Returns:
            None
        """
        plt.ioff()
        if "ML_Evaluation" not in self.plots.keys():
            print("ML Evaluation plots not found in the plots dictionary")
            print("You need to run evaluate() method first")
            return None
        if param not in self.plots["ML_Evaluation"].keys():
            print(f"Parameter {param} not found in the ML Evaluation plots")
            print(f"Available parameters: {list(self.plots['ML_Evaluation'].keys())}")
            return None
        if metric not in self.plots["ML_Evaluation"][param].keys():
            print(f"Metric {metric} not found in the ML Evaluation plots for {param}")
            print(
                f"Available metrics: {list(self.plots['ML_Evaluation'][param].keys())}"
            )
            return None

        algs = list(self.plots["ML_Evaluation"][param][metric].keys())
        if ml_alg is not None:
            if ml_alg not in algs:
                print(f"ML algorithm {ml_alg} not found for {param} and {metric}")
                print(f"Available ML algorithms: {algs}")
                return None
            fig = self.plots["ML_Evaluation"][param][metric][ml_alg].figure
            show_figure(fig)
            plt.show()
        else:
            for alg in algs:
                print(f"Showing plot for ML algorithm: {alg}")
                fig = self.plots["ML_Evaluation"][param][metric][alg].figure
                show_figure(fig)
                plt.show()

    def save_plots(
        self, path: str, which: Union[str, list] = "all", format: str = "png"
    ) -> None:
        """
        Save specified plots to the given path in the specified format.

        Args:
            path: The directory path where the plots will be saved.
            which: A list of plot names to save or a string specifying which plots to save.
                   If 'all', all plots in the plots dictionary will be saved.
                   If a single plot name is provided as a string, only that plot will be saved.
            format: The file format in which to save the plots (e.g., 'png', 'jpg').

        Returns:
            None

        Raises:
            ValueError: If the 'which' parameter is not a list or a string.
        """
        if not os.path.exists(path):
            os.makedirs(path)

        if not isinstance(which, list):
            ## Case when which is a string
            if which == "all":
                ## Case when all plots are to be saved
                if len(self.plots) == 0:
                    print("No plots found in the plots dictionary")
                    print("You need to run  visualize() method first")
                else:
                    for item in nested_to_tuple(self.plots):
                        fig = item[-1]  ## Figure is in last element of the tuple
                        filename = "_".join(str(x) for x in item[0:-1])
                        fullpath = os.path.join(path, filename)
                        if hasattr(fig, "savefig"):
                            fig.savefig(f"{fullpath}.{format}")
                        elif hasattr(fig, "save"):  # for seaborn objects plots
                            fig.save(f"{fullpath}.{format}")
            else:
                ## Case when a single plot is provided as string
                if which not in self.plots.keys():
                    print(f"Plot {which} not found in the plots dictionary")
                    print(f"All available plots are: {list(self.plots.keys())}")
                else:
                    for item in nested_to_tuple(
                        self.plots[which]
                    ):  # Plot all epochs and splits of type which
                        fig = item[-1]  ## Figure is in last element of the tuple
                        filename = which + "_" + "_".join(str(x) for x in item[0:-1])  # type: ignore
                        fullpath = os.path.join(path, filename)
                        if hasattr(fig, "savefig"):
                            fig.savefig(f"{fullpath}.{format}")
                        elif hasattr(fig, "save"):  # for seaborn objects plots
                            fig.save(f"{fullpath}.{format}")
        else:
            ## Case when which is a list of plot specified as strings
            for key in which:
                if key not in self.plots.keys():
                    print(f"Plot {key} not found in the plots dictionary")
                    print(f"All available plots are: {list(self.plots.keys())}")
                    continue
                else:
                    for item in nested_to_tuple(
                        self.plots[key]
                    ):  # Plot all epochs and splits of type key
                        fig = item[-1]  ## Figure is in last element of the tuple
                        filename = key + "_" + "_".join(str(x) for x in item[0:-1])
                        fullpath = os.path.join(path, filename)
                        if hasattr(fig, "savefig"):
                            fig.savefig(f"{fullpath}.{format}")
                        elif hasattr(fig, "save"):  # for seaborn objects plots
                            fig.save(f"{fullpath}.{format}")

    ### Utilities ###

    @staticmethod
    def _make_loss_format(result: Result, config: DefaultConfig) -> pd.DataFrame:
        loss_df_melt = pd.DataFrame()
        for term in result.sub_losses.keys():
            # Get the loss values and ensure it's a dictionary
            loss_values = result.sub_losses.get(key=term).get()

            # Add explicit type checking/conversion
            if not isinstance(loss_values, dict):
                # If it's not a dict, try to convert it or handle appropriately
                if hasattr(loss_values, "to_dict"):
                    loss_values = loss_values.to_dict()  # type: ignore
                else:
                    # For non-convertible types, you might need a custom solution
                    # For numpy arrays, you could do something like:
                    if hasattr(loss_values, "shape"):
                        # For numpy arrays, create a dict with indices as keys
                        loss_values = {i: val for i, val in enumerate(loss_values)}

            # Now create the DataFrame
            loss_df = pd.DataFrame.from_dict(loss_values, orient="index")  # type: ignore

            # Rest of your code remains the same
            if term == "var_loss":
                loss_df = loss_df * config.beta
            loss_df["Epoch"] = loss_df.index + 1
            loss_df["Loss Term"] = term

            loss_df_melt = pd.concat(
                [
                    loss_df_melt,
                    loss_df.melt(
                        id_vars=["Epoch", "Loss Term"],
                        var_name="Split",
                        value_name="Loss Value",
                    ),
                ],
                axis=0,
            ).reset_index(drop=True)

        # Similar handling for the total losses
        loss_values = result.losses.get()
        if not isinstance(loss_values, dict):
            if hasattr(loss_values, "to_dict"):
                loss_values = loss_values.to_dict()  # type: ignore
            else:
                if hasattr(loss_values, "shape"):
                    loss_values = {i: val for i, val in enumerate(loss_values)}

        loss_df = pd.DataFrame.from_dict(loss_values, orient="index")  # type: ignore
        loss_df["Epoch"] = loss_df.index + 1
        loss_df["Loss Term"] = "total_loss"

        loss_df_melt = pd.concat(
            [
                loss_df_melt,
                loss_df.melt(
                    id_vars=["Epoch", "Loss Term"],
                    var_name="Split",
                    value_name="Loss Value",
                ),
            ],
            axis=0,
        ).reset_index(drop=True)

        loss_df_melt["Loss Value"] = loss_df_melt["Loss Value"].astype(float)
        return loss_df_melt

    @staticmethod
    def _make_loss_plot(
        df_plot: pd.DataFrame, plot_type: str
    ) -> matplotlib.figure.Figure:  # type: ignore
        """
        Generates a plot for visualizing loss values from a DataFrame.

        Args:
            df_plot : DataFrame containing the loss values to be plotted. It should have the columns:
                - "Loss Term": The type of loss term (e.g., "total_loss", "reconstruction_loss").
                - "Epoch": The epoch number.
                - "Loss Value": The value of the loss.
                - "Split": The data split (e.g., "train", "validation").

            plot_type: The type of plot to generate. It can be either "absolute" or "relative".
                - "absolute": Generates a line plot for each unique loss term.
                - "relative": Generates a density plot for each data split, excluding the "total_loss" term.

        Returns:
            The generated matplotlib figure containing the loss plots.
        """
        fig_width_abs = 5 * len(df_plot["Loss Term"].unique())
        fig_width_rel = 5 * len(df_plot["Split"].unique())
        if plot_type == "absolute":
            fig, axes = plt.subplots(
                1,
                len(df_plot["Loss Term"].unique()),
                figsize=(fig_width_abs, 5),
                sharey=False,
            )
            ax = 0
            for term in df_plot["Loss Term"].unique():
                axes[ax] = sns.lineplot(
                    data=df_plot[(df_plot["Loss Term"] == term)],
                    x="Epoch",
                    y="Loss Value",
                    hue="Split",
                    ax=axes[ax],
                ).set_title(term)
                ax += 1

            plt.close()

        if plot_type == "relative":
            # Check if loss values are positive
            if (df_plot["Loss Value"] < 0).any():
                # Warning
                warnings.warn(
                    "Loss values contain negative values. Check your loss function if correct. Loss will be clipped to zero for plotting."
                )
                df_plot["Loss Value"] = df_plot["Loss Value"].clip(lower=0)

            # Exclude loss terms where all Loss Value are zero or NaN over all epochs
            valid_terms = [
                term
                for term in df_plot["Loss Term"].unique()
                if (
                    (df_plot[df_plot["Loss Term"] == term]["Loss Value"].notna().any())
                    and (df_plot[df_plot["Loss Term"] == term]["Loss Value"] != 0).any()
                )
            ]
            exclude = (
                (df_plot["Loss Term"] != "total_loss")
                & ~(df_plot["Loss Term"].str.contains("_factor"))
                & (df_plot["Loss Term"].isin(valid_terms))
            )

            df_plot.loc[exclude, "Relative Loss Value"] = (
                df_plot[exclude]
                .groupby(["Split", "Epoch"])["Loss Value"]
                .transform(lambda x: x / x.sum())
            )
            fig = (
                (
                    so.Plot(
                        df_plot[exclude],
                        "Epoch",
                        "Relative Loss Value",
                        color="Loss Term",
                    ).add(so.Area(alpha=0.7), so.Stack())
                )
                .facet("Split")
                .layout(size=(fig_width_rel, 5))
            )

            # fig, axes = plt.subplots(1, 2, figsize=(fig_width_rel, 5), sharey=True)

            # ax = 0

            # for split in df_plot["Split"].unique():
            #     axes[ax] = sns.kdeplot(
            #         data=df_plot[exclude & (df_plot["Split"] == split)],
            #         x="Epoch",
            #         hue="Loss Term",
            #         multiple="fill",
            #         weights="Loss Value",
            #         clip=[0, df_plot["Epoch"].max()],
            #         ax=axes[ax],
            #     ).set_title(split)
            #     ax += 1

            # plt.close()

        return fig

    @staticmethod
    def _plot_model_weights(model: torch.nn.Module) -> matplotlib.figure.Figure:  # type: ignore
        """
        Visualization of model weights in encoder and decoder layers as heatmap for each layer as subplot.
        Handles non-symmetrical autoencoder architectures.
        Plots _mu layer for encoder as well.
        Uses node_names for decoder layers if model has ontologies.
        ARGS:
            model (torch.nn.Module): PyTorch model instance.
        RETURNS:
            fig (matplotlib.figure): Figure handle (of last plot)
        """
        all_weights = []
        names = []
        node_names = None
        if hasattr(model, "ontologies"):
            if model.ontologies is not None:
                node_names = []
                for ontology in model.ontologies:
                    node_names.append(list(ontology.keys()))
                node_names.append(model.feature_order)

        # Collect encoder and decoder weights separately
        encoder_weights = []
        encoder_names = []
        decoder_weights = []
        decoder_names = []
        for name, param in model.named_parameters():
            # print(name)
            if "weight" in name and len(param.shape) == 2:
                if "encoder" in name and "var" not in name and "_mu" not in name:
                    encoder_weights.append(param.detach().cpu().numpy())
                    encoder_names.append(name[:-7])
                elif "_mu" in name:
                    encoder_weights.append(param.detach().cpu().numpy())
                    encoder_names.append(name[:-7])
                elif "decoder" in name and "var" not in name:
                    decoder_weights.append(param.detach().cpu().numpy())
                    decoder_names.append(name[:-7])
                elif (
                    "encoder" not in name
                    and "decoder" not in name
                    and "var" not in name
                ):
                    # fallback for models without explicit encoder/decoder in name
                    all_weights.append(param.detach().cpu().numpy())
                    names.append(name[:-7])

        if encoder_weights or decoder_weights:
            n_enc = len(encoder_weights)
            n_dec = len(decoder_weights)
            n_cols = max(n_enc, n_dec)
            fig, axes = plt.subplots(2, n_cols, sharex=False, figsize=(15 * n_cols, 15))
            if n_cols == 1:
                axes = axes.reshape(2, 1)
            # Plot encoder weights
            for i in range(n_enc):
                ax = axes[0, i]
                sns.heatmap(
                    encoder_weights[i],
                    cmap=sns.color_palette("Spectral", as_cmap=True),
                    center=0,
                    ax=ax,
                ).set(title=encoder_names[i])
                ax.set_ylabel("Out Node", size=12)
            # Hide unused encoder subplots
            for i in range(n_enc, n_cols):
                axes[0, i].axis("off")
            # Plot decoder weights
            for i in range(n_dec):
                ax = axes[1, i]
                heatmap_kwargs = {}

                sns.heatmap(
                    decoder_weights[i],
                    cmap=sns.color_palette("Spectral", as_cmap=True),
                    center=0,
                    ax=ax,
                    **heatmap_kwargs,
                ).set(title=decoder_names[i])
                if model.ontologies is not None:
                    axes[1, i].set_xticks(
                        ticks=range(len(node_names[i])),  # type: ignore
                        labels=node_names[i],  # type: ignore
                        rotation=90,
                        fontsize=8,
                    )
                    axes[1, i].set_yticks(
                        ticks=range(len(node_names[i + 1])),  # type: ignore
                        labels=node_names[i + 1],  # type: ignore
                        rotation=0,
                        fontsize=8,
                    )
                ax.set_xlabel("In Node", size=12)
                ax.set_ylabel("Out Node", size=12)
            # Hide unused decoder subplots
            for i in range(n_dec, n_cols):
                axes[1, i].axis("off")
        else:
            # fallback: plot all weights in order, split in half for encoder/decoder
            n_layers = len(all_weights) // 2
            fig, axes = plt.subplots(
                2, n_layers, sharex=False, figsize=(5 * n_layers, 10)
            )
            for layer in range(n_layers):
                sns.heatmap(
                    all_weights[layer],
                    cmap=sns.color_palette("Spectral", as_cmap=True),
                    center=0,
                    ax=axes[0, layer],
                ).set(title=names[layer])
                sns.heatmap(
                    all_weights[n_layers + layer],
                    cmap=sns.color_palette("Spectral", as_cmap=True),
                    center=0,
                    ax=axes[1, layer],
                ).set(title=names[n_layers + layer])
                axes[1, layer].set_xlabel("In Node", size=12)
                axes[0, layer].set_ylabel("Out Node", size=12)
                axes[1, layer].set_ylabel("Out Node", size=12)

        fig.suptitle("Model Weights", size=20)
        plt.close()
        return fig

    @staticmethod
    def _collect_all_metadata(result):
        all_metadata = pd.DataFrame()

        # 1) collect metadata from results.datasets

        # 1a) iterate over splits [train, valid, test] if they exist
        for split in ["train", "valid", "test"]:

            if hasattr(result.datasets, split) and result.datasets[split] is not None:
                if hasattr(result.datasets[split], "metadata"):
                    split_metadata = result.datasets[split].metadata

                    # 1b) if result.datasets.split is a dictionary, iterate over keys (modalities)
                    if isinstance(split_metadata, dict):
                        for modality, modality_data in split_metadata.items():
                            all_metadata = pd.concat(
                                [all_metadata, modality_data], axis=0
                            )
                    # 1c) if result.datasets.split is a Dataframe, just collect metadata directly
                    elif isinstance(split_metadata, pd.DataFrame):
                        all_metadata = pd.concat([all_metadata, split_metadata], axis=0)
                else:
                    split_modalities = result.datasets[split].datasets
                    if isinstance(split_modalities, dict):
                        for modality, modality_data in split_modalities.items():
                            if hasattr(modality_data, "metadata"):
                                modality_metadata = modality_data.metadata
                                if isinstance(modality_metadata, pd.DataFrame):
                                    all_metadata = pd.concat(
                                        [all_metadata, modality_metadata], axis=0
                                    )

        # 2) collect metadata from results.new_datasets in the same way
        if hasattr(result, "new_datasets"):
            for split in ["train", "valid", "test"]:
                if (
                    hasattr(result.new_datasets, split)
                    and result.new_datasets[split] is not None
                ):
                    if hasattr(result.new_datasets[split], "metadata"):
                        split_metadata = result.new_datasets[split].metadata

                        if isinstance(split_metadata, dict):
                            for modality, modality_data in split_metadata.items():
                                all_metadata = pd.concat(
                                    [all_metadata, modality_data], axis=0
                                )
                        elif isinstance(split_metadata, pd.DataFrame):
                            all_metadata = pd.concat(
                                [all_metadata, split_metadata], axis=0
                            )
                    else:
                        split_modalities = result.new_datasets[split].datasets
                        if isinstance(split_modalities, dict):
                            for modality, modality_data in split_modalities.items():
                                if hasattr(modality_data, "metadata"):
                                    modality_metadata = modality_data.metadata
                                    if isinstance(modality_metadata, pd.DataFrame):
                                        all_metadata = pd.concat(
                                            [all_metadata, modality_metadata], axis=0
                                        )

        # Remove duplicate rows if any
        all_metadata = all_metadata.loc[~all_metadata.index.duplicated(keep="first")]

        return all_metadata

save_plots(path, which='all', format='png')

Save specified plots to the given path in the specified format.

Parameters:

Name Type Description Default
path str

The directory path where the plots will be saved.

required
which Union[str, list]

A list of plot names to save or a string specifying which plots to save. If 'all', all plots in the plots dictionary will be saved. If a single plot name is provided as a string, only that plot will be saved.

'all'
format str

The file format in which to save the plots (e.g., 'png', 'jpg').

'png'

Returns:

Type Description
None

None

Raises:

Type Description
ValueError

If the 'which' parameter is not a list or a string.

Source code in src/autoencodix/base/_base_visualizer.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def save_plots(
    self, path: str, which: Union[str, list] = "all", format: str = "png"
) -> None:
    """
    Save specified plots to the given path in the specified format.

    Args:
        path: The directory path where the plots will be saved.
        which: A list of plot names to save or a string specifying which plots to save.
               If 'all', all plots in the plots dictionary will be saved.
               If a single plot name is provided as a string, only that plot will be saved.
        format: The file format in which to save the plots (e.g., 'png', 'jpg').

    Returns:
        None

    Raises:
        ValueError: If the 'which' parameter is not a list or a string.
    """
    if not os.path.exists(path):
        os.makedirs(path)

    if not isinstance(which, list):
        ## Case when which is a string
        if which == "all":
            ## Case when all plots are to be saved
            if len(self.plots) == 0:
                print("No plots found in the plots dictionary")
                print("You need to run  visualize() method first")
            else:
                for item in nested_to_tuple(self.plots):
                    fig = item[-1]  ## Figure is in last element of the tuple
                    filename = "_".join(str(x) for x in item[0:-1])
                    fullpath = os.path.join(path, filename)
                    if hasattr(fig, "savefig"):
                        fig.savefig(f"{fullpath}.{format}")
                    elif hasattr(fig, "save"):  # for seaborn objects plots
                        fig.save(f"{fullpath}.{format}")
        else:
            ## Case when a single plot is provided as string
            if which not in self.plots.keys():
                print(f"Plot {which} not found in the plots dictionary")
                print(f"All available plots are: {list(self.plots.keys())}")
            else:
                for item in nested_to_tuple(
                    self.plots[which]
                ):  # Plot all epochs and splits of type which
                    fig = item[-1]  ## Figure is in last element of the tuple
                    filename = which + "_" + "_".join(str(x) for x in item[0:-1])  # type: ignore
                    fullpath = os.path.join(path, filename)
                    if hasattr(fig, "savefig"):
                        fig.savefig(f"{fullpath}.{format}")
                    elif hasattr(fig, "save"):  # for seaborn objects plots
                        fig.save(f"{fullpath}.{format}")
    else:
        ## Case when which is a list of plot specified as strings
        for key in which:
            if key not in self.plots.keys():
                print(f"Plot {key} not found in the plots dictionary")
                print(f"All available plots are: {list(self.plots.keys())}")
                continue
            else:
                for item in nested_to_tuple(
                    self.plots[key]
                ):  # Plot all epochs and splits of type key
                    fig = item[-1]  ## Figure is in last element of the tuple
                    filename = key + "_" + "_".join(str(x) for x in item[0:-1])
                    fullpath = os.path.join(path, filename)
                    if hasattr(fig, "savefig"):
                        fig.savefig(f"{fullpath}.{format}")
                    elif hasattr(fig, "save"):  # for seaborn objects plots
                        fig.save(f"{fullpath}.{format}")

show_evaluation(param, metric, ml_alg=None)

Displays the evaluation plot for a specific clinical parameter, metric, and optionally ML algorithm. Args: param: clinical parameter to visualize. metric: metric to visualize. ml_alg: ML algorithm to visualize. If None, plots all available algorithms. Returns: None

Source code in src/autoencodix/base/_base_visualizer.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def show_evaluation(
    self,
    param: str,
    metric: str,
    ml_alg: Optional[str] = None,
) -> None:
    """
    Displays the evaluation plot for a specific clinical parameter, metric, and optionally ML algorithm.
    Args:
        param: clinical parameter to visualize.
        metric: metric to visualize.
        ml_alg: ML algorithm to visualize. If None, plots all available algorithms.
    Returns:
        None
    """
    plt.ioff()
    if "ML_Evaluation" not in self.plots.keys():
        print("ML Evaluation plots not found in the plots dictionary")
        print("You need to run evaluate() method first")
        return None
    if param not in self.plots["ML_Evaluation"].keys():
        print(f"Parameter {param} not found in the ML Evaluation plots")
        print(f"Available parameters: {list(self.plots['ML_Evaluation'].keys())}")
        return None
    if metric not in self.plots["ML_Evaluation"][param].keys():
        print(f"Metric {metric} not found in the ML Evaluation plots for {param}")
        print(
            f"Available metrics: {list(self.plots['ML_Evaluation'][param].keys())}"
        )
        return None

    algs = list(self.plots["ML_Evaluation"][param][metric].keys())
    if ml_alg is not None:
        if ml_alg not in algs:
            print(f"ML algorithm {ml_alg} not found for {param} and {metric}")
            print(f"Available ML algorithms: {algs}")
            return None
        fig = self.plots["ML_Evaluation"][param][metric][ml_alg].figure
        show_figure(fig)
        plt.show()
    else:
        for alg in algs:
            print(f"Showing plot for ML algorithm: {alg}")
            fig = self.plots["ML_Evaluation"][param][metric][alg].figure
            show_figure(fig)
            plt.show()

show_loss(plot_type='absolute')

Display the loss plot. Args: plot_type: Type of loss plot to display. Options are "absolute" or "relative". Options are "absolute" for the absolute loss plot and "relative" for the relative loss plot. Defaults to "absolute". Returns: None

Source code in src/autoencodix/base/_base_visualizer.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def show_loss(self, plot_type: str = "absolute") -> None:
    """
    Display the loss plot.
    Args:
        plot_type: Type of loss plot to display. Options are "absolute" or "relative". Options are
            "absolute" for the absolute loss plot and
            "relative" for the relative loss plot.
            Defaults to "absolute".
    Returns:
        None
    """
    if plot_type == "absolute":
        if "loss_absolute" not in self.plots.keys():
            print("Absolute loss plot not found in the plots dictionary")
            print(
                "This happens, when you did not run visualize() or if you saved and loaded the model with `save_all=False`"
            )
        else:
            fig = self.plots["loss_absolute"]
            show_figure(fig)
            plt.show()
    if plot_type == "relative":
        if "loss_relative" not in self.plots.keys():
            print("Relative loss plot not found in the plots dictionary")

            print(
                "This happens, when you did not run visualize() or if you saved and loaded the model with `save_all=False`"
            )
        else:
            fig = self.plots["loss_relative"]
            fig.show()
            # show_figure(fig)
            # plt.show()

    if plot_type not in ["absolute", "relative"]:
        print(
            "Type of loss plot not recognized. Please use 'absolute' or 'relative'"
        )