Skip to content

Data Module

BalancedBatchSampler

Bases: Sampler[List[int]]

A custom PyTorch Sampler that avoids creating a final batch of size 1.

This sampler behaves like a standard BatchSampler but with a key difference in handling the last batch. If the last batch would normally have a size of 1, this sampler redistributes the last two batches to be of roughly equal size. For example, if a dataset of 129 samples is used with a batch size of 128, instead of yielding batches of [128, 1], it will yield two balanced batches, such as [65, 64].

This is particularly useful for avoiding issues with layers like BatchNorm, which require batch sizes greater than 1, without having to drop data (drop_last=True).

Parameters:

Name Type Description Default
data_source Sized

The dataset to sample from.

required
batch_size int

The target number of samples in each batch.

required
shuffle bool

If True, the sampler will shuffle the indices at start of each epoch.

True
Source code in src/autoencodix/data/_sampler.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class BalancedBatchSampler(Sampler[List[int]]):
    """
    A custom PyTorch Sampler that avoids creating a final batch of size 1.

    This sampler behaves like a standard `BatchSampler` but with a key
    difference in handling the last batch. If the last batch would normally
    have a size of 1, this sampler redistributes the last two batches to be
    of roughly equal size. For example, if a dataset of 129 samples is used
    with a batch size of 128, instead of yielding batches of [128, 1], it
    will yield two balanced batches, such as [65, 64].

    This is particularly useful for avoiding issues with layers like
    BatchNorm, which require batch sizes greater than 1, without having to
    drop data (`drop_last=True`).

    Args:
        data_source: The dataset to sample from.
        batch_size: The target number of samples in each batch.
        shuffle: If True, the sampler will shuffle the indices at start of each epoch.
    """

    def __init__(self, data_source: Sized, batch_size: int, shuffle: bool = True):
        """Initializes the BalancedBatchSampler.
        Args:
            data_source: The dataset to sample from.
            batch_size: The target number of samples in each batch.
            shuffle: If True, the sampler will shuffle the indices at start of each epoch.
        """
        if not isinstance(batch_size, int) or batch_size <= 0:
            raise ValueError(
                f"batch_size should be a positive integer, but got {batch_size}"
            )

        self.data_source = data_source
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self) -> Iterator[List[int]]:
        """
        Returns an iterator over batches of indices.
        """
        n_samples = len(self.data_source)
        if n_samples == 0:
            return

        # Generate a list of indices
        indices = torch.arange(n_samples)
        if self.shuffle:
            # Use a random permutation for shuffling
            indices = torch.randperm(n_samples)

        # Check for the special case where the last batch would be of size 1.
        # This logic only applies if there is more than one batch to begin with.
        if n_samples > self.batch_size and n_samples % self.batch_size == 1:
            # Calculate the number of full batches to yield before special handling
            num_full_batches = n_samples // self.batch_size - 1

            # Yield the full-sized batches
            for i in range(num_full_batches):
                start_idx = i * self.batch_size
                end_idx = start_idx + self.batch_size
                yield indices[start_idx:end_idx].tolist()

            # Handle the last two batches by redistributing them
            remaining_indices_start = num_full_batches * self.batch_size
            remaining_indices = indices[remaining_indices_start:]

            # Split the remaining indices (batch_size + 1) into two roughly equal halves
            split_point = (self.batch_size + 1) // 2
            yield remaining_indices[:split_point].tolist()
            yield remaining_indices[split_point:].tolist()

        else:
            # Standard behavior: yield batches of size `batch_size`
            # The last batch will have size > 1 or there will be no remainder.
            for i in range(0, n_samples, self.batch_size):
                end_idx = min(i + self.batch_size, n_samples)
                yield indices[i:end_idx].tolist()

    def __len__(self) -> int:
        """
        Returns the total number of batches in an epoch.
        """
        n_samples = len(self.data_source)
        if n_samples == 0:
            return 0

        # If we are redistributing, we create one extra batch compared to floor division
        if n_samples > self.batch_size and n_samples % self.batch_size == 1:
            return n_samples // self.batch_size + 1
        else:
            # Standard ceiling division to calculate number of batches
            return (n_samples + self.batch_size - 1) // self.batch_size

__init__(data_source, batch_size, shuffle=True)

Initializes the BalancedBatchSampler. Args: data_source: The dataset to sample from. batch_size: The target number of samples in each batch. shuffle: If True, the sampler will shuffle the indices at start of each epoch.

Source code in src/autoencodix/data/_sampler.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(self, data_source: Sized, batch_size: int, shuffle: bool = True):
    """Initializes the BalancedBatchSampler.
    Args:
        data_source: The dataset to sample from.
        batch_size: The target number of samples in each batch.
        shuffle: If True, the sampler will shuffle the indices at start of each epoch.
    """
    if not isinstance(batch_size, int) or batch_size <= 0:
        raise ValueError(
            f"batch_size should be a positive integer, but got {batch_size}"
        )

    self.data_source = data_source
    self.batch_size = batch_size
    self.shuffle = shuffle

__iter__()

Returns an iterator over batches of indices.

Source code in src/autoencodix/data/_sampler.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __iter__(self) -> Iterator[List[int]]:
    """
    Returns an iterator over batches of indices.
    """
    n_samples = len(self.data_source)
    if n_samples == 0:
        return

    # Generate a list of indices
    indices = torch.arange(n_samples)
    if self.shuffle:
        # Use a random permutation for shuffling
        indices = torch.randperm(n_samples)

    # Check for the special case where the last batch would be of size 1.
    # This logic only applies if there is more than one batch to begin with.
    if n_samples > self.batch_size and n_samples % self.batch_size == 1:
        # Calculate the number of full batches to yield before special handling
        num_full_batches = n_samples // self.batch_size - 1

        # Yield the full-sized batches
        for i in range(num_full_batches):
            start_idx = i * self.batch_size
            end_idx = start_idx + self.batch_size
            yield indices[start_idx:end_idx].tolist()

        # Handle the last two batches by redistributing them
        remaining_indices_start = num_full_batches * self.batch_size
        remaining_indices = indices[remaining_indices_start:]

        # Split the remaining indices (batch_size + 1) into two roughly equal halves
        split_point = (self.batch_size + 1) // 2
        yield remaining_indices[:split_point].tolist()
        yield remaining_indices[split_point:].tolist()

    else:
        # Standard behavior: yield batches of size `batch_size`
        # The last batch will have size > 1 or there will be no remainder.
        for i in range(0, n_samples, self.batch_size):
            end_idx = min(i + self.batch_size, n_samples)
            yield indices[i:end_idx].tolist()

__len__()

Returns the total number of batches in an epoch.

Source code in src/autoencodix/data/_sampler.py
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def __len__(self) -> int:
    """
    Returns the total number of batches in an epoch.
    """
    n_samples = len(self.data_source)
    if n_samples == 0:
        return 0

    # If we are redistributing, we create one extra batch compared to floor division
    if n_samples > self.batch_size and n_samples % self.batch_size == 1:
        return n_samples // self.batch_size + 1
    else:
        # Standard ceiling division to calculate number of batches
        return (n_samples + self.batch_size - 1) // self.batch_size

DataFilter

Preprocesses dataframes, including filtering and scaling.

This class separates the filtering logic that needs to be applied consistently across train, validation, and test sets from the scaling logic that is typically fitted on the training data and then applied to the other sets.

Attributes:

Name Type Description
data_info

Configuration object containing preprocessing parameters.

filtered_features Optional[Set[str]]

Set of features to keep after filtering on the training data. None initially.

_scaler

The fitted scaler object. None initially.

ontologies

Ontology information, if provided for Ontix.

config

Configuration object containing default parameters.

Source code in src/autoencodix/data/_filter.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
class DataFilter:
    """Preprocesses dataframes, including filtering and scaling.

    This class separates the filtering logic that needs to be applied consistently
    across train, validation, and test sets from the scaling logic that is
    typically fitted on the training data and then applied to the other sets.

    Attributes:
        data_info: Configuration object containing preprocessing parameters.
        filtered_features: Set of features to keep after filtering on the training data. None initially.
        _scaler: The fitted scaler object. None initially.
        ontologies: Ontology information, if provided for Ontix.
        config: Configuration object containing default parameters.
    """

    def __init__(
        self,
        data_info: DataInfo,
        config: DefaultConfig,
        ontologies: Optional[tuple] = None,
    ):  # Addition to Varix, mandotory for Ontix
        """Initializes the DataFilter with a configuration.

        Args:
            data_info: Configuration object containing preprocessing parameters.
            config: Configuration object containing default parameters.
            ontologies: Ontology information, if provided for Ontix.
        """
        self.data_info = data_info
        self.config = config
        self.filtered_features: Optional[Set[str]] = None
        self._scaler = None
        self.ontologies = ontologies  # Addition to Varix, mandotory for Ontix
        self._init_scaler()

    def _filter_nonzero_variance(self, df: pd.DataFrame) -> pd.Series:
        """Removes features with zero variance.

        Args:
            df: Input dataframe.

        Returns:
            Filtered dataframe containing only columns with non-zero variance.
        """
        var = pd.Series(np.var(df, axis=0), index=df.columns)
        return df[var[var > 0].index]

    def _filter_by_variance(
        self, df: pd.DataFrame, k: Optional[int]
    ) -> Union[pd.Series, pd.DataFrame]:
        """Keeps top k features by variance.

        Args:
            df: Input dataframe.
            k: Number of top variance features to keep. If None or greater
               than number of columns, all features are kept.

        Returns:
            Filtered dataframe with top k variance features.
        """
        if k is None or k > df.shape[1]:
            warnings.warn(
                "WARNING: k is None or greater than number of columns, keeping all features."
            )
            return df
        var = pd.Series(np.var(df, axis=0), index=df.columns)
        return df[var.sort_values(ascending=False).index[:k]]

    def _filter_by_mad(
        self, df: pd.DataFrame, k: Optional[int]
    ) -> Union[pd.Series, pd.DataFrame]:
        """Keeps top k features by median absolute deviation.

        Args:
            df: Input dataframe.
            k: Number of top MAD features to keep. If None or greater
               than number of columns, all features are kept.

        Returns:
            Filtered dataframe with top k MAD features.
        """
        if k is None or k > df.shape[1]:
            return df
        mads = pd.Series(median_abs_deviation(df, axis=0), index=df.columns)
        return df[mads.sort_values(ascending=False).index[:k]]

    def _filter_by_correlation(
        self, df: pd.DataFrame, k: Optional[int]
    ) -> Union[pd.Series, pd.DataFrame]:
        """Filters features using correlation-based clustering.

        This method clusters features based on their correlation distance and
        selects a representative feature (medoid) from each cluster.

        Args:
            df: Input dataframe.
            k: Number of clusters to create. If None or greater
               than number of columns, all features are kept.

        Returns:
            Filtered dataframe with one representative feature (medoid) per cluster.
        """
        if k is None or k > df.shape[1]:
            warnings.warn(
                "WARNING: k is None or greater than number of columns, keeping all features."
            )
            return df
        else:
            X = df.transpose().values

            dist_matrix = squareform(pdist(X, metric="correlation"))

            clustering = AgglomerativeClustering(
                n_clusters=k,
            ).fit(dist_matrix)

            medoid_indices = []
            for i in range(k):
                cluster_points = np.where(clustering.labels_ == i)[0]
                if len(cluster_points) > 0:
                    # The medoid is the point with minimum sum of distances to other points in the cluster
                    cluster_dist_matrix = dist_matrix[
                        np.ix_(cluster_points, cluster_points)
                    ]
                    sum_distances = np.sum(cluster_dist_matrix, axis=1)
                    medoid_idx = cluster_points[np.argmin(sum_distances)]
                    medoid_indices.append(medoid_idx)

            df_filt: Union[pd.DataFrame, pd.Series] = df.iloc[:, medoid_indices]
            return df_filt

    def filter(
        self, df: pd.DataFrame, genes_to_keep: Optional[List] = None
    ) -> Tuple[Union[pd.Series, pd.DataFrame], List[str]]:
        """Applies the configured filtering method to the dataframe.

        This method is intended to be called on the training data to determine
        which features to keep. The `filtered_features` attribute will be set
        based on the result.

        Args:
            df: Input dataframe to be filtered (typically the training set).
            genes_to_keep: A list of gene names to explicitly keep.
                If provided, other filtering methods will be ignored.

        Returns:
            A tuple containing:
                - The filtered dataframe.
                - A list of column names (features) that were kept.

        Raises:
            KeyError: If some genes in `genes_to_keep` are not present in the dataframe.
        """
        if genes_to_keep is not None:
            try:
                df: Union[pd.Series, pd.DataFrame] = df[genes_to_keep]
                return df, genes_to_keep
            except KeyError as e:
                raise KeyError(
                    f"Some genes in genes_to_keep are not present in the dataframe: {e}"
                )

        MIN_FILTER = 2
        filtering_method = FilterMethod(self.data_info.filtering)

        if df.shape[0] < MIN_FILTER or df.empty:
            warnings.warn(
                f"WARNING: df is too small for filtering, needs to have at least {MIN_FILTER}"
            )
            return df, df.columns.tolist()

        filtered_df = df.copy()

        ## Remove features which are not in the ontology for Ontix architecture
        ## must be done before other filtering is applied
        if hasattr(self, "ontologies") and self.ontologies is not None:
            all_feature_names: Union[Set, List] = set()
            for key, values in self.ontologies[-1].items():
                all_feature_names.update(values)
            all_feature_names = list(all_feature_names)
            feature_order = filtered_df.columns.tolist()
            missing_features = [f for f in feature_order if f not in all_feature_names]
            ## Filter out features not in the ontology
            feature_order = [f for f in feature_order if f in all_feature_names]
            if missing_features:
                print(
                    f"Features in feature_order not found in all_feature_names: {missing_features}"
                )

            filtered_df = filtered_df.loc[:, feature_order]

        ####

        if filtering_method == FilterMethod.NOFILT:
            return filtered_df, df.columns.tolist()
        if self.data_info.k_filter is None:
            return filtered_df, df.columns.tolist()

        if filtering_method == FilterMethod.NONZEROVAR:
            filtered_df = self._filter_nonzero_variance(filtered_df)
        elif filtering_method == FilterMethod.VAR:
            filtered_df = self._filter_nonzero_variance(filtered_df)
            filtered_df = self._filter_by_variance(filtered_df, self.data_info.k_filter)
        elif filtering_method == FilterMethod.MAD:
            filtered_df = self._filter_nonzero_variance(filtered_df)
            filtered_df = self._filter_by_mad(filtered_df, self.data_info.k_filter)
        elif filtering_method == FilterMethod.CORR:
            filtered_df = self._filter_nonzero_variance(filtered_df)
            filtered_df = self._filter_by_correlation(
                filtered_df, self.data_info.k_filter
            )
        elif filtering_method == FilterMethod.VARCORR:
            filtered_df = self._filter_nonzero_variance(filtered_df)
            filtered_df = self._filter_by_variance(
                filtered_df,
                self.data_info.k_filter * 10 if self.data_info.k_filter else None,
            )
            if self.data_info.k_filter is not None:
                # Apply correlation filter on the already variance-filtered data
                num_features_after_var = filtered_df.shape[1]
                k_corr = min(self.data_info.k_filter, num_features_after_var)
                filtered_df = self._filter_by_correlation(filtered_df, k_corr)

        return filtered_df, filtered_df.columns.tolist()

    def _init_scaler(self) -> None:
        """Initializes the scaler based on the configured scaling method."""
        self.method = self.data_info.scaling

        if self.method == "NOTSET":
            # if not set in data config, we use the global scaling config
            self.method = self.config.scaling
        if self.method == "MINMAX":
            self._scaler = MinMaxScaler(clip=True)
        elif self.method == "STANDARD":
            self._scaler = StandardScaler()
        elif self.method == "ROBUST":
            self._scaler = RobustScaler()
        elif self.method == "MAXABS":
            self._scaler = MaxAbsScaler()
        else:
            self._scaler = None

    def fit_scaler(self, df: Union[pd.Series, pd.DataFrame]) -> Any:
        """Fits the scaler to the input dataframe (typically the training set).

        Args:
            df: Input dataframe to fit the scaler on.

        Returns:
            The fitted scaler object.
        """
        self._init_scaler()
        if self._scaler is not None:
            self._scaler.fit(df)
        else:
            warnings.warn("No scaling applied.")
        return self._scaler

    def scale(
        self, df: Union[pd.Series, pd.DataFrame], scaler: Any
    ) -> Union[pd.Series, pd.DataFrame]:
        """Applies the fitted scaler to the input dataframe.

        Args:
            df: Input dataframe to be scaled.
            scaler: The fitted scaler object.

        Returns:
            Scaled dataframe.
        """
        if self.method == "LOG1P":
            X_log = np.log1p(df.values)
            X_norm = X_log / np.log1p(np.max(X_log, axis=0))
            df_scaled = pd.DataFrame(X_norm, columns=df.columns, index=df.index)
            return df_scaled
        if scaler is None:
            warnings.warn("No scaler has been fitted yet or scaling is set to none.")
            return df
        df_scaled = pd.DataFrame(
            scaler.transform(df), columns=df.columns, index=df.index
        )
        return df_scaled

    @property
    def available_methods(self) -> List[str]:
        """Lists all available filtering methods.

        Returns:
            List of available filtering method names.
        """
        return [method.value for method in FilterMethod]

available_methods property

Lists all available filtering methods.

Returns:

Type Description
List[str]

List of available filtering method names.

__init__(data_info, config, ontologies=None)

Initializes the DataFilter with a configuration.

Parameters:

Name Type Description Default
data_info DataInfo

Configuration object containing preprocessing parameters.

required
config DefaultConfig

Configuration object containing default parameters.

required
ontologies Optional[tuple]

Ontology information, if provided for Ontix.

None
Source code in src/autoencodix/data/_filter.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __init__(
    self,
    data_info: DataInfo,
    config: DefaultConfig,
    ontologies: Optional[tuple] = None,
):  # Addition to Varix, mandotory for Ontix
    """Initializes the DataFilter with a configuration.

    Args:
        data_info: Configuration object containing preprocessing parameters.
        config: Configuration object containing default parameters.
        ontologies: Ontology information, if provided for Ontix.
    """
    self.data_info = data_info
    self.config = config
    self.filtered_features: Optional[Set[str]] = None
    self._scaler = None
    self.ontologies = ontologies  # Addition to Varix, mandotory for Ontix
    self._init_scaler()

filter(df, genes_to_keep=None)

Applies the configured filtering method to the dataframe.

This method is intended to be called on the training data to determine which features to keep. The filtered_features attribute will be set based on the result.

Parameters:

Name Type Description Default
df DataFrame

Input dataframe to be filtered (typically the training set).

required
genes_to_keep Optional[List]

A list of gene names to explicitly keep. If provided, other filtering methods will be ignored.

None

Returns:

Type Description
Tuple[Union[Series, DataFrame], List[str]]

A tuple containing: - The filtered dataframe. - A list of column names (features) that were kept.

Raises:

Type Description
KeyError

If some genes in genes_to_keep are not present in the dataframe.

Source code in src/autoencodix/data/_filter.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def filter(
    self, df: pd.DataFrame, genes_to_keep: Optional[List] = None
) -> Tuple[Union[pd.Series, pd.DataFrame], List[str]]:
    """Applies the configured filtering method to the dataframe.

    This method is intended to be called on the training data to determine
    which features to keep. The `filtered_features` attribute will be set
    based on the result.

    Args:
        df: Input dataframe to be filtered (typically the training set).
        genes_to_keep: A list of gene names to explicitly keep.
            If provided, other filtering methods will be ignored.

    Returns:
        A tuple containing:
            - The filtered dataframe.
            - A list of column names (features) that were kept.

    Raises:
        KeyError: If some genes in `genes_to_keep` are not present in the dataframe.
    """
    if genes_to_keep is not None:
        try:
            df: Union[pd.Series, pd.DataFrame] = df[genes_to_keep]
            return df, genes_to_keep
        except KeyError as e:
            raise KeyError(
                f"Some genes in genes_to_keep are not present in the dataframe: {e}"
            )

    MIN_FILTER = 2
    filtering_method = FilterMethod(self.data_info.filtering)

    if df.shape[0] < MIN_FILTER or df.empty:
        warnings.warn(
            f"WARNING: df is too small for filtering, needs to have at least {MIN_FILTER}"
        )
        return df, df.columns.tolist()

    filtered_df = df.copy()

    ## Remove features which are not in the ontology for Ontix architecture
    ## must be done before other filtering is applied
    if hasattr(self, "ontologies") and self.ontologies is not None:
        all_feature_names: Union[Set, List] = set()
        for key, values in self.ontologies[-1].items():
            all_feature_names.update(values)
        all_feature_names = list(all_feature_names)
        feature_order = filtered_df.columns.tolist()
        missing_features = [f for f in feature_order if f not in all_feature_names]
        ## Filter out features not in the ontology
        feature_order = [f for f in feature_order if f in all_feature_names]
        if missing_features:
            print(
                f"Features in feature_order not found in all_feature_names: {missing_features}"
            )

        filtered_df = filtered_df.loc[:, feature_order]

    ####

    if filtering_method == FilterMethod.NOFILT:
        return filtered_df, df.columns.tolist()
    if self.data_info.k_filter is None:
        return filtered_df, df.columns.tolist()

    if filtering_method == FilterMethod.NONZEROVAR:
        filtered_df = self._filter_nonzero_variance(filtered_df)
    elif filtering_method == FilterMethod.VAR:
        filtered_df = self._filter_nonzero_variance(filtered_df)
        filtered_df = self._filter_by_variance(filtered_df, self.data_info.k_filter)
    elif filtering_method == FilterMethod.MAD:
        filtered_df = self._filter_nonzero_variance(filtered_df)
        filtered_df = self._filter_by_mad(filtered_df, self.data_info.k_filter)
    elif filtering_method == FilterMethod.CORR:
        filtered_df = self._filter_nonzero_variance(filtered_df)
        filtered_df = self._filter_by_correlation(
            filtered_df, self.data_info.k_filter
        )
    elif filtering_method == FilterMethod.VARCORR:
        filtered_df = self._filter_nonzero_variance(filtered_df)
        filtered_df = self._filter_by_variance(
            filtered_df,
            self.data_info.k_filter * 10 if self.data_info.k_filter else None,
        )
        if self.data_info.k_filter is not None:
            # Apply correlation filter on the already variance-filtered data
            num_features_after_var = filtered_df.shape[1]
            k_corr = min(self.data_info.k_filter, num_features_after_var)
            filtered_df = self._filter_by_correlation(filtered_df, k_corr)

    return filtered_df, filtered_df.columns.tolist()

fit_scaler(df)

Fits the scaler to the input dataframe (typically the training set).

Parameters:

Name Type Description Default
df Union[Series, DataFrame]

Input dataframe to fit the scaler on.

required

Returns:

Type Description
Any

The fitted scaler object.

Source code in src/autoencodix/data/_filter.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def fit_scaler(self, df: Union[pd.Series, pd.DataFrame]) -> Any:
    """Fits the scaler to the input dataframe (typically the training set).

    Args:
        df: Input dataframe to fit the scaler on.

    Returns:
        The fitted scaler object.
    """
    self._init_scaler()
    if self._scaler is not None:
        self._scaler.fit(df)
    else:
        warnings.warn("No scaling applied.")
    return self._scaler

scale(df, scaler)

Applies the fitted scaler to the input dataframe.

Parameters:

Name Type Description Default
df Union[Series, DataFrame]

Input dataframe to be scaled.

required
scaler Any

The fitted scaler object.

required

Returns:

Type Description
Union[Series, DataFrame]

Scaled dataframe.

Source code in src/autoencodix/data/_filter.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def scale(
    self, df: Union[pd.Series, pd.DataFrame], scaler: Any
) -> Union[pd.Series, pd.DataFrame]:
    """Applies the fitted scaler to the input dataframe.

    Args:
        df: Input dataframe to be scaled.
        scaler: The fitted scaler object.

    Returns:
        Scaled dataframe.
    """
    if self.method == "LOG1P":
        X_log = np.log1p(df.values)
        X_norm = X_log / np.log1p(np.max(X_log, axis=0))
        df_scaled = pd.DataFrame(X_norm, columns=df.columns, index=df.index)
        return df_scaled
    if scaler is None:
        warnings.warn("No scaler has been fitted yet or scaling is set to none.")
        return df
    df_scaled = pd.DataFrame(
        scaler.transform(df), columns=df.columns, index=df.index
    )
    return df_scaled

DataPackage dataclass

Represents a data package containing multiple types of data.

Source code in src/autoencodix/data/datapackage.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
@dataclass
class DataPackage:
    """Represents a data package containing multiple types of data."""

    multi_sc: Optional[Dict[str, MuData]] = None  # # ty: ignore[invalid-type-form]
    multi_bulk: Optional[Dict[str, pd.DataFrame]] = None
    annotation: Optional[Dict[str, Union[pd.DataFrame, None]]] = None
    img: Optional[Dict[str, List[ImgData]]] = None

    from_modality: Optional[
        Dict[
            str,
            Union[
                pd.DataFrame,
                List[ImgData],
                MuData,  # ty: ignore[invalid-type-form]
                AnnData,  # ty: ignore[invalid-type-form]
            ],
        ]  # ty: ignore[invalid-type-form]
    ] = field(default_factory=dict, repr=False)
    to_modality: Optional[
        Dict[
            str,
            Union[
                pd.DataFrame,
                List[ImgData],
                MuData,  # ty: ignore[invalid-type-form]
                AnnData,  # ty: ignore[invalid-type-form]
            ],  # ty: ignore[invalid-type-form]
        ]  # ty: ignore[invalid-type-form]
    ] = field(default_factory=dict, repr=False)

    def __getitem__(self, key: str) -> Any:
        """Allow dictionary-like access to top-level attributes."""
        if hasattr(self, key):
            return getattr(self, key)
        raise KeyError(f"{key} not found in DataPackage.")

    def __setitem__(self, key: str, value: Any) -> None:
        """Allow dictionary-like item assignment to top-level attributes."""
        if hasattr(self, key):
            setattr(self, key, value)
        else:
            raise KeyError(f"{key} not found in DataPackage.")

    def __iter__(self) -> Iterator[Tuple[str, Any]]:
        """Make DataPackage iterable, yielding (key, value) pairs.

        For dictionary attributes, yields nested items as (parent_key.child_key, value).
        """
        for attr_name in self.__annotations__.keys():
            attr_value = getattr(self, attr_name)

            if attr_value is None:
                continue
            if isinstance(attr_value, dict):
                for sub_key, sub_value in attr_value.items():
                    yield f"{attr_name}.{sub_key}", sub_value
            else:
                yield attr_name, attr_value

    def format_shapes(self) -> str:
        """Format the shape dictionary in a clean, readable way."""
        shapes = self.shape()
        lines = []

        for data_type, data_info in shapes.items():
            # Skip empty entries
            if not data_info:
                continue

            sub_items = []
            for subtype, shape in data_info.items():
                if shape is not None:
                    if isinstance(shape, tuple):
                        sub_items.append(
                            f"{subtype}: {shape[0]} samples × {shape[1]} features"
                        )
                    else:
                        sub_items.append(f"{subtype}: {shape} items")

            if sub_items:
                lines.append(f"{data_type}:")
                lines.extend(f"  {item}" for item in sub_items)

        if not lines:
            return "Empty DataPackage"

        return "\n".join(lines)

    def __str__(self) -> str:
        return self.format_shapes()

    def __repr__(self) -> str:
        return self.__str__()

    def is_empty(self) -> bool:
        """Check if the data package is empty."""
        return all(
            [
                self.multi_sc is None,
                self.multi_bulk is None or len(self.multi_bulk) == 0,
                self.annotation is None,
                self.img is None,
                not self.from_modality,
                not self.to_modality,
            ]
        )

    def get_n_samples(self) -> Dict[str, Dict[str, int]]:
        """Get the number of samples for each data type in nested dictionary format.

        Returns:
            Dictionary with nested structure: {modality_type: {sub_key: count}}
        """
        n_samples: Dict[str, Dict[str, int]] = {}

        # Process each main attribute
        for attr_name in self.__annotations__.keys():
            attr_value = getattr(self, attr_name)

            if isinstance(attr_value, dict):
                # Handle dictionary attributes (multi_sc, multi_bulk, etc.)
                sub_counts = {}
                for sub_key, sub_value in attr_value.items():
                    if sub_value is None or len(sub_value) == 0:
                        continue
                    sub_counts[sub_key] = self._get_n_samples(sub_value)
                n_samples[attr_name] = sub_counts if sub_counts else {}
            else:
                # Handle non-dictionary attributes
                count = self._get_n_samples(attr_value)
                n_samples[attr_name] = {attr_name: count}

        paired_count = self._calculate_paired_count()
        n_samples["paired_count"] = {"paired_count": paired_count}

        return n_samples

    def _calculate_paired_count(self) -> int:
        """
        Calculate the number of samples that are common across modalities that have data.

        Returns:
            Number of common samples across modalities with data
        """
        all_counts = []

        # Collect all sample counts from modalities that have data
        for attr_name in self.__annotations__.keys():
            attr_value = getattr(self, attr_name)
            if attr_value is None:
                continue

            if isinstance(attr_value, dict):
                if attr_value:  # Non-empty dictionary
                    for sub_value in attr_value.values():
                        if sub_value is not None:
                            count = self._get_n_samples(sub_value)
                            if count > 0:
                                all_counts.append(count)
            else:
                count = self._get_n_samples(attr_value)
                if count > 0:
                    all_counts.append(count)

        # Return minimum count (intersection) or 0 if no data
        return min(all_counts) if all_counts else 0

    def get_common_ids(self) -> List[str]:
        """Get the common sample IDs across modalities that have data.

        Returns:
            List of sample IDs that are present in all modalities with data
        """
        all_ids = []

        # Collect sample IDs from each modality that has data
        for attr_name in self.__annotations__.keys():
            attr_value = getattr(self, attr_name)
            if attr_value is None:
                continue

            if isinstance(attr_value, dict):
                if attr_value:  # Non-empty dictionary
                    for sub_value in attr_value.values():
                        if sub_value is not None:
                            ids = self._get_sample_ids(sub_value)
                            if ids:
                                all_ids.append(set(ids))
            else:
                ids = self._get_sample_ids(attr_value)
                if ids:
                    all_ids.append(set(ids))

        # Find intersection of all ID sets
        if not all_ids:
            return []

        common = all_ids[0]
        for id_set in all_ids[1:]:
            common = common.intersection(id_set)

        return sorted(list(common))

    def _get_sample_ids(
        self,
        dataobj: Union[
            MuData,  # ty: ignore[invalid-type-form]
            pd.DataFrame,
            List[ImgData],
            AnnData,
        ],  # ty: ignore[invalid-type-form]
    ) -> List[str]:
        """
        Extract sample IDs from a data object.

        Args:
            dataobj: Data object to extract IDs from

        Returns:
            List of sample IDs
        """
        if dataobj is None:
            return []

        if isinstance(dataobj, pd.DataFrame):
            return dataobj.index.astype(str).tolist()
        elif isinstance(dataobj, list):
            # For lists of ImgData, extract sample_id from each object
            return [
                img_data.sample_id
                for img_data in dataobj
                if hasattr(img_data, "sample_id")
            ]
        elif isinstance(dataobj, AnnData):
            return dataobj.obs.index.astype(str).tolist()
        elif isinstance(dataobj, MuData):
            # For MuData, we can use the obs.index directly
            return dataobj.obs.index.astype(str).tolist()  # ty: ignore
        else:
            return []

    def _get_n_samples(
        self,
        dataobj: Union[
            MuData,  # ty: ignore[invalid-type-form]
            pd.DataFrame,
            List[ImgData],
            AnnData,
            Dict,
        ],
    ) -> int:
        """Get the number of samples for a specific attribute."""
        if dataobj is None:
            return 0

        if isinstance(dataobj, pd.DataFrame):
            return dataobj.shape[0]
        elif isinstance(dataobj, dict):
            if not dataobj:  # Empty dict
                return 0
            first_value = next(iter(dataobj.values()))
            return self._get_n_samples(first_value)
        elif isinstance(dataobj, list):
            return len(dataobj)
        elif isinstance(dataobj, AnnData):
            return dataobj.obs.shape[0]
        elif isinstance(dataobj, MuData):
            if not dataobj.mod:  # ty: ignore
                return 0
            return dataobj.n_obs  # ty: ignore
        else:
            raise ValueError(
                f"Unknown data type {type(dataobj)} for dataobj. Probably you've implemented a new attribute in the DataPackage class or changed the data type of an existing attribute."
            )

    def shape(self) -> Dict[str, Dict[str, Any]]:
        """
        Get the shape of the data for each data type in nested dictionary format.

        Returns:
            Dictionary with nested structure: {modality_type: {sub_key: shape}}
        """
        shapes: Dict[str, Dict[str, Any]] = {}

        for attr_name in self.__annotations__.keys():
            attr_value = getattr(self, attr_name)

            if isinstance(attr_value, dict):
                # Handle dictionary attributes
                if attr_value is None or len(attr_value) == 0:
                    # Empty or None dictionary
                    shapes[attr_name] = {}
                else:
                    sub_dict = self._get_shape_from_dict(attr_value)
                    shapes[attr_name] = sub_dict
            else:
                # Handle non-dictionary attributes
                shape = self._get_single_shape(attr_value)
                shapes[attr_name] = {attr_name: shape}

        return shapes

    def _get_single_shape(self, dataobj: Any) -> Optional[Union[Tuple, int]]:
        """Get shape for a single data object."""
        if dataobj is None:
            return None
        elif isinstance(dataobj, list):
            return len(dataobj)
        elif isinstance(dataobj, pd.DataFrame):
            return dataobj.shape
        elif isinstance(dataobj, AnnData):
            return dataobj.shape
        elif isinstance(dataobj, MuData):
            return (dataobj.n_obs, dataobj.n_vars)
        else:
            return None

    def _get_shape_from_dict(self, data_dict: Dict) -> Dict[str, Any]:
        """Recursively process dictionary to extract shapes of contained data objects.

        Args::
            data_dict: Dictionary containing data objects

        Returns:
            Dictionary with shapes information
        """
        result: Dict[str, Any] = {}
        for key, value in data_dict.items():
            if isinstance(value, pd.DataFrame):
                result[key] = value.shape
            elif isinstance(value, list):
                # For lists of objects, just store the length
                result[key] = len(value)
            elif isinstance(value, AnnData):
                result[key] = value.shape
            elif isinstance(value, MuData):
                result[key] = (value.n_obs, value.n_vars)
            elif isinstance(value, dict):
                # Recursively process nested dictionaries
                nested_result = self._get_shape_from_dict(value)
                result[key] = nested_result
            elif value is None:
                result[key] = None
            else:
                # For unknown types, store a descriptive string instead of raising an error
                # This is more robust as it won't crash the entire method
                result[key] = f"<{type(value).__name__}>"

        return result

    def get_modality_key(self, direction: str) -> Optional[str]:
        """Get the first key for a specific direction's modality.

        Args:
            direction: Either 'from' or 'to'

        Returns:
            First key of the modality dictionary or None if empty
        """
        if direction not in ["from", "to"]:
            raise ValueError(f"Direction must be 'from' or 'to', got {direction}")

        modality_dict = self.from_modality if direction == "from" else self.to_modality
        if not modality_dict:
            return None

        return next(iter(modality_dict.keys()), None)

__getitem__(key)

Allow dictionary-like access to top-level attributes.

Source code in src/autoencodix/data/datapackage.py
46
47
48
49
50
def __getitem__(self, key: str) -> Any:
    """Allow dictionary-like access to top-level attributes."""
    if hasattr(self, key):
        return getattr(self, key)
    raise KeyError(f"{key} not found in DataPackage.")

__iter__()

Make DataPackage iterable, yielding (key, value) pairs.

For dictionary attributes, yields nested items as (parent_key.child_key, value).

Source code in src/autoencodix/data/datapackage.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __iter__(self) -> Iterator[Tuple[str, Any]]:
    """Make DataPackage iterable, yielding (key, value) pairs.

    For dictionary attributes, yields nested items as (parent_key.child_key, value).
    """
    for attr_name in self.__annotations__.keys():
        attr_value = getattr(self, attr_name)

        if attr_value is None:
            continue
        if isinstance(attr_value, dict):
            for sub_key, sub_value in attr_value.items():
                yield f"{attr_name}.{sub_key}", sub_value
        else:
            yield attr_name, attr_value

__setitem__(key, value)

Allow dictionary-like item assignment to top-level attributes.

Source code in src/autoencodix/data/datapackage.py
52
53
54
55
56
57
def __setitem__(self, key: str, value: Any) -> None:
    """Allow dictionary-like item assignment to top-level attributes."""
    if hasattr(self, key):
        setattr(self, key, value)
    else:
        raise KeyError(f"{key} not found in DataPackage.")

format_shapes()

Format the shape dictionary in a clean, readable way.

Source code in src/autoencodix/data/datapackage.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def format_shapes(self) -> str:
    """Format the shape dictionary in a clean, readable way."""
    shapes = self.shape()
    lines = []

    for data_type, data_info in shapes.items():
        # Skip empty entries
        if not data_info:
            continue

        sub_items = []
        for subtype, shape in data_info.items():
            if shape is not None:
                if isinstance(shape, tuple):
                    sub_items.append(
                        f"{subtype}: {shape[0]} samples × {shape[1]} features"
                    )
                else:
                    sub_items.append(f"{subtype}: {shape} items")

        if sub_items:
            lines.append(f"{data_type}:")
            lines.extend(f"  {item}" for item in sub_items)

    if not lines:
        return "Empty DataPackage"

    return "\n".join(lines)

get_common_ids()

Get the common sample IDs across modalities that have data.

Returns:

Type Description
List[str]

List of sample IDs that are present in all modalities with data

Source code in src/autoencodix/data/datapackage.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def get_common_ids(self) -> List[str]:
    """Get the common sample IDs across modalities that have data.

    Returns:
        List of sample IDs that are present in all modalities with data
    """
    all_ids = []

    # Collect sample IDs from each modality that has data
    for attr_name in self.__annotations__.keys():
        attr_value = getattr(self, attr_name)
        if attr_value is None:
            continue

        if isinstance(attr_value, dict):
            if attr_value:  # Non-empty dictionary
                for sub_value in attr_value.values():
                    if sub_value is not None:
                        ids = self._get_sample_ids(sub_value)
                        if ids:
                            all_ids.append(set(ids))
        else:
            ids = self._get_sample_ids(attr_value)
            if ids:
                all_ids.append(set(ids))

    # Find intersection of all ID sets
    if not all_ids:
        return []

    common = all_ids[0]
    for id_set in all_ids[1:]:
        common = common.intersection(id_set)

    return sorted(list(common))

get_modality_key(direction)

Get the first key for a specific direction's modality.

Parameters:

Name Type Description Default
direction str

Either 'from' or 'to'

required

Returns:

Type Description
Optional[str]

First key of the modality dictionary or None if empty

Source code in src/autoencodix/data/datapackage.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
def get_modality_key(self, direction: str) -> Optional[str]:
    """Get the first key for a specific direction's modality.

    Args:
        direction: Either 'from' or 'to'

    Returns:
        First key of the modality dictionary or None if empty
    """
    if direction not in ["from", "to"]:
        raise ValueError(f"Direction must be 'from' or 'to', got {direction}")

    modality_dict = self.from_modality if direction == "from" else self.to_modality
    if not modality_dict:
        return None

    return next(iter(modality_dict.keys()), None)

get_n_samples()

Get the number of samples for each data type in nested dictionary format.

Returns:

Type Description
Dict[str, Dict[str, int]]

Dictionary with nested structure: {modality_type: {sub_key: count}}

Source code in src/autoencodix/data/datapackage.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def get_n_samples(self) -> Dict[str, Dict[str, int]]:
    """Get the number of samples for each data type in nested dictionary format.

    Returns:
        Dictionary with nested structure: {modality_type: {sub_key: count}}
    """
    n_samples: Dict[str, Dict[str, int]] = {}

    # Process each main attribute
    for attr_name in self.__annotations__.keys():
        attr_value = getattr(self, attr_name)

        if isinstance(attr_value, dict):
            # Handle dictionary attributes (multi_sc, multi_bulk, etc.)
            sub_counts = {}
            for sub_key, sub_value in attr_value.items():
                if sub_value is None or len(sub_value) == 0:
                    continue
                sub_counts[sub_key] = self._get_n_samples(sub_value)
            n_samples[attr_name] = sub_counts if sub_counts else {}
        else:
            # Handle non-dictionary attributes
            count = self._get_n_samples(attr_value)
            n_samples[attr_name] = {attr_name: count}

    paired_count = self._calculate_paired_count()
    n_samples["paired_count"] = {"paired_count": paired_count}

    return n_samples

is_empty()

Check if the data package is empty.

Source code in src/autoencodix/data/datapackage.py
110
111
112
113
114
115
116
117
118
119
120
121
def is_empty(self) -> bool:
    """Check if the data package is empty."""
    return all(
        [
            self.multi_sc is None,
            self.multi_bulk is None or len(self.multi_bulk) == 0,
            self.annotation is None,
            self.img is None,
            not self.from_modality,
            not self.to_modality,
        ]
    )

shape()

Get the shape of the data for each data type in nested dictionary format.

Returns:

Type Description
Dict[str, Dict[str, Any]]

Dictionary with nested structure: {modality_type: {sub_key: shape}}

Source code in src/autoencodix/data/datapackage.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def shape(self) -> Dict[str, Dict[str, Any]]:
    """
    Get the shape of the data for each data type in nested dictionary format.

    Returns:
        Dictionary with nested structure: {modality_type: {sub_key: shape}}
    """
    shapes: Dict[str, Dict[str, Any]] = {}

    for attr_name in self.__annotations__.keys():
        attr_value = getattr(self, attr_name)

        if isinstance(attr_value, dict):
            # Handle dictionary attributes
            if attr_value is None or len(attr_value) == 0:
                # Empty or None dictionary
                shapes[attr_name] = {}
            else:
                sub_dict = self._get_shape_from_dict(attr_value)
                shapes[attr_name] = sub_dict
        else:
            # Handle non-dictionary attributes
            shape = self._get_single_shape(attr_value)
            shapes[attr_name] = {attr_name: shape}

    return shapes

DataPackageSplitter

Splits DataPackage objects into training, validation, and testing sets.

Supports paired and unpaired (translation) splitting.

Attributes:

Name Type Description
data_package

The original DataPackage to split.

config

The configuration settings for the splitting process.

indices

The indices for each split (train/val/test).

Source code in src/autoencodix/data/_datapackage_splitter.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class DataPackageSplitter:
    """Splits DataPackage objects into training, validation, and testing sets.

    Supports paired and unpaired (translation) splitting.

    Attributes:
        data_package: The original DataPackage to split.
        config: The configuration settings for the splitting process.
        indices: The indices for each split (train/val/test).
    """

    def __init__(
        self,
        data_package: DataPackage,
        config: DefaultConfig,
        indices: Dict[str, Dict[str, Dict[str, np.ndarray]]],
    ) -> None:
        self._data_package = data_package
        self.indices = indices
        self.config = config

        if not isinstance(self._data_package, DataPackage):
            raise TypeError(
                f"Expected data_package to be of type DataPackage, got {type(self._data_package)}"
            )

    def _shallow_copy(self, value: Any) -> Any:
        try:
            return copy.copy(value)
        except AttributeError:
            return value

    def _indexing(self, obj: Any, indices: np.ndarray) -> Any:
        """Indexes pd.DataFrame, list, AnnData, or MuData objects using the provided indices.

        Args:
            obj: The object to index (can be pd.DataFrame, list, AnnData, MuData, or None).
            indices: The indices to use for indexing.
        Returns:
            The indexed object, or None if the input object is None.
        Raises:
            TypeError: If an unsupported type is encountered.
        """

        if obj is None:
            return None
        if isinstance(obj, pd.DataFrame):
            return obj.iloc[indices]
        elif isinstance(obj, list):
            return [obj[i] for i in indices]
        elif isinstance(obj, (AnnData, MuData)):
            # print(f"shape of obj: {obj.shape}")
            # print(f"obj: {obj}")
            # print(f"len(ind): {len(indices)}")
            # print(f"max of index{np.max(indices)}")
            # print(f"ind: {indices}")
            return obj[indices]
        else:
            raise TypeError(
                f"Unsupported type for indexing: {type(obj)}. "
                "Supported types are pd.DataFrame, list, AnnData, and MuData."
            )

    def _split_data_package(self, split: str) -> Optional[DataPackage]:
        """Creates a new DataPackage where each attribute is indexed (if applicable)
        by the given indices. Returns None if indices are empty.

        Args:
            indices: The indices to use for splitting the DataPackage.
        Returns:
            A new DataPackage with attributes indexed by the provided indices,
            or None if indices are empty.
        """
        if len(self.indices) == 0:
            return None

        split_data = {}
        for key, value in self._data_package.__dict__.items():
            if value is None:
                continue
            split_data[key] = {
                modality: self._indexing(data, self.indices[key][modality][split])
                for modality, data in value.items()
            }
        return DataPackage(**split_data)

    def _split_mudata(
        self,
        mudata: MuData,  # ty: ignore[invalid-type-form]
        indices_map: Dict[str, Dict[str, np.ndarray]],
        split: str,
    ) -> MuData:  # ty: ignore[invalid-type-form]
        """Splits a MuData object based on the provided indices map.

        Args:
            mudata: The MuData object to split.
            indices_map: A dictionary mapping modalities to their respective indices.
            split: The split type ("train", "valid", or "test").
        Returns:
            A new MuData object with the specified splits applied.
        """
        for modality, data in mudata.mod.items():
            indices = indices_map[modality][split]
            mudata.mod[modality] = self._indexing(data, indices)
        return mudata

    def _requires_paired(self) -> bool:
        return self.config.requires_paired is None or self.config.requires_paired

    def split(self) -> Dict[str, Optional[Dict[str, Any]]]:
        """Splits the underlying DataPackage into train, valid, and test subsets.
        Returns:
            A dictionary containing the split data packages for "train", "valid", and "test".
            Each entry contains a "data" key with the DataPackage and an "indices" key with
            the corresponding indices.
        Raises:
            ValueError: If no data package is available for splitting.
            TypeError: If indices are not provided for unpaired translation case.
        """
        if self._data_package is None:
            raise ValueError("No data package available for splitting")

        splits = ["train", "valid", "test"]
        result: Dict[str, Optional[Dict[str, Any]]] = {
            "train": {},
            "valid": {},
            "test": {},
        }

        for split in splits:
            if self.indices is None:  # or split not in self.indices:
                result[split] = None
                continue
            result[split] = {
                "data": self._split_data_package(split=split),
                "indices": self.indices,
            }

        return result

split()

Splits the underlying DataPackage into train, valid, and test subsets. Returns: A dictionary containing the split data packages for "train", "valid", and "test". Each entry contains a "data" key with the DataPackage and an "indices" key with the corresponding indices. Raises: ValueError: If no data package is available for splitting. TypeError: If indices are not provided for unpaired translation case.

Source code in src/autoencodix/data/_datapackage_splitter.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def split(self) -> Dict[str, Optional[Dict[str, Any]]]:
    """Splits the underlying DataPackage into train, valid, and test subsets.
    Returns:
        A dictionary containing the split data packages for "train", "valid", and "test".
        Each entry contains a "data" key with the DataPackage and an "indices" key with
        the corresponding indices.
    Raises:
        ValueError: If no data package is available for splitting.
        TypeError: If indices are not provided for unpaired translation case.
    """
    if self._data_package is None:
        raise ValueError("No data package available for splitting")

    splits = ["train", "valid", "test"]
    result: Dict[str, Optional[Dict[str, Any]]] = {
        "train": {},
        "valid": {},
        "test": {},
    }

    for split in splits:
        if self.indices is None:  # or split not in self.indices:
            result[split] = None
            continue
        result[split] = {
            "data": self._split_data_package(split=split),
            "indices": self.indices,
        }

    return result

DataSplitter

Splits data into train, validation, and test sets. And validates the splits.

Also allows for custom splits to be provided. Here we allow empty splits (e.g. test_ratio=0), this might raise an error later in the pipeline, when this split is expected to be non-empty. However, this allows are more flexible usage of the pipeline (e.g. when the user only wants to run the fit step).

Constraints: 1. Split ratios must sum to 1 2. Each non-empty split must have at least min_samples_per_split samples 3. Any split ratio must be <= 1.0 4. Custom splits must contain 'train', 'valid', and 'test' keys and non-overlapping indices

Attributes:

Name Type Description
_config

Configuration object containing split ratios

_custom_splits

Optional pre-defined split indices

Source code in src/autoencodix/data/_datasplitter.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
class DataSplitter:
    """
    Splits data into train, validation, and test sets. And validates the splits.

    Also allows for custom splits to be provided.
    Here we allow empty splits (e.g. test_ratio=0), this might raise an error later
    in the pipeline, when this split is expected to be non-empty. However, this allows
    are more flexible usage of the pipeline (e.g. when the user only wants to run the fit step).

    Constraints:
    1. Split ratios must sum to 1
    2. Each non-empty split must have at least min_samples_per_split samples
    3. Any split ratio must be <= 1.0
    4. Custom splits must contain 'train', 'valid', and 'test' keys and non-overlapping indices

    Attributes:
        _config: Configuration object containing split ratios

        _custom_splits: Optional pre-defined split indices
        _test_ratio
        _valid_ratio

    """

    def __init__(
        self,
        config: DefaultConfig,
        custom_splits: Optional[Dict[str, np.ndarray]] = None,
    ):
        """
        Initialize DataSplitter with configuration and optional custom splits.

        Args:
            config (DefaultConfig): Configuration object containing split ratios
            custom_splits (Optional[Dict[str, np.ndarray]]): Pre-defined split indices
        """
        self._config = config
        self._test_ratio = self._config.test_ratio
        self._valid_ratio = self._config.valid_ratio
        self._train_ratio = self._config.train_ratio
        self._min_samples = self._config.min_samples_per_split
        self._custom_splits = custom_splits

        self._validate_ratios()
        if self._custom_splits:
            self._validate_custom_splits(self._custom_splits)

    def _validate_ratios(self) -> None:
        """
        Validate that the splitting ratios meet required constraints.
        Returns:
            None
        Raises:
            ValueError: If ratios violate constraints

        """
        if not 0 <= self._test_ratio <= 1:
            raise ValueError(
                f"Test ratio must be between 0 and 1, got {self._test_ratio}"
            )
        if not 0 <= self._valid_ratio <= 1:
            raise ValueError(
                f"Validation ratio must be between 0 and 1, got {self._valid_ratio}"
            )
        if not 0 <= self._train_ratio <= 1:
            raise ValueError(
                f"Train ratio must be between 0 and 1, got {self._train_ratio}"
            )

        if np.sum([self._test_ratio, self._valid_ratio, self._train_ratio]) != 1:
            raise ValueError("Split ratios must sum to 1")

    def _validate_split_sizes(self, n_samples: int) -> None:
        """
        Validate that each non-empty split will have sufficient samples.

        Args:
            n_samples: Total number of samples in dataset
        Returns:
            None
        Raises:
            ValueError: If any non-empty split would have too few samples

        """
        # Calculate expected sizes
        n_train = int(n_samples * (1 - self._test_ratio - self._valid_ratio))
        n_valid = int(n_samples * self._valid_ratio) if self._valid_ratio > 0 else 0
        n_test = int(n_samples * self._test_ratio) if self._test_ratio > 0 else 0

        if self._train_ratio > 0 and n_train < self._min_samples:
            raise ValueError(
                f"Training set would have {n_train} samples, "
                f"which is less than minimum required ({self._min_samples})"
            )

        if self._valid_ratio > 0 and n_valid < self._min_samples:
            raise ValueError(
                f"Validation set would have {n_valid} samples, "
                f"which is less than minimum required ({self._min_samples})"
            )

        if self._test_ratio > 0 and n_test < self._min_samples:
            raise ValueError(
                f"Test set would have {n_test} samples, "
                f"which is less than minimum required ({self._min_samples})"
            )

    def _validate_custom_splits(self, splits: Dict[str, np.ndarray]) -> None:
        """
        Validate custom splits for correctness.

        Args:
            splits: Custom split indices
        Returns:
            None
        Raises:
            ValueError: If custom splits violate constraints

        """
        required_keys = {"train", "valid", "test"}
        if not all(key in splits for key in required_keys):
            raise ValueError(
                f"Custom splits must contain all of: {required_keys} \ Got: {splits.keys()} \ if you want to pass empty splits, pass an empty array"
            )

        # check for index out of bounds
        if len(splits["train"]) < self._min_samples:
            raise ValueError(
                f"Custom training split has {len(splits['train'])} samples, "
                f"which is less than minimum required ({self._min_samples})"
            )

        # For non-empty validation and test splits, check minimum size
        if len(splits["valid"]) > 0 and len(splits["valid"]) < self._min_samples:
            raise ValueError(
                f"Custom validation split has {len(splits['valid'])} samples, "
                f"which is less than minimum required ({self._min_samples})"
            )

        if len(splits["test"]) > 0 and len(splits["test"]) < self._min_samples:
            raise ValueError(
                f"Custom test split has {len(splits['test'])} samples, "
                f"which is less than minimum required ({self._min_samples})"
            )

        # Check for overlap between splits
        for k1, k2 in itertools.combinations(required_keys, 2):
            intersection = set(splits[k1]) & set(splits[k2])
            if intersection:
                raise ValueError(
                    f"Overlapping indices found between splits '{k1}' and '{k2}': {intersection}"
                )

    def split(
        self,
        n_samples: int,
    ) -> Dict[str, np.ndarray]:
        """
        Split data into train, validation, and test sets.

        Args:
            n_samples: Total number of samples in the dataset

        Returns:
            Dictionary containing indices for each split, with empty arrays for splits with ratio=0

        Raises:
            ValueError: If resulting splits would violate size constraints
        """
        self._validate_split_sizes(n_samples)
        indices = np.arange(n_samples)

        if self._custom_splits:
            max_index = n_samples - 1
            for split in self._custom_splits.values():
                if len(split) > 0:
                    if np.max(split) > max_index:
                        raise AssertionError(
                            f"Custom split indices must be within range [0, {max_index}]"
                        )
                    elif np.min(split) < 0:
                        raise AssertionError(
                            f"Custom split indices must be within range [0, {max_index}]"
                        )
            return self._custom_splits

        # all three 0 case already handled in _validate_ratios (sum to 1)
        if self._test_ratio == 0 and self._valid_ratio == 0:
            return {
                "train": indices,
                "valid": np.array([], dtype=int),
                "test": np.array([], dtype=int),
            }
        if self._train_ratio == 0 and self._valid_ratio == 0:
            return {
                "train": np.array([], dtype=int),
                "valid": np.array([], dtype=int),
                "test": indices,
            }
        if self._train_ratio == 0 and self._test_ratio == 0:
            return {
                "train": np.array([], dtype=int),
                "valid": indices,
                "test": np.array([], dtype=int),
            }

        if self._train_ratio == 0:
            valid_indices, test_indices = train_test_split(
                indices,
                test_size=self._test_ratio,
                random_state=self._config.global_seed,
            )
            return {
                "train": np.array([], dtype=int),
                "valid": valid_indices,
                "test": test_indices,
            }

        if self._test_ratio == 0:
            train_indices, valid_indices = train_test_split(
                indices,
                test_size=self._valid_ratio,
                random_state=self._config.global_seed,
            )
            return {
                "train": train_indices,
                "valid": valid_indices,
                "test": np.array([], dtype=int),
            }

        if self._valid_ratio == 0:
            train_indices, test_indices = train_test_split(
                indices,
                test_size=self._test_ratio,
                random_state=self._config.global_seed,
            )
            return {
                "train": train_indices,
                "valid": np.array([], dtype=int),
                "test": test_indices,
            }

        # Normal case: split into all three sets
        train_valid_indices, test_indices = train_test_split(
            indices, test_size=self._test_ratio, random_state=self._config.global_seed
        )

        train_indices, valid_indices = train_test_split(
            train_valid_indices,
            test_size=self._valid_ratio / (1 - self._test_ratio),
            random_state=self._config.global_seed,
        )

        return {"train": train_indices, "valid": valid_indices, "test": test_indices}

__init__(config, custom_splits=None)

Initialize DataSplitter with configuration and optional custom splits.

Parameters:

Name Type Description Default
config DefaultConfig

Configuration object containing split ratios

required
custom_splits Optional[Dict[str, ndarray]]

Pre-defined split indices

None
Source code in src/autoencodix/data/_datasplitter.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(
    self,
    config: DefaultConfig,
    custom_splits: Optional[Dict[str, np.ndarray]] = None,
):
    """
    Initialize DataSplitter with configuration and optional custom splits.

    Args:
        config (DefaultConfig): Configuration object containing split ratios
        custom_splits (Optional[Dict[str, np.ndarray]]): Pre-defined split indices
    """
    self._config = config
    self._test_ratio = self._config.test_ratio
    self._valid_ratio = self._config.valid_ratio
    self._train_ratio = self._config.train_ratio
    self._min_samples = self._config.min_samples_per_split
    self._custom_splits = custom_splits

    self._validate_ratios()
    if self._custom_splits:
        self._validate_custom_splits(self._custom_splits)

split(n_samples)

Split data into train, validation, and test sets.

Parameters:

Name Type Description Default
n_samples int

Total number of samples in the dataset

required

Returns:

Type Description
Dict[str, ndarray]

Dictionary containing indices for each split, with empty arrays for splits with ratio=0

Raises:

Type Description
ValueError

If resulting splits would violate size constraints

Source code in src/autoencodix/data/_datasplitter.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def split(
    self,
    n_samples: int,
) -> Dict[str, np.ndarray]:
    """
    Split data into train, validation, and test sets.

    Args:
        n_samples: Total number of samples in the dataset

    Returns:
        Dictionary containing indices for each split, with empty arrays for splits with ratio=0

    Raises:
        ValueError: If resulting splits would violate size constraints
    """
    self._validate_split_sizes(n_samples)
    indices = np.arange(n_samples)

    if self._custom_splits:
        max_index = n_samples - 1
        for split in self._custom_splits.values():
            if len(split) > 0:
                if np.max(split) > max_index:
                    raise AssertionError(
                        f"Custom split indices must be within range [0, {max_index}]"
                    )
                elif np.min(split) < 0:
                    raise AssertionError(
                        f"Custom split indices must be within range [0, {max_index}]"
                    )
        return self._custom_splits

    # all three 0 case already handled in _validate_ratios (sum to 1)
    if self._test_ratio == 0 and self._valid_ratio == 0:
        return {
            "train": indices,
            "valid": np.array([], dtype=int),
            "test": np.array([], dtype=int),
        }
    if self._train_ratio == 0 and self._valid_ratio == 0:
        return {
            "train": np.array([], dtype=int),
            "valid": np.array([], dtype=int),
            "test": indices,
        }
    if self._train_ratio == 0 and self._test_ratio == 0:
        return {
            "train": np.array([], dtype=int),
            "valid": indices,
            "test": np.array([], dtype=int),
        }

    if self._train_ratio == 0:
        valid_indices, test_indices = train_test_split(
            indices,
            test_size=self._test_ratio,
            random_state=self._config.global_seed,
        )
        return {
            "train": np.array([], dtype=int),
            "valid": valid_indices,
            "test": test_indices,
        }

    if self._test_ratio == 0:
        train_indices, valid_indices = train_test_split(
            indices,
            test_size=self._valid_ratio,
            random_state=self._config.global_seed,
        )
        return {
            "train": train_indices,
            "valid": valid_indices,
            "test": np.array([], dtype=int),
        }

    if self._valid_ratio == 0:
        train_indices, test_indices = train_test_split(
            indices,
            test_size=self._test_ratio,
            random_state=self._config.global_seed,
        )
        return {
            "train": train_indices,
            "valid": np.array([], dtype=int),
            "test": test_indices,
        }

    # Normal case: split into all three sets
    train_valid_indices, test_indices = train_test_split(
        indices, test_size=self._test_ratio, random_state=self._config.global_seed
    )

    train_indices, valid_indices = train_test_split(
        train_valid_indices,
        test_size=self._valid_ratio / (1 - self._test_ratio),
        random_state=self._config.global_seed,
    )

    return {"train": train_indices, "valid": valid_indices, "test": test_indices}

DatasetContainer dataclass

A container for datasets used in training, validation, and testing.

train : Dataset The training dataset. valid : Dataset The validation dataset. test : Dataset The testing dataset.

Source code in src/autoencodix/data/_datasetcontainer.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
@dataclass
class DatasetContainer:
    """A container for datasets used in training, validation, and testing.

    Attributes:
    train : Dataset
        The training dataset.
    valid : Dataset
        The validation dataset.
    test : Dataset
        The testing dataset.
    """

    train: Optional[Union[BaseDataset, None]] = None
    valid: Optional[Union[BaseDataset, None]] = None
    test: Optional[Union[BaseDataset, None]] = None

    def __getitem__(self, key: str) -> BaseDataset:
        """Allows dictionary-like access to datasets."""
        if key not in {"train", "valid", "test"}:
            raise KeyError(f"Invalid key: {key}. Must be 'train', 'valid', or 'test'.")
        return getattr(self, key)

    def __setitem__(self, key: str, value: BaseDataset):
        """Allows dictionary-like assignment of datasets."""
        if key not in {"train", "valid", "test"}:
            raise KeyError(f"Invalid key: {key}. Must be 'train', 'valid', or 'test'.")
        setattr(self, key, value)

__getitem__(key)

Allows dictionary-like access to datasets.

Source code in src/autoencodix/data/_datasetcontainer.py
23
24
25
26
27
def __getitem__(self, key: str) -> BaseDataset:
    """Allows dictionary-like access to datasets."""
    if key not in {"train", "valid", "test"}:
        raise KeyError(f"Invalid key: {key}. Must be 'train', 'valid', or 'test'.")
    return getattr(self, key)

__setitem__(key, value)

Allows dictionary-like assignment of datasets.

Source code in src/autoencodix/data/_datasetcontainer.py
29
30
31
32
33
def __setitem__(self, key: str, value: BaseDataset):
    """Allows dictionary-like assignment of datasets."""
    if key not in {"train", "valid", "test"}:
        raise KeyError(f"Invalid key: {key}. Must be 'train', 'valid', or 'test'.")
    setattr(self, key, value)

GeneralPreprocessor

Bases: BasePreprocessor

Preprocessor for handling multi-modal data.

Attributes:

Name Type Description
_datapackage_dict Optional[Dict[str, Any]]

Dictionary holding DataPackage objects for each data split.

_dataset_container Optional[DatasetContainer]

Container holding processed datasets for each split.

_reverse_mapping_multi_bulk Dict[str, Dict[str, Tuple[List[int], List[str]]]]

Reverse mapping for multi-bulk data reconstruction.

_reverse_mapping_multi_sc Dict[str, Dict[str, Tuple[List[int], List[str]]]]

Reverse mapping for multi-single-cell data reconstruction.

Source code in src/autoencodix/data/general_preprocessor.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
class GeneralPreprocessor(BasePreprocessor):
    """Preprocessor for handling multi-modal data.

    Attributes:
        _datapackage_dict: Dictionary holding DataPackage objects for each data split.
        _dataset_container: Container holding processed datasets for each split.
        _reverse_mapping_multi_bulk: Reverse mapping for multi-bulk data reconstruction.
        _reverse_mapping_multi_sc: Reverse mapping for multi-single-cell data reconstruction.

    """

    def __init__(
        self, config: DefaultConfig, ontologies: Optional[Union[Tuple, Dict]] = None
    ) -> None:
        super().__init__(config=config, ontologies=ontologies)
        self._datapackage_dict: Optional[Dict[str, Any]] = None
        self._dataset_container: Optional[DatasetContainer] = None
        # Reverse mappings for reconstruction
        self._reverse_mapping_multi_bulk: Dict[
            str, Dict[str, Tuple[List[int], List[str]]]
        ] = {"train": {}, "test": {}, "valid": {}}
        self._reverse_mapping_multi_sc: Dict[
            str, Dict[str, Tuple[List[int], List[str]]]
        ] = {"train": {}, "test": {}, "valid": {}}

    def _combine_layers(
        self, modality_name: str, modality_data: Any
    ) -> List[np.ndarray]:
        layer_list: List[np.ndarray] = []
        selected_layers = self.config.data_config.data_info[
            modality_name
        ].selected_layers
        for layer_name in selected_layers:
            if layer_name == "X":
                layer_list.append(modality_data.X)
            elif layer_name in modality_data.layers:
                layer_list.append(modality_data.layers[layer_name])
        return layer_list

    def _combine_modality_data(
        self,
        mudata: md.MuData,  # ty: ignore[invalid-type-form]
    ) -> Union[np.ndarray, sp.sparse.spmatrix]:  # ty: ignore[invalid-type-form]
        # Reset single-cell reverse mapping
        modality_data_list: List[np.ndarray] = []
        start_idx = 0

        for modality_name, modality_data in mudata.mod.items():
            self._reverse_mapping_multi_sc[self._split][modality_name] = {}
            layers = self.config.data_config.data_info[modality_name].selected_layers
            for layer_name in layers:
                if layer_name == "X":
                    n_feats = modality_data.shape[1]
                else:
                    n_feats = modality_data.layers[layer_name].shape[1]

                end_idx = start_idx + n_feats
                feature_ids = modality_data.var_names.tolist()
                self._reverse_mapping_multi_sc[self._split][modality_name][
                    layer_name
                ] = (
                    list(range(start_idx, end_idx)),
                    feature_ids,
                )
                start_idx = end_idx

            combined_layers = self._combine_layers(
                modality_name=modality_name, modality_data=modality_data
            )
            modality_data_list.extend(combined_layers)
        all_sparse = all(issparse(arr) for arr in modality_data_list)
        if all_sparse:
            combined = sp.sparse.hstack(modality_data_list, format="csr")
        else:
            dense_layers = [
                arr.toarray() if issparse(arr) else arr  # ty: ignore
                for arr in modality_data_list
            ]
            combined = np.concatenate(dense_layers, axis=1)

        return combined

    def _create_numeric_dataset(
        self,
        data: Union[np.ndarray, sp.sparse.spmatrix],
        config: DefaultConfig,
        split_ids: np.ndarray,
        metadata: pd.DataFrame,
        ids: List[str],
        feature_ids: List[str],
    ) -> NumericDataset:
        # keep sparse data sparse until batch level in training for memory efficency
        ds = NumericDataset(
            data=data,
            config=config,
            split_indices=split_ids,
            metadata=metadata,
            sample_ids=ids,
            feature_ids=feature_ids,
        )
        return ds

    def _process_data_package(self, data_dict: Dict[str, Any]) -> BaseDataset:
        data, split_ids = data_dict["data"], data_dict["indices"]
        # MULTI-BULK
        if data.multi_bulk is not None:
            # reset bulk mapping
            metadata = data.annotation
            bulk_dict: Dict[str, pd.DataFrame] = data.multi_bulk

            # Check if all DataFrames have the same number of samples
            sample_counts = {}
            for subkey, df in bulk_dict.items():
                if not isinstance(df, pd.DataFrame):
                    raise ValueError(
                        f"Expected a DataFrame for '{subkey}', got {type(df)}"
                    )
                sample_counts[subkey] = df.shape[0]
                # print(f"cur shape: {subkey}: {df.shape}")

            # Validate all modalities have the same number of samples
            unique_sample_counts = set(sample_counts.values())
            if len(unique_sample_counts) > 1:
                sample_count_str = ", ".join(
                    [f"{k}: {v} samples" for k, v in sample_counts.items()]
                )
                raise NotImplementedError(
                    f"Different sample counts across modalities are not currently supported for Varix and Vanillix"
                    "Set requires_pared=True in config."
                    f"Found: {sample_count_str}. All modalities must have the same number of samples."
                )

            combined_cols: List[str] = []
            start_idx = 0
            for subkey, df in bulk_dict.items():
                n_feats = df.shape[1]
                end_idx = start_idx + n_feats
                self._reverse_mapping_multi_bulk[self._split][subkey] = (
                    list(range(start_idx, end_idx)),
                    df.columns.tolist(),
                )
                combined_cols.extend(df.columns.tolist())
                start_idx = end_idx

            combined_df = pd.concat(bulk_dict.values(), axis=1)
            return self._create_numeric_dataset(
                data=combined_df.values,
                config=self.config,
                split_ids=split_ids,
                metadata=metadata,
                ids=combined_df.index.tolist(),
                feature_ids=combined_cols,
            )
        # MULTI-SINGLE-CELL
        elif data.multi_sc is not None:
            # reset single-cell mapping
            mudata: md.MuData = data.multi_sc.get(  # ty: ignore[invalid-type-form]
                "multi_sc", None
            )  # ty: ignore[invalid-type-form]
            if mudata is None:
                raise NotImplementedError(
                    "Unpaired multi Single Cell case not implemented vor Varix and Vanillix, set requires_paired=True in config"
                )
            combined_data = self._combine_modality_data(mudata)

            # collect feature IDs in concatenation order
            feature_ids: List[str] = []
            for layers in self._reverse_mapping_multi_sc[self._split].values():
                for _, fids in layers.values():
                    feature_ids.extend(fids)
            return self._create_numeric_dataset(
                data=combined_data,
                config=self.config,
                split_ids=split_ids,
                metadata=mudata.obs,
                ids=mudata.obs_names.tolist(),
                feature_ids=feature_ids,
            )
        else:
            raise NotImplementedError(
                "GeneralPreprocessor only handles multi_bulk or multi_sc."
            )

    def preprocess(
        self,
        raw_user_data: Optional[DataPackage] = None,
        predict_new_data: bool = False,
    ) -> DatasetContainer:
        # run common preprocessing

        # self._reverse_mapping_multi_bulk.clear()
        # self._reverse_mapping_multi_sc.clear()

        self._datapackage_dict = self._general_preprocess(
            raw_user_data=raw_user_data, predict_new_data=predict_new_data
        )
        if self._datapackage_dict is None:
            raise TypeError("Datapackage cannot be None")

        # prepare container
        ds_container: DatasetContainer = DatasetContainer()

        for split in ["train", "test", "valid"]:
            split_data = self._datapackage_dict.get(split)
            self._split = split
            if not split_data or split_data["data"] is None:
                ds_container[split] = None  # type: ignore
                continue
            ds = self._process_data_package(split_data)
            ds_container[split] = ds
        self._dataset_container = ds_container
        return ds_container

    def format_reconstruction(
        self, reconstruction: torch.Tensor, result: Optional[Result] = None
    ) -> DataPackage:
        self._split = self._match_split(n_samples=reconstruction.shape[0])
        if self.config.data_case == DataCase.MULTI_BULK:
            return self._reverse_multi_bulk(reconstruction)
        elif self.config.data_case == DataCase.MULTI_SINGLE_CELL:
            return self._reverse_multi_sc(reconstruction)
        else:
            raise NotImplementedError(
                f"Reconstruction not implemented for {self.config.data_case}"
            )

    def _match_split(self, n_samples: int) -> str:
        """Match the split based on the number of samples."""
        print(f"n_samples in format recon: {n_samples}")
        for split, split_data in self._datapackage_dict.items():
            print(split)
            data = split_data.get("data")
            if data is None:
                continue
            ref_n = data.get_n_samples()["paired_count"]
            print(f"n_samples from datatpackge: {ref_n}")
            if n_samples == data.get_n_samples()["paired_count"]["paired_count"]:
                return split
        raise ValueError(
            f"Cannot find matching split for {n_samples} samples in the dataset."
        )

    def _reverse_multi_bulk(
        self, reconstruction: Union[np.ndarray, torch.Tensor]
    ) -> DataPackage:
        data_package = DataPackage(
            multi_bulk={},
            multi_sc=None,
            annotation=None,
            img=None,
            from_modality=None,
            to_modality=None,
        )
        # reconstruct each bulk subkey
        dfs: Dict[str, pd.DataFrame] = {}
        for subkey, (indices, fids) in self._reverse_mapping_multi_bulk[
            self._split
        ].items():
            arr = self._slice_tensor(
                reconstruction=reconstruction,
                indices=indices,
            )
            dfs[subkey] = pd.DataFrame(
                arr,
                columns=fids,
                index=self._dataset_container[self._split].sample_ids,
            )
        data_package.annotation = self._dataset_container[self._split].metadata

        data_package.multi_bulk = dfs
        return data_package

    def _slice_tensor(
        self, reconstruction: Union[np.ndarray, torch.Tensor], indices: List[int]
    ) -> np.ndarray:
        if isinstance(reconstruction, torch.Tensor):
            arr = reconstruction[:, indices].detach().cpu().numpy()
        elif isinstance(reconstruction, np.ndarray):
            arr = reconstruction[:, indices]
        else:
            raise TypeError(
                f"Expected reconstruction to be a torch.Tensor or np.ndarray, got {type(reconstruction)}"
            )
        return arr

    def _reverse_multi_sc(self, reconstruction: torch.Tensor) -> DataPackage:
        data_package = DataPackage(
            multi_bulk=None,
            multi_sc=None,
            annotation=None,
            img=None,
            from_modality=None,
            to_modality=None,
        )
        modalities: Dict[str, AnnData] = {}

        for modality_name, layers in self._reverse_mapping_multi_sc[
            self._split
        ].items():
            # rebuild each layer as DataFrame
            layers_dict: Dict[str, pd.DataFrame] = {}
            for layer_name, (indices, fids) in layers.items():
                arr = self._slice_tensor(reconstruction=reconstruction, indices=indices)
                layers_dict[layer_name] = pd.DataFrame(
                    arr,
                    columns=fids,
                    index=self._dataset_container[self._split].sample_ids,
                )

            # extract X layer for AnnData var
            feature_ids = layers.get("X", (None, []))[1]
            var = pd.DataFrame(index=feature_ids)
            X_df = layers_dict.pop("X", None)
            adata = AnnData(
                X=X_df.values if X_df is not None else None,
                obs=self._dataset_container[self._split].metadata,
                var=var,
                layers={k: v.values for k, v in layers_dict.items()},
            )
            modalities[modality_name] = adata

        data_package.multi_sc = {"multi_sc": md.MuData(modalities)}
        data_package.annotation = self._dataset_container[self._split].metadata
        return data_package

ImageDataset

Bases: TensorAwareDataset

A custom PyTorch dataset that handles image data with proper dtype conversion.

Attributes:

Name Type Description
raw_data

List of ImgData objects containing original image data and metadata.

config

Configuration object for dataset settings.

mytype

Enum indicating the dataset type (set to DataSetTypes.IMG).

data

List of image tensors converted to the appropriate dtype.

sample_ids

List of identifiers for each sample.

split_indices

Optional numpy array of indices for splitting the dataset.

feature_ids

Optional list of identifiers for each feature (set to None for images).

metadata

Optional pandas DataFrame containing additional metadata.

Source code in src/autoencodix/data/_image_dataset.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class ImageDataset(TensorAwareDataset):
    """
    A custom PyTorch dataset that handles image data with proper dtype conversion.


    Attributes:
        raw_data: List of ImgData objects containing original image data and metadata.
        config: Configuration object for dataset settings.
        mytype: Enum indicating the dataset type (set to DataSetTypes.IMG).
        data: List of image tensors converted to the appropriate dtype.
        sample_ids: List of identifiers for each sample.
        split_indices: Optional numpy array of indices for splitting the dataset.
        feature_ids: Optional list of identifiers for each feature (set to None for images).
        metadata: Optional pandas DataFrame containing additional metadata.
    """

    def __init__(
        self,
        data: List[ImgData],
        config: DefaultConfig,
        split_indices: Optional[Dict[str, Any]] = None,
        metadata: Optional[pd.DataFrame] = None,
    ):
        """
        Initialize the dataset
        Args:
            data: List of image data objects
            config: Configuration object
        """
        self.raw_data = data  # image data before conversion to keep original infos
        self.config = config
        self.mytype = DataSetTypes.IMG

        if self.config is None:
            raise ValueError("config cannot be None")

        # Convert all images to tensors with proper dtype once during initialization
        target_dtype = self._get_target_dtype()
        self.data = self._convert_all_images_to_tensors(target_dtype)

        # Extract sample_ids for consistency
        self.sample_ids = [img_data.sample_id for img_data in data]

        self.split_indices = split_indices
        self.feature_ids = None
        self.metadata = metadata

    def _convert_all_images_to_tensors(self, dtype: torch.dtype) -> List[torch.Tensor]:
        """
        Convert all images to tensors with specified dtype during initialization.

        Args:
            dtype: Target dtype for the tensors

        Returns:
            List of converted image tensors
        """
        print(f"Converting {len(self.raw_data)} images to {dtype} tensors...")
        converted_data = []

        for img_data in self.raw_data:
            tensor = self._to_tensor(img_data.img, dtype)
            converted_data.append(tensor)

        return converted_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """Get item at index - data is already converted to proper dtype
        Returns:
            Tuple of (index, image tensor, sample_id)
        """
        return idx, self.data[idx], self.sample_ids[idx]

    def get_input_dim(self) -> Tuple[int, ...]:
        """
        Gets the input dimension of the dataset's feature space.

        Returns:
            The input dimension of the dataset's feature space
        """
        return self.data[0].shape  # All images should have the same shape

__getitem__(idx)

Get item at index - data is already converted to proper dtype Returns: Tuple of (index, image tensor, sample_id)

Source code in src/autoencodix/data/_image_dataset.py
79
80
81
82
83
84
def __getitem__(self, idx):
    """Get item at index - data is already converted to proper dtype
    Returns:
        Tuple of (index, image tensor, sample_id)
    """
    return idx, self.data[idx], self.sample_ids[idx]

__init__(data, config, split_indices=None, metadata=None)

Initialize the dataset Args: data: List of image data objects config: Configuration object

Source code in src/autoencodix/data/_image_dataset.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(
    self,
    data: List[ImgData],
    config: DefaultConfig,
    split_indices: Optional[Dict[str, Any]] = None,
    metadata: Optional[pd.DataFrame] = None,
):
    """
    Initialize the dataset
    Args:
        data: List of image data objects
        config: Configuration object
    """
    self.raw_data = data  # image data before conversion to keep original infos
    self.config = config
    self.mytype = DataSetTypes.IMG

    if self.config is None:
        raise ValueError("config cannot be None")

    # Convert all images to tensors with proper dtype once during initialization
    target_dtype = self._get_target_dtype()
    self.data = self._convert_all_images_to_tensors(target_dtype)

    # Extract sample_ids for consistency
    self.sample_ids = [img_data.sample_id for img_data in data]

    self.split_indices = split_indices
    self.feature_ids = None
    self.metadata = metadata

get_input_dim()

Gets the input dimension of the dataset's feature space.

Returns:

Type Description
Tuple[int, ...]

The input dimension of the dataset's feature space

Source code in src/autoencodix/data/_image_dataset.py
86
87
88
89
90
91
92
93
def get_input_dim(self) -> Tuple[int, ...]:
    """
    Gets the input dimension of the dataset's feature space.

    Returns:
        The input dimension of the dataset's feature space
    """
    return self.data[0].shape  # All images should have the same shape

ImagePreprocessor

Bases: GeneralPreprocessor

Preprocessor for cross-modal data, handling multiple data types and their transformations.

Attributes:

Name Type Description
data_config

Configuration specific to data handling and preprocessing.

dataset_dicts

Dictionary holding datasets for different splits (train/test/valid).

Source code in src/autoencodix/data/_image_processor.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class ImagePreprocessor(GeneralPreprocessor):
    """
    Preprocessor for cross-modal data, handling multiple data types and their transformations.


    Attributes:
        data_config: Configuration specific to data handling and preprocessing.
        dataset_dicts: Dictionary holding datasets for different splits (train/test/valid).
    """

    def __init__(
        self, config: DefaultConfig, ontologies: Optional[Union[Tuple, Dict]] = None
    ):
        super().__init__(config=config, ontologies=ontologies)
        self.data_config = config.data_config

    def preprocess(
        self,
        raw_user_data: Optional[DataPackage] = None,
        predict_new_data: bool = False,
    ) -> DatasetContainer:
        """
        Preprocess the data according to the configuration.

        Args:
            raw_user_data: The raw data package provided by the user.
            predict_new_data: Flag indicating if new data is being predicted.
        Returns:
            A DatasetContainer with processed training, validation, and test datasets.
        """
        self.dataset_dicts = self._general_preprocess(
            raw_user_data=raw_user_data, predict_new_data=predict_new_data
        )
        datasets = {}
        for split in ["train", "test", "valid"]:
            cur_split = self.dataset_dicts.get(split)
            if cur_split is None:
                print(f"split is None: {split}")
                continue
            cur_data = cur_split.get("data")
            if not isinstance(cur_data, DataPackage):
                raise TypeError(
                    f"expected type of cur_data to be DataPackage, got {type(cur_data)}"
                )
            cur_indices = cur_split.get("indices")
            datasets[split] = self._process_dp(dp=cur_data, indices=cur_indices)

        return DatasetContainer(
            train=datasets["train"], test=datasets["test"], valid=datasets["valid"]
        )

    def _process_dp(self, dp: DataPackage, indices: Dict[str, Any]) -> ImageDataset:
        if dp.img is None:
            raise ValueError("no img attribute found in datapackage")
        first_key = next(iter(list(dp.img.keys())))
        if not isinstance(dp.img, dict):
            raise TypeError(
                f"Expected `img` attribute of DataPackage to be `dict`, got {type(dp.img)}"
            )
        if len(dp.img.keys()) > 1:
            import warnings

            warnings.warn(
                f"got multiple image datasets for Imagix: {dp.img.keys()},\
                          we only support a single image dataset in this case, using: {first_key}"
            )
        if dp.annotation is None:
            metadata = None
        else:
            metadata = dp.annotation.get(first_key)
            if metadata is None:
                metadata = dp.annotation.get("paired")
        data = dp.img[first_key]
        return ImageDataset(
            data=data,
            config=self.config,
            split_indices=indices,
            metadata=metadata,
        )

preprocess(raw_user_data=None, predict_new_data=False)

Preprocess the data according to the configuration.

Parameters:

Name Type Description Default
raw_user_data Optional[DataPackage]

The raw data package provided by the user.

None
predict_new_data bool

Flag indicating if new data is being predicted.

False

Returns: A DatasetContainer with processed training, validation, and test datasets.

Source code in src/autoencodix/data/_image_processor.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def preprocess(
    self,
    raw_user_data: Optional[DataPackage] = None,
    predict_new_data: bool = False,
) -> DatasetContainer:
    """
    Preprocess the data according to the configuration.

    Args:
        raw_user_data: The raw data package provided by the user.
        predict_new_data: Flag indicating if new data is being predicted.
    Returns:
        A DatasetContainer with processed training, validation, and test datasets.
    """
    self.dataset_dicts = self._general_preprocess(
        raw_user_data=raw_user_data, predict_new_data=predict_new_data
    )
    datasets = {}
    for split in ["train", "test", "valid"]:
        cur_split = self.dataset_dicts.get(split)
        if cur_split is None:
            print(f"split is None: {split}")
            continue
        cur_data = cur_split.get("data")
        if not isinstance(cur_data, DataPackage):
            raise TypeError(
                f"expected type of cur_data to be DataPackage, got {type(cur_data)}"
            )
        cur_indices = cur_split.get("indices")
        datasets[split] = self._process_dp(dp=cur_data, indices=cur_indices)

    return DatasetContainer(
        train=datasets["train"], test=datasets["test"], valid=datasets["valid"]
    )

ImgData dataclass

Stores image data along with its associated metadata.

Attributes:

Name Type Description
img

The image data as a NumPy array.

sample_id str

A unique identifier for the image sample.

annotation Union[Series, DataFrame]

A DataFrame containing annotations or metadata related to the image.

Source code in src/autoencodix/data/_imgdataclass.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
@dataclass
class ImgData:
    """Stores image data along with its associated metadata.

    Attributes:
        img : The image data as a NumPy array.
        sample_id: A unique identifier for the image sample.
        annotation: A DataFrame containing annotations or metadata related to the image.
    """

    img: np.ndarray
    sample_id: str
    annotation: Union[pd.Series, pd.DataFrame]

    def __repr__(self):
        return (
            f"ImgData(\n"
            f"    sample_id={self.sample_id!r},\n"
            f"    img_shape={self.img.shape},\n"
            f"    annotation_shape={self.annotation.shape}\n"
            f" .  img: actual image data is not shown for brevity, use img attribute to access it"
            f")"
        )

MultiModalDataset

Bases: BaseDataset, Dataset

Handles multiple datasets of different modalities.

Attributes:

Name Type Description
datasets

Dictionary of datasets for each modality.

n_modalities

Number of modalities.

sample_to_modalities

Mapping from sample IDs to available modalities.

sample_ids List[Any]

List of all unique sample IDs across modalities.

config

Configuration object.

data

Data from the first modality (for compatibility).

feature_ids

Feature IDs (currently None, to be implemented).

_id_to_idx

Reverse lookup tables for sample IDs to indices per modality.

paired_sample_ids

List of sample IDs that have data in all modalities.

unpaired_sample_ids

List of sample IDs that do not have data in all modalities.

Source code in src/autoencodix/data/_multimodal_dataset.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class MultiModalDataset(BaseDataset, torch.utils.data.Dataset):  # type: ignore
    """Handles multiple datasets of different modalities.

    Attributes:
        datasets: Dictionary of datasets for each modality.
        n_modalities: Number of modalities.
        sample_to_modalities: Mapping from sample IDs to available modalities.
        sample_ids: List of all unique sample IDs across modalities.
        config: Configuration object.
        data: Data from the first modality (for compatibility).
        feature_ids: Feature IDs (currently None, to be implemented).
        _id_to_idx: Reverse lookup tables for sample IDs to indices per modality.
        paired_sample_ids: List of sample IDs that have data in all modalities.
        unpaired_sample_ids: List of sample IDs that do not have data in all modalities.
    """

    def __init__(self, datasets: Dict[str, BaseDataset], config: DefaultConfig):
        """
        Initialize the MultiModalDataset.

        Args:
            datasets: Dictionary of datasets for each modality.
            config: Configuration object.
        """
        self.datasets = datasets
        self.modalities = list(datasets.keys())
        self.n_modalities = len(self.datasets.keys())
        self.sample_to_modalities = self._build_sample_map()
        self.sample_ids: List[Any] = list(self.sample_to_modalities.keys())
        self.config = config
        self.data = next(iter(self.datasets.values())).data
        self.feature_ids = None  # TODO

        # Build reverse lookup tables once
        for ds_name, ds in self.datasets.items():
            if ds.sample_ids is None:
                raise ValueError(f"There are no sample_ids for {ds_name}")
        self._id_to_idx = {
            mod: {sid: idx for idx, sid in enumerate(ds.sample_ids)}  # type: ignore
            for mod, ds in self.datasets.items()
        }
        self.paired_sample_ids = self._get_paired_sample_ids()
        self.unpaired_sample_ids = list(
            set(self.sample_ids) - set(self.paired_sample_ids)
        )

    def _to_df(self, modality: Optional[str] = None) -> pd.DataFrame:
        """Convert the dataset to a pandas DataFrame.

        Returns:
            DataFrame representation of the dataset
        """
        if modality is None:
            all_modality = list(self.datasets.keys())
        else:
            all_modality = [modality]

        df_all = pd.DataFrame()
        for modality in all_modality:
            if modality not in self.datasets:
                raise ValueError(f"Unknown modality: {modality}")

            ds = self.datasets[modality]
            if isinstance(ds.data, torch.Tensor):
                df = pd.DataFrame(
                    ds.data.numpy(), columns=ds.feature_ids, index=ds.sample_ids
                )
            elif isinstance(ds.data, list):
                # Handle image modality
                # Get the list of tensors
                tensor_list = self.datasets[modality].data
                if not isinstance(tensor_list[0], torch.Tensor):
                    raise TypeError(
                        f" Image List is not a List[torch.Tensor], but a {type(tensor_list[0])} and cannot be converted to DataFrame."
                    )

                rows = [
                    (
                        t.flatten().cpu().numpy()
                        if isinstance(t, torch.Tensor)
                        else t.flatten()
                    )
                    for t in tensor_list
                ]

                df = pd.DataFrame(
                    rows,
                    index=ds.sample_ids,
                    columns=["Pixel_" + str(i) for i in range(len(rows[0]))],
                )
            else:
                raise TypeError(
                    f"Data is not a torch.Tensor or image data, but a {type(ds.data)} and cannot be converted to DataFrame."
                )

            df = df.add_prefix(f"{modality}_")
            if df_all.empty:
                df_all = df
            else:
                df_all = pd.concat([df_all, df], axis=1, join="inner")

        return df_all

    def _build_sample_map(self):
        sample_to_mods = {}
        for modality, dataset in self.datasets.items():
            for sid in dataset.sample_ids:
                sample_to_mods.setdefault(sid, set()).add(modality)
        return sample_to_mods

    def _get_paired_sample_ids(self):
        return [
            sid
            for sid, mods in self.sample_to_modalities.items()
            if all(mod in mods for mod in self.datasets.keys())
        ]

    def __len__(self):
        return len(self.paired_sample_ids)

    def __getitem__(self, idx: Union[int, str]):
        sid = self.paired_sample_ids[idx] if isinstance(idx, int) else idx
        out = {"sample_id": sid}
        for mod in self.modalities:
            if sid not in self._id_to_idx[mod]:  # missing modality
                out[mod] = None
                continue
            _, data, _ = self.datasets[mod][self._id_to_idx[mod][sid]]
            out[mod] = data
        return out

    @property
    def is_fully_paired(self) -> bool:
        """Returns True if all samples are fully paired across all modalities (no unpaired samples)."""

        return len(self.unpaired_sample_ids) == 0

is_fully_paired property

Returns True if all samples are fully paired across all modalities (no unpaired samples).

__init__(datasets, config)

Initialize the MultiModalDataset.

Parameters:

Name Type Description Default
datasets Dict[str, BaseDataset]

Dictionary of datasets for each modality.

required
config DefaultConfig

Configuration object.

required
Source code in src/autoencodix/data/_multimodal_dataset.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(self, datasets: Dict[str, BaseDataset], config: DefaultConfig):
    """
    Initialize the MultiModalDataset.

    Args:
        datasets: Dictionary of datasets for each modality.
        config: Configuration object.
    """
    self.datasets = datasets
    self.modalities = list(datasets.keys())
    self.n_modalities = len(self.datasets.keys())
    self.sample_to_modalities = self._build_sample_map()
    self.sample_ids: List[Any] = list(self.sample_to_modalities.keys())
    self.config = config
    self.data = next(iter(self.datasets.values())).data
    self.feature_ids = None  # TODO

    # Build reverse lookup tables once
    for ds_name, ds in self.datasets.items():
        if ds.sample_ids is None:
            raise ValueError(f"There are no sample_ids for {ds_name}")
    self._id_to_idx = {
        mod: {sid: idx for idx, sid in enumerate(ds.sample_ids)}  # type: ignore
        for mod, ds in self.datasets.items()
    }
    self.paired_sample_ids = self._get_paired_sample_ids()
    self.unpaired_sample_ids = list(
        set(self.sample_ids) - set(self.paired_sample_ids)
    )

NaNRemover

Removes NaN values from multi-modal datasets.

This object identifies and removes NaN values from various data structures commonly used in single-cell and multi-modal omics, including AnnData, MuData, and Pandas DataFrames. It supports processing of X matrices, layers, and observation annotations within AnnData objects, as well as handling bulk and annotation data within a DataPackage.

Attributes:

Name Type Description
config

Configuration object containing settings for data processing.

relevant_cols

List of columns in metadata to check for NaNs.

Source code in src/autoencodix/data/_nanremover.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
class NaNRemover:
    """Removes NaN values from multi-modal datasets.

    This object identifies and removes NaN values from various data structures
    commonly used in single-cell and multi-modal omics, including AnnData, MuData,
    and Pandas DataFrames. It supports processing of X matrices, layers, and
    observation annotations within AnnData objects, as well as handling bulk and
    annotation data within a DataPackage.

    Attributes:
        config: Configuration object containing settings for data processing.
        relevant_cols: List of columns in metadata to check for NaNs.
    """

    def __init__(
        self,
        config: DefaultConfig,
    ):
        """Initialize the NaNRemover with configuration settings.
        Args:
            config: Configuration object containing settings for data processing.

        """
        self.config = config
        self.relevant_cols = self.config.data_config.annotation_columns

    def _process_modality(self, adata: ad.AnnData) -> ad.AnnData:
        """Converts NaN values in AnnData object to zero and metadata NaNs to 'missing'.
        Args:
            adata: The AnnData object to process.
        Returns:
            The processed AnnData object with NaN values replaced.
        """
        adata = adata.copy()

        # Handle X matrix
        if sparse.issparse(adata.X):
            if hasattr(adata.X, "data"):
                adata.X.data = np.nan_to_num(  # ty:  ignore
                    adata.X.data, nan=0.0
                )  # ty: ignore[invalid-assignment]
                adata.X.eliminate_zeros()  # ty: ignore
        else:
            adata.X = np.nan_to_num(adata.X, nan=0.0)

        # Handle all layers
        for layer_name, layer_data in adata.layers.items():
            if sparse.issparse(layer_data):
                if hasattr(layer_data, "data"):
                    layer_data.data = np.nan_to_num(layer_data.data, nan=0.0)
                    layer_data.eliminate_zeros()
            else:
                adata.layers[layer_name] = np.nan_to_num(layer_data, nan=0.0)

        # Handle obs metadata
        if self.relevant_cols is not None:
            print(adata.obs.columns)
            for col in self.relevant_cols:
                if col in adata.obs.columns:
                    # Fill NaNs with "missing" for non-numeric columns
                    if not pd.api.types.is_numeric_dtype(adata.obs[col]):
                        # Add "missing" to categories first, then fill
                        adata.obs[col] = adata.obs[col].cat.add_categories(["missing"])
                        adata.obs[col] = (
                            adata.obs[col].fillna("missing").astype("category")
                        )
        return adata

    def remove_nan(self, data: DataPackage) -> DataPackage:
        """Removes NaN values from all applicable DataPackage components.

        Iterates through the bulk data, annotation data, and multi-modal
        single-cell data (MuData and AnnData objects) within the provided
        DataPackage and removes rows/columns/entries containing NaN values.

        Args:
            data: The DataPackage object containing multi-modal data.

        Returns:
            The DataPackage object with NaN values removed from its components.
        """
        # Handle bulk data
        if data.multi_bulk:
            for key, df in data.multi_bulk.items():
                data.multi_bulk[key] = df.dropna(axis=1)

        # Handle annotation data
        if data.annotation is not None:
            non_na = {}
            for k, v in data.annotation.items():
                if v is None:
                    continue
                if self.relevant_cols is not None:
                    for col in self.relevant_cols:
                        # Fill with "missing" if column is not integer or float
                        if col in v.columns and not pd.api.types.is_numeric_dtype(
                            v[col]
                        ):
                            v.fillna(value={col: "missing"}, inplace=True)

                non_na[k] = v
            data.annotation = non_na  # type: ignore

        # Handle MuData in multi_sc
        if data.multi_sc is not None and self.config.requires_paired:
            mudata = data.multi_sc["multi_sc"]
            # Process each modality
            for mod_name, mod_data in mudata.mod.items():
                processed_mod = self._process_modality(adata=mod_data)
                data.multi_sc["multi_sc"].mod[mod_name] = processed_mod

        elif data.multi_sc is not None:
            print(f"data in multi_sc: {data.multi_sc}")
            processed = {k: None for k, _ in data.multi_sc.items()}

            for k, v in data.multi_sc.items():
                # we know from screader that there is only one modality
                for modkey, adata in v.mod.items():
                    processed_mod = self._process_modality(adata=adata)
                    processed_mod = md.MuData({modkey: processed_mod})
                processed[k] = processed_mod
            data.multi_sc = processed

        # Handle from_modality and to_modality (for translation cases)
        for direction in ["from_modality", "to_modality"]:
            modality_dict = getattr(data, direction)
            if not modality_dict:
                continue

            for mod_key, mod_value in modality_dict.items():
                # Handle MuData objects - use the proper import
                if isinstance(mod_value, md.MuData):
                    # Process each modality in the MuData
                    for inner_mod_name, inner_mod_data in mod_value.mod.items():
                        processed_mod = self._process_modality(inner_mod_data)
                        mod_value.mod[inner_mod_name] = processed_mod

                    # Ensure cell alignment if there are multiple modalities
                    if len(mod_value.mod) > 1:
                        common_cells = list(
                            set.intersection(
                                *(set(mod.obs_names) for mod in mod_value.mod.values())
                            )
                        )
                        mod_value = mod_value[common_cells]

                    modality_dict[mod_key] = mod_value

                # Handle AnnData objects directly
                elif isinstance(mod_value, ad.AnnData):
                    processed_mod = self._process_modality(mod_value)
                    modality_dict[mod_key] = processed_mod

                # Handle other types of data (e.g., dictionaries of AnnData objects)
                elif isinstance(mod_value, dict):
                    for sub_key, sub_value in mod_value.items():
                        if isinstance(sub_value, ad.AnnData):
                            processed_mod = self._process_modality(sub_value)
                            mod_value[sub_key] = processed_mod

                elif isinstance(mod_value, pd.DataFrame):
                    mod_value.dropna(axis=1, inplace=True)
                    modality_dict[mod_key] = mod_value

                else:
                    warnings.warn(
                        f"Skipping unknown type in {direction}.{mod_key}: {type(mod_value)}"
                    )

        return data

__init__(config)

Initialize the NaNRemover with configuration settings. Args: config: Configuration object containing settings for data processing.

Source code in src/autoencodix/data/_nanremover.py
27
28
29
30
31
32
33
34
35
36
37
def __init__(
    self,
    config: DefaultConfig,
):
    """Initialize the NaNRemover with configuration settings.
    Args:
        config: Configuration object containing settings for data processing.

    """
    self.config = config
    self.relevant_cols = self.config.data_config.annotation_columns

remove_nan(data)

Removes NaN values from all applicable DataPackage components.

Iterates through the bulk data, annotation data, and multi-modal single-cell data (MuData and AnnData objects) within the provided DataPackage and removes rows/columns/entries containing NaN values.

Parameters:

Name Type Description Default
data DataPackage

The DataPackage object containing multi-modal data.

required

Returns:

Type Description
DataPackage

The DataPackage object with NaN values removed from its components.

Source code in src/autoencodix/data/_nanremover.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def remove_nan(self, data: DataPackage) -> DataPackage:
    """Removes NaN values from all applicable DataPackage components.

    Iterates through the bulk data, annotation data, and multi-modal
    single-cell data (MuData and AnnData objects) within the provided
    DataPackage and removes rows/columns/entries containing NaN values.

    Args:
        data: The DataPackage object containing multi-modal data.

    Returns:
        The DataPackage object with NaN values removed from its components.
    """
    # Handle bulk data
    if data.multi_bulk:
        for key, df in data.multi_bulk.items():
            data.multi_bulk[key] = df.dropna(axis=1)

    # Handle annotation data
    if data.annotation is not None:
        non_na = {}
        for k, v in data.annotation.items():
            if v is None:
                continue
            if self.relevant_cols is not None:
                for col in self.relevant_cols:
                    # Fill with "missing" if column is not integer or float
                    if col in v.columns and not pd.api.types.is_numeric_dtype(
                        v[col]
                    ):
                        v.fillna(value={col: "missing"}, inplace=True)

            non_na[k] = v
        data.annotation = non_na  # type: ignore

    # Handle MuData in multi_sc
    if data.multi_sc is not None and self.config.requires_paired:
        mudata = data.multi_sc["multi_sc"]
        # Process each modality
        for mod_name, mod_data in mudata.mod.items():
            processed_mod = self._process_modality(adata=mod_data)
            data.multi_sc["multi_sc"].mod[mod_name] = processed_mod

    elif data.multi_sc is not None:
        print(f"data in multi_sc: {data.multi_sc}")
        processed = {k: None for k, _ in data.multi_sc.items()}

        for k, v in data.multi_sc.items():
            # we know from screader that there is only one modality
            for modkey, adata in v.mod.items():
                processed_mod = self._process_modality(adata=adata)
                processed_mod = md.MuData({modkey: processed_mod})
            processed[k] = processed_mod
        data.multi_sc = processed

    # Handle from_modality and to_modality (for translation cases)
    for direction in ["from_modality", "to_modality"]:
        modality_dict = getattr(data, direction)
        if not modality_dict:
            continue

        for mod_key, mod_value in modality_dict.items():
            # Handle MuData objects - use the proper import
            if isinstance(mod_value, md.MuData):
                # Process each modality in the MuData
                for inner_mod_name, inner_mod_data in mod_value.mod.items():
                    processed_mod = self._process_modality(inner_mod_data)
                    mod_value.mod[inner_mod_name] = processed_mod

                # Ensure cell alignment if there are multiple modalities
                if len(mod_value.mod) > 1:
                    common_cells = list(
                        set.intersection(
                            *(set(mod.obs_names) for mod in mod_value.mod.values())
                        )
                    )
                    mod_value = mod_value[common_cells]

                modality_dict[mod_key] = mod_value

            # Handle AnnData objects directly
            elif isinstance(mod_value, ad.AnnData):
                processed_mod = self._process_modality(mod_value)
                modality_dict[mod_key] = processed_mod

            # Handle other types of data (e.g., dictionaries of AnnData objects)
            elif isinstance(mod_value, dict):
                for sub_key, sub_value in mod_value.items():
                    if isinstance(sub_value, ad.AnnData):
                        processed_mod = self._process_modality(sub_value)
                        mod_value[sub_key] = processed_mod

            elif isinstance(mod_value, pd.DataFrame):
                mod_value.dropna(axis=1, inplace=True)
                modality_dict[mod_key] = mod_value

            else:
                warnings.warn(
                    f"Skipping unknown type in {direction}.{mod_key}: {type(mod_value)}"
                )

    return data

NumericDataset

Bases: TensorAwareDataset

A custom PyTorch dataset that handles tensors.

Attributes:

Name Type Description
data

The input features as a torch.Tensor.

config

Configuration object containing settings for data processing.

sample_ids

Optional list of sample identifiers.

feature_ids

Optional list of feature identifiers.

metadata

Optional pandas DataFrame containing metadata.

split_indices

Optional numpy array for data splitting.

mytype

Enum indicating the dataset type (set to DataSetTypes.NUM).

Source code in src/autoencodix/data/_numeric_dataset.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
class NumericDataset(TensorAwareDataset):
    """A custom PyTorch dataset that handles tensors.


    Attributes:
        data: The input features as a torch.Tensor.
        config: Configuration object containing settings for data processing.
        sample_ids: Optional list of sample identifiers.
        feature_ids: Optional list of feature identifiers.
        metadata: Optional pandas DataFrame containing metadata.
        split_indices: Optional numpy array for data splitting.
        mytype: Enum indicating the dataset type (set to DataSetTypes.NUM).
    """

    def __init__(
        self,
        data: Union[torch.Tensor, np.ndarray, sp.sparse.spmatrix],
        config: DefaultConfig,
        sample_ids: Union[None, List[Any]] = None,
        feature_ids: Union[None, List[Any]] = None,
        metadata: Optional[Union[pd.Series, pd.DataFrame]] = None,
        split_indices: Optional[Union[Dict[str, Any], List[Any], np.ndarray]] = None,
    ):
        """
        Initialize the dataset

        Args:
            data: Input features
            config: Configuration object
            sample_ids: Optional sample identifiers
            feature_ids: Optional feature identifiers
            metadata: Optional metadata
            split_indices: Optional split indices
            Optional split indices
        """
        super().__init__(
            data=data, sample_ids=sample_ids, config=config, feature_ids=feature_ids
        )

        if self.config is None:
            raise ValueError("config cannot be None")

        # Convert data to appropriate dtype once during initialization
        self.target_dtype = self._get_target_dtype()
        # keep data sparce if it is a scipy sparse matrix to be memory
        # efficient for large single cell data, convert at batch level to dense tensor
        if isinstance(self.data, (np.ndarray, torch.Tensor)):
            self.data = self._to_tensor(data, self.target_dtype)

        self.metadata = metadata
        self.split_indices = split_indices
        self.mytype = DataSetTypes.NUM

    @no_type_check
    def __getitem__(self, index: int) -> Union[
        Tuple[
            Union[torch.Tensor, int],
            Union[torch.Tensor, "ImgData"],  # ty: ignore  # noqa: F821
            Any,
        ],
        Dict[str, Tuple[Any, torch.Tensor, Any]],
    ]:
        """Retrieves a single sample and its corresponding label.

        Args:
            index: Index of the sample to retrieve.

        Returns:
            A tuple containing the index, the data sample and its label, or a dictionary
            mapping keys to such tuples in case we have multiple uncombined data at this step.
        """

        row = self.data[index]  # idx: int, slice, or list
        if self.sample_ids is not None:
            label = self.sample_ids[index]
        else:
            label = index
        if issparse(row):
            # print("calling to array")

            # print(f"Size of data sparse: {asizeof.asizeof(row)}")
            row = torch.tensor(row.toarray(), dtype=self.target_dtype).squeeze(0)

            # print(f"Size of data dense: {asizeof.asizeof(row)}")

        return index, row, label

    def __len__(self) -> int:
        """Returns the number of samples (rows) in the dataset"""
        return self.data.shape[0]

__getitem__(index)

Retrieves a single sample and its corresponding label.

Parameters:

Name Type Description Default
index int

Index of the sample to retrieve.

required

Returns:

Type Description
Union[Tuple[Union[Tensor, int], Union[Tensor, 'ImgData'], Any], Dict[str, Tuple[Any, Tensor, Any]]]

A tuple containing the index, the data sample and its label, or a dictionary

Union[Tuple[Union[Tensor, int], Union[Tensor, 'ImgData'], Any], Dict[str, Tuple[Any, Tensor, Any]]]

mapping keys to such tuples in case we have multiple uncombined data at this step.

Source code in src/autoencodix/data/_numeric_dataset.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
@no_type_check
def __getitem__(self, index: int) -> Union[
    Tuple[
        Union[torch.Tensor, int],
        Union[torch.Tensor, "ImgData"],  # ty: ignore  # noqa: F821
        Any,
    ],
    Dict[str, Tuple[Any, torch.Tensor, Any]],
]:
    """Retrieves a single sample and its corresponding label.

    Args:
        index: Index of the sample to retrieve.

    Returns:
        A tuple containing the index, the data sample and its label, or a dictionary
        mapping keys to such tuples in case we have multiple uncombined data at this step.
    """

    row = self.data[index]  # idx: int, slice, or list
    if self.sample_ids is not None:
        label = self.sample_ids[index]
    else:
        label = index
    if issparse(row):
        # print("calling to array")

        # print(f"Size of data sparse: {asizeof.asizeof(row)}")
        row = torch.tensor(row.toarray(), dtype=self.target_dtype).squeeze(0)

        # print(f"Size of data dense: {asizeof.asizeof(row)}")

    return index, row, label

__init__(data, config, sample_ids=None, feature_ids=None, metadata=None, split_indices=None)

Initialize the dataset

Parameters:

Name Type Description Default
data Union[Tensor, ndarray, spmatrix]

Input features

required
config DefaultConfig

Configuration object

required
sample_ids Union[None, List[Any]]

Optional sample identifiers

None
feature_ids Union[None, List[Any]]

Optional feature identifiers

None
metadata Optional[Union[Series, DataFrame]]

Optional metadata

None
split_indices Optional[Union[Dict[str, Any], List[Any], ndarray]]

Optional split indices

None
Source code in src/autoencodix/data/_numeric_dataset.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def __init__(
    self,
    data: Union[torch.Tensor, np.ndarray, sp.sparse.spmatrix],
    config: DefaultConfig,
    sample_ids: Union[None, List[Any]] = None,
    feature_ids: Union[None, List[Any]] = None,
    metadata: Optional[Union[pd.Series, pd.DataFrame]] = None,
    split_indices: Optional[Union[Dict[str, Any], List[Any], np.ndarray]] = None,
):
    """
    Initialize the dataset

    Args:
        data: Input features
        config: Configuration object
        sample_ids: Optional sample identifiers
        feature_ids: Optional feature identifiers
        metadata: Optional metadata
        split_indices: Optional split indices
        Optional split indices
    """
    super().__init__(
        data=data, sample_ids=sample_ids, config=config, feature_ids=feature_ids
    )

    if self.config is None:
        raise ValueError("config cannot be None")

    # Convert data to appropriate dtype once during initialization
    self.target_dtype = self._get_target_dtype()
    # keep data sparce if it is a scipy sparse matrix to be memory
    # efficient for large single cell data, convert at batch level to dense tensor
    if isinstance(self.data, (np.ndarray, torch.Tensor)):
        self.data = self._to_tensor(data, self.target_dtype)

    self.metadata = metadata
    self.split_indices = split_indices
    self.mytype = DataSetTypes.NUM

__len__()

Returns the number of samples (rows) in the dataset

Source code in src/autoencodix/data/_numeric_dataset.py
203
204
205
def __len__(self) -> int:
    """Returns the number of samples (rows) in the dataset"""
    return self.data.shape[0]

SingleCellFilter

Filter and scale single-cell data, returning a MuData object with synchronized metadata.AnnData

Attributes:

Name Type Description
data_info

Configuration for filtering and scaling (can be a single DataInfo or a dict of DataInfo per modality).

total_features

Total number of features to keep across all modalities.

config

Configuration object containing settings for data processing.

_is_data_info_dict

Internal flag indicating if data_info is a dictionary.

Source code in src/autoencodix/data/_sc_filter.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
class SingleCellFilter:
    """Filter and scale single-cell data, returning a MuData object with synchronized metadata.AnnData

    Attributes:
        data_info: Configuration for filtering and scaling (can be a single DataInfo or a dict of DataInfo per modality).
        total_features: Total number of features to keep across all modalities.
        config: Configuration object containing settings for data processing.
        _is_data_info_dict: Internal flag indicating if data_info is a dictionary.
    """

    def __init__(
        self, data_info: Union[Dict[str, DataInfo], DataInfo], config: DefaultConfig
    ):
        """
        Initialize single-cell filter.
        Args:
            data_info: Either a single data_info object for all modalities or a dictionary of data_info objects for each modality.
            config: Configuration object containing settings for data processing.
        """
        self.data_info = data_info
        self.total_features = config.k_filter
        self._is_data_info_dict = isinstance(data_info, dict)
        self.config = config

    def _get_data_info_for_modality(self, mod_key: str) -> DataInfo:
        """
        Get the data_info configuration for a specific modality.
        Args:
            mod_key: The modality key (e.g., "RNA", "METH")
        Returns
            The data_info configuration for the modality
        """
        if self._is_data_info_dict:
            info = self.data_info.get(mod_key)  # type: ignore
            if info is None:
                raise ValueError(f"No data info found for modality {mod_key}")
            return info
        return self.data_info  # type: ignore

    def _get_layers_for_modality(self, mod_key: str, mod_data) -> List[str]:
        """
        Get the layers to process for a specific modality.
        Args
            mod_key: The modality key (e.g., "RNA", "METH")
            mod_data: The AnnData object for the modality
        Returns
            List of layer names to process. If None or empty, returns ['X'] for default layer.
        """
        data_info = self._get_data_info_for_modality(mod_key)
        selected_layers = data_info.selected_layers

        # Validate that the specified layers exist
        available_layers = list(mod_data.layers.keys())
        valid_layers = []

        for layer in selected_layers:
            if layer == "X":
                valid_layers.append("X")
            elif layer in available_layers:
                valid_layers.append(layer)
            else:
                print(
                    f"Warning: Layer '{layer}' not found in modality '{mod_key}'. Skipping."
                )
        if not valid_layers:
            valid_layers = ["X"]

        return valid_layers

    def _presplit_processing(
        self,
        mudata: MuData,  # type: ignore[invalid-type-form]
    ) -> MuData:  # type: ignore[invalid-type-form]
        """
        Preprocess the data using modality-specific configurations.
        Returns:
            Preprocessed data
        """
        print(f"mudata: {mudata}")
        for mod_key, mod_data in mudata.mod.items():
            data_info = self._get_data_info_for_modality(mod_key)
            if data_info is not None:
                sc.pp.filter_cells(mod_data, min_genes=data_info.min_genes)
                layers_to_process = self._get_layers_for_modality(mod_key, mod_data)

                for layer in layers_to_process:
                    if layer == "X":
                        if data_info.log_transform:
                            sc.pp.log1p(mod_data)
                    else:
                        temp_view = mod_data.copy()
                        temp_view.X = mod_data.layers[layer].copy()
                        if data_info.log_transform:
                            sc.pp.log1p(temp_view)
                        mod_data.layers[layer] = temp_view.X.copy()

                mudata.mod[mod_key] = mod_data

        return mudata

    def presplit_processing(
        self,
        multi_sc: Union[MuData, Dict[str, MuData]],  # ty: ignore[invalid-type-form]
    ) -> Dict[str, MuData]:  # ty: ignore[invalid-type-form]
        """
        Process each modality independently to filter cells based on min_genes.

        Args:
            multi_sc: Either a single MuData object or a dictionary of MuData objects.
        Returns:
            A dictionary mapping modality keys to processed MuData objects.
        """
        from mudata import MuData

        if isinstance(multi_sc, MuData):
            return self._presplit_processing(mudata=multi_sc)
        res = {k: None for k in multi_sc.keys()}
        for k, v in multi_sc.items():
            processed = self._presplit_processing(mudata=v)
            res[k] = processed
        return res

    def _to_dataframe(self, mod_data, layer=None) -> pd.DataFrame:
        """
        Transform a modality's AnnData object to a pandas DataFrame.
        Args:
            mod_data: Modality data to be transformed
            layer: Layer to convert to DataFrame. If None, uses X.
        Returns:
            Transformed DataFrame
        """
        if layer is None or layer == "X":
            data = mod_data.X
        else:
            data = mod_data.layers[layer]

        # Convert to dense array if sparse
        if isinstance(data, np.ndarray):
            matrix = data
        else:  # Assuming it's a sparse matrix
            matrix = data.toarray()

        return pd.DataFrame(
            matrix, columns=mod_data.var_names, index=mod_data.obs_names
        )

    def _from_dataframe(self, df: pd.DataFrame, mod_data, layer=None):
        """
        Update a modality's AnnData object with the values from a DataFrame.
        This also synchronizes the `obs` and `var` metadata to match the filtered data.
        Args:
            df: DataFrame containing the updated values
            mod_data: Modality data to be updated
            layer: Layer to update with DataFrame values. If None, updates X.
        Returns:
            Updated AnnData object
        """
        # Filter the AnnData object to match the rows and columns of the DataFrame
        filtered_mod_data = mod_data[df.index, df.columns].copy()

        # Update the data matrix with the filtered and scaled values
        if layer is None or layer == "X":
            filtered_mod_data.X = df.values
        else:
            if layer not in filtered_mod_data.layers:
                filtered_mod_data.layers[layer] = df.values
            else:
                filtered_mod_data.layers[layer] = df.values

        return filtered_mod_data

    def sc_postsplit_processing(
        self,
        mudata: MuData,  # ty: ignore[invalid-type-form]
        gene_map: Optional[
            Dict[str, List[str]]
        ] = None,  # ty: ignore[invalid-type-form]
    ) -> Tuple[MuData, Dict[str, List[str]]]:  # ty: ignore[invalid-type-form]
        """
        Process each modality independently to filter genes based on X layer, then
        consistently apply the same filtering to all layers.

        Args:
        mudata : Input multi-modal data container
        gene_map : Optional override of genes to keep per modality

        Returns:
            - Processed MuData with filtered modalities
            - Mapping of modality to kept gene names
        """
        kept_genes = {}
        processed_mods = {}

        for mod_key, adata in mudata.mod.items():
            # Get configuration for this modality
            info = self._get_data_info_for_modality(mod_key)
            if info is None:
                raise ValueError(f"No data info for modality '{mod_key}'")

            # Determine which genes to keep
            if gene_map and mod_key in gene_map:
                # Use provided gene list if available
                genes_to_keep = gene_map[mod_key]
                var_mask = adata.var_names.isin(genes_to_keep)
            else:
                # Filter genes based on minimum cells expressing each gene
                var_mask = sc.pp.filter_genes(
                    adata.copy(), min_cells=info.min_cells, inplace=False
                )[0]
                genes_to_keep = adata.var_names[var_mask].tolist()

            kept_genes[mod_key] = genes_to_keep

            # Create new AnnData with filtered X layer
            filtered_adata = AnnData(
                X=adata.X[:, var_mask],
                obs=adata.obs.copy(),
                var=adata.var[var_mask].copy(),
                uns=adata.uns.copy(),
                obsm=adata.obsm.copy(),
            )

            # Normalize if configured
            if info.normalize_counts:
                sc.pp.normalize_total(filtered_adata)

            # Copy filtered layers
            for layer in self._get_layers_for_modality(mod_key, adata):
                if layer == "X":
                    continue

                if layer not in adata.layers:
                    raise ValueError(
                        f"Layer '{layer}' not found in modality '{mod_key}'"
                    )

                filtered_adata.layers[layer] = adata.layers[layer][:, var_mask].copy()

            processed_mods[mod_key] = filtered_adata

        # Construct new MuData from filtered modalities
        return md.MuData(processed_mods), kept_genes

    def _apply_general_filtering(
        self, df: pd.DataFrame, data_info: DataInfo, gene_list: Optional[List]
    ) -> Tuple[Union[pd.Series, pd.DataFrame], List]:
        data_processor = DataFilter(data_info=data_info, config=self.config)
        return data_processor.filter(df=df, genes_to_keep=gene_list)

    def _apply_scaling(
        self, df: pd.DataFrame, data_info: DataInfo, scaler: Any
    ) -> Tuple[Union[pd.Series, pd.DataFrame], Any]:
        data_processor = DataFilter(data_info=data_info, config=self.config)
        if scaler is None:
            scaler = data_processor.fit_scaler(df=df)
        scaled_df = data_processor.scale(df=df, scaler=scaler)
        return scaled_df, scaler

    def general_postsplit_processing(
        self,
        mudata: MuData,  # ty: ignore[invalid-type-form]
        gene_map: Optional[Dict[str, List]],
        scaler_map: Optional[Dict[str, Dict[str, Any]]] = None,
    ) -> Tuple[
        MuData,  # ty: ignore[invalid-type-form]
        Dict[str, List],
        Dict[str, Dict[str, Any]],  # ty: ignore[invalid-type-form]
    ]:  # ty: ignore[invalid-type-form]
        """Process single-cell data with proper MuData handling
        Args:
            mudata: Input multi-modal data container
            gene_map: Optional override of genes to keep per modality
            scaler_map: Optional pre-fitted scalers per modality and layer
        Returns:
            Processed MuData with filtered and scaled modalities,
        """
        feature_distribution = self.distribute_features_across_modalities(
            mudata, self.total_features
        )
        out_gene_map = {}
        out_scaler_map = {mod_key: {} for mod_key in mudata.mod.keys()}

        # Dictionary to store processed modalities
        processed_modalities = {}

        for mod_key, original_mod in mudata.mod.items():
            data_info = self._get_data_info_for_modality(mod_key)
            data_info.k_filter = feature_distribution[mod_key]

            if data_info is None:
                raise ValueError(f"No data info found for modality {mod_key}")

            # Create working copy of the modality data
            mod_data = original_mod.copy()

            # Process X matrix
            x_df = self._to_dataframe(mod_data, layer=None)
            filtered_x, gene_list = self._apply_general_filtering(
                df=x_df,
                gene_list=gene_map.get(mod_key) if gene_map else None,
                data_info=data_info,
            )
            out_gene_map[mod_key] = gene_list

            # Apply scaling to X
            scaled_x, x_scaler = self._apply_scaling(
                df=filtered_x,
                data_info=data_info,
                scaler=scaler_map[mod_key].get("X") if scaler_map else None,
            )
            out_scaler_map[mod_key]["X"] = x_scaler

            # Create new AnnData for this modality
            processed_adata = self._create_new_adata(
                scaled_x,
                original_adata=mod_data,
                obs_names=mod_data.obs_names.tolist(),
                var_names=filtered_x.columns.tolist(),
            )

            # Process layers
            layers_to_process = self._get_layers_for_modality(mod_key, mod_data)
            for layer in layers_to_process:
                if layer == "X":
                    continue

                # Process layer data
                layer_df = self._to_dataframe(mod_data, layer=layer)
                filtered_layer = layer_df[filtered_x.columns]  # Match X's columns

                # Apply scaling with same genes as X
                scaled_layer, layer_scaler = self._apply_scaling(
                    df=filtered_layer,
                    data_info=data_info,
                    scaler=scaler_map[mod_key].get(layer) if scaler_map else None,
                )
                out_scaler_map[mod_key][layer] = layer_scaler

                # Store in new AnnData
                processed_adata.layers[layer] = scaled_layer.values

            # Store processed modality
            processed_modalities[mod_key] = processed_adata

        # Create new MuData from processed modalities
        new_mudata = md.MuData(processed_modalities)

        return new_mudata, out_gene_map, out_scaler_map

    def _create_new_adata(self, df, original_adata, obs_names, var_names):
        """Helper to create properly structured AnnData"""
        return AnnData(
            X=df.values,
            obs=original_adata.obs.loc[obs_names],
            var=pd.DataFrame(index=var_names),
            layers={},
            uns=original_adata.uns.copy(),
            obsm=original_adata.obsm.copy(),
            varm=original_adata.varm.copy(),
        )

    def distribute_features_across_modalities(
        self,
        mudata: MuData,  # ty: ignore[invalid-type-form]
        total_features: Optional[int],  # ty: ignore[invalid-type-form]
    ) -> Dict[str, int]:
        """
        Distributes a total number of features across modalities evenly.

        Args:
            mudata: Multi-modal data object
            total_features: Total number of features to distribute across all modalities

        Returns:
            Dictionary mapping modality keys to number of features to keep
        """

        valid_modalities = [key for key in mudata.mod.keys()]
        if total_features is None:
            return {k: None for k in valid_modalities}
        n_modalities = len(valid_modalities)

        if n_modalities == 0:
            return {}

        base_features = total_features // n_modalities
        remainder = total_features % n_modalities

        # Distribute features
        feature_distribution = {}
        for i, mod_key in enumerate(valid_modalities):
            # Add one extra feature to early modalities if there's remainder
            extra = 1 if i < remainder else 0
            feature_distribution[mod_key] = base_features + extra

            # Set k_filter in data_info if available
            data_info = self._get_data_info_for_modality(mod_key)
            if data_info is not None:
                if not hasattr(data_info, "k_filter"):
                    setattr(data_info, "k_filter", feature_distribution[mod_key])
                else:
                    data_info.k_filter = feature_distribution[mod_key]

        return feature_distribution

__init__(data_info, config)

Initialize single-cell filter. Args: data_info: Either a single data_info object for all modalities or a dictionary of data_info objects for each modality. config: Configuration object containing settings for data processing.

Source code in src/autoencodix/data/_sc_filter.py
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(
    self, data_info: Union[Dict[str, DataInfo], DataInfo], config: DefaultConfig
):
    """
    Initialize single-cell filter.
    Args:
        data_info: Either a single data_info object for all modalities or a dictionary of data_info objects for each modality.
        config: Configuration object containing settings for data processing.
    """
    self.data_info = data_info
    self.total_features = config.k_filter
    self._is_data_info_dict = isinstance(data_info, dict)
    self.config = config

distribute_features_across_modalities(mudata, total_features)

Distributes a total number of features across modalities evenly.

Parameters:

Name Type Description Default
mudata MuData

Multi-modal data object

required
total_features Optional[int]

Total number of features to distribute across all modalities

required

Returns:

Type Description
Dict[str, int]

Dictionary mapping modality keys to number of features to keep

Source code in src/autoencodix/data/_sc_filter.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
def distribute_features_across_modalities(
    self,
    mudata: MuData,  # ty: ignore[invalid-type-form]
    total_features: Optional[int],  # ty: ignore[invalid-type-form]
) -> Dict[str, int]:
    """
    Distributes a total number of features across modalities evenly.

    Args:
        mudata: Multi-modal data object
        total_features: Total number of features to distribute across all modalities

    Returns:
        Dictionary mapping modality keys to number of features to keep
    """

    valid_modalities = [key for key in mudata.mod.keys()]
    if total_features is None:
        return {k: None for k in valid_modalities}
    n_modalities = len(valid_modalities)

    if n_modalities == 0:
        return {}

    base_features = total_features // n_modalities
    remainder = total_features % n_modalities

    # Distribute features
    feature_distribution = {}
    for i, mod_key in enumerate(valid_modalities):
        # Add one extra feature to early modalities if there's remainder
        extra = 1 if i < remainder else 0
        feature_distribution[mod_key] = base_features + extra

        # Set k_filter in data_info if available
        data_info = self._get_data_info_for_modality(mod_key)
        if data_info is not None:
            if not hasattr(data_info, "k_filter"):
                setattr(data_info, "k_filter", feature_distribution[mod_key])
            else:
                data_info.k_filter = feature_distribution[mod_key]

    return feature_distribution

general_postsplit_processing(mudata, gene_map, scaler_map=None)

Process single-cell data with proper MuData handling Args: mudata: Input multi-modal data container gene_map: Optional override of genes to keep per modality scaler_map: Optional pre-fitted scalers per modality and layer Returns: Processed MuData with filtered and scaled modalities,

Source code in src/autoencodix/data/_sc_filter.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
def general_postsplit_processing(
    self,
    mudata: MuData,  # ty: ignore[invalid-type-form]
    gene_map: Optional[Dict[str, List]],
    scaler_map: Optional[Dict[str, Dict[str, Any]]] = None,
) -> Tuple[
    MuData,  # ty: ignore[invalid-type-form]
    Dict[str, List],
    Dict[str, Dict[str, Any]],  # ty: ignore[invalid-type-form]
]:  # ty: ignore[invalid-type-form]
    """Process single-cell data with proper MuData handling
    Args:
        mudata: Input multi-modal data container
        gene_map: Optional override of genes to keep per modality
        scaler_map: Optional pre-fitted scalers per modality and layer
    Returns:
        Processed MuData with filtered and scaled modalities,
    """
    feature_distribution = self.distribute_features_across_modalities(
        mudata, self.total_features
    )
    out_gene_map = {}
    out_scaler_map = {mod_key: {} for mod_key in mudata.mod.keys()}

    # Dictionary to store processed modalities
    processed_modalities = {}

    for mod_key, original_mod in mudata.mod.items():
        data_info = self._get_data_info_for_modality(mod_key)
        data_info.k_filter = feature_distribution[mod_key]

        if data_info is None:
            raise ValueError(f"No data info found for modality {mod_key}")

        # Create working copy of the modality data
        mod_data = original_mod.copy()

        # Process X matrix
        x_df = self._to_dataframe(mod_data, layer=None)
        filtered_x, gene_list = self._apply_general_filtering(
            df=x_df,
            gene_list=gene_map.get(mod_key) if gene_map else None,
            data_info=data_info,
        )
        out_gene_map[mod_key] = gene_list

        # Apply scaling to X
        scaled_x, x_scaler = self._apply_scaling(
            df=filtered_x,
            data_info=data_info,
            scaler=scaler_map[mod_key].get("X") if scaler_map else None,
        )
        out_scaler_map[mod_key]["X"] = x_scaler

        # Create new AnnData for this modality
        processed_adata = self._create_new_adata(
            scaled_x,
            original_adata=mod_data,
            obs_names=mod_data.obs_names.tolist(),
            var_names=filtered_x.columns.tolist(),
        )

        # Process layers
        layers_to_process = self._get_layers_for_modality(mod_key, mod_data)
        for layer in layers_to_process:
            if layer == "X":
                continue

            # Process layer data
            layer_df = self._to_dataframe(mod_data, layer=layer)
            filtered_layer = layer_df[filtered_x.columns]  # Match X's columns

            # Apply scaling with same genes as X
            scaled_layer, layer_scaler = self._apply_scaling(
                df=filtered_layer,
                data_info=data_info,
                scaler=scaler_map[mod_key].get(layer) if scaler_map else None,
            )
            out_scaler_map[mod_key][layer] = layer_scaler

            # Store in new AnnData
            processed_adata.layers[layer] = scaled_layer.values

        # Store processed modality
        processed_modalities[mod_key] = processed_adata

    # Create new MuData from processed modalities
    new_mudata = md.MuData(processed_modalities)

    return new_mudata, out_gene_map, out_scaler_map

presplit_processing(multi_sc)

Process each modality independently to filter cells based on min_genes.

Parameters:

Name Type Description Default
multi_sc Union[MuData, Dict[str, MuData]]

Either a single MuData object or a dictionary of MuData objects.

required

Returns: A dictionary mapping modality keys to processed MuData objects.

Source code in src/autoencodix/data/_sc_filter.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def presplit_processing(
    self,
    multi_sc: Union[MuData, Dict[str, MuData]],  # ty: ignore[invalid-type-form]
) -> Dict[str, MuData]:  # ty: ignore[invalid-type-form]
    """
    Process each modality independently to filter cells based on min_genes.

    Args:
        multi_sc: Either a single MuData object or a dictionary of MuData objects.
    Returns:
        A dictionary mapping modality keys to processed MuData objects.
    """
    from mudata import MuData

    if isinstance(multi_sc, MuData):
        return self._presplit_processing(mudata=multi_sc)
    res = {k: None for k in multi_sc.keys()}
    for k, v in multi_sc.items():
        processed = self._presplit_processing(mudata=v)
        res[k] = processed
    return res

sc_postsplit_processing(mudata, gene_map=None)

Process each modality independently to filter genes based on X layer, then consistently apply the same filtering to all layers.

Args: mudata : Input multi-modal data container gene_map : Optional override of genes to keep per modality

Returns:

Type Description
MuData
  • Processed MuData with filtered modalities
Dict[str, List[str]]
  • Mapping of modality to kept gene names
Source code in src/autoencodix/data/_sc_filter.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def sc_postsplit_processing(
    self,
    mudata: MuData,  # ty: ignore[invalid-type-form]
    gene_map: Optional[
        Dict[str, List[str]]
    ] = None,  # ty: ignore[invalid-type-form]
) -> Tuple[MuData, Dict[str, List[str]]]:  # ty: ignore[invalid-type-form]
    """
    Process each modality independently to filter genes based on X layer, then
    consistently apply the same filtering to all layers.

    Args:
    mudata : Input multi-modal data container
    gene_map : Optional override of genes to keep per modality

    Returns:
        - Processed MuData with filtered modalities
        - Mapping of modality to kept gene names
    """
    kept_genes = {}
    processed_mods = {}

    for mod_key, adata in mudata.mod.items():
        # Get configuration for this modality
        info = self._get_data_info_for_modality(mod_key)
        if info is None:
            raise ValueError(f"No data info for modality '{mod_key}'")

        # Determine which genes to keep
        if gene_map and mod_key in gene_map:
            # Use provided gene list if available
            genes_to_keep = gene_map[mod_key]
            var_mask = adata.var_names.isin(genes_to_keep)
        else:
            # Filter genes based on minimum cells expressing each gene
            var_mask = sc.pp.filter_genes(
                adata.copy(), min_cells=info.min_cells, inplace=False
            )[0]
            genes_to_keep = adata.var_names[var_mask].tolist()

        kept_genes[mod_key] = genes_to_keep

        # Create new AnnData with filtered X layer
        filtered_adata = AnnData(
            X=adata.X[:, var_mask],
            obs=adata.obs.copy(),
            var=adata.var[var_mask].copy(),
            uns=adata.uns.copy(),
            obsm=adata.obsm.copy(),
        )

        # Normalize if configured
        if info.normalize_counts:
            sc.pp.normalize_total(filtered_adata)

        # Copy filtered layers
        for layer in self._get_layers_for_modality(mod_key, adata):
            if layer == "X":
                continue

            if layer not in adata.layers:
                raise ValueError(
                    f"Layer '{layer}' not found in modality '{mod_key}'"
                )

            filtered_adata.layers[layer] = adata.layers[layer][:, var_mask].copy()

        processed_mods[mod_key] = filtered_adata

    # Construct new MuData from filtered modalities
    return md.MuData(processed_mods), kept_genes

StackixDataset

Bases: NumericDataset

Dataset for handling multiple modalities in Stackix models.

This dataset holds individual BaseDataset objects for multiple data modalities and provides a consistent interface for accessing them during training. It's designed to work specifically with StackixTrainer.

Attributes:

Name Type Description
dataset_dict

Dictionary mapping modality names to dataset objects

modality_keys

List of modality names

Source code in src/autoencodix/data/_stackix_dataset.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class StackixDataset(NumericDataset):
    """
    Dataset for handling multiple modalities in Stackix models.

    This dataset holds individual BaseDataset objects for multiple data modalities
    and provides a consistent interface for accessing them during training.
    It's designed to work specifically with StackixTrainer.

    Attributes:
        dataset_dict: Dictionary mapping modality names to dataset objects
        modality_keys: List of modality names
    """

    def __init__(
        self,
        dataset_dict: Dict[str, BaseDataset],
        config: DefaultConfig,
    ):
        """
        Initialize a StackixDataset instance.

        Args:
            dataset_dict: Dictionary mapping modality names to dataset objects
            config: Configuration object

        Raises:
            ValueError: If the datasets dictionary is empty or if modality datasets have different numbers of samples
            NotImplementedError: If the datasets have incompatible shapes for concatenation
        """
        if not dataset_dict:
            raise ValueError("dataset_dict cannot be empty")

        # Use first modality for base class initialization
        first_modality_key = next(iter(dataset_dict.keys()))
        first_modality = dataset_dict[first_modality_key]
        try:
            data = torch.cat(
                [v.data for _, v in dataset_dict.items() if hasattr(v, "data")], dim=1
            )
        except Exception:
            raise NotImplementedError(
                "Data modalities have different shapes, set requires_paired=True in config"
            )
        super().__init__(
            data=data,
            sample_ids=first_modality.sample_ids,
            config=config,
            split_indices=first_modality.split_indices,
            metadata=first_modality.metadata,
            feature_ids=[
                v.feature_ids
                for v in dataset_dict.values()
                if hasattr(v, "feature_ids")
            ],
        )

        self.dataset_dict = dataset_dict
        self.modality_keys = list(dataset_dict.keys())

        # Ensure all datasets have the same number of samples
        sample_counts = [len(dataset) for dataset in dataset_dict.values()]
        if not all(count == sample_counts[0] for count in sample_counts):
            raise ValueError(
                "All modality datasets must have the same number of samples"
            )

    def __len__(self) -> int:
        """Return the number of samples in the dataset."""
        return len(next(iter(self.dataset_dict.values())))

    def __getitem__(
        self, index: int
    ) -> Union[Tuple[torch.Tensor, Any], Dict[str, Tuple[torch.Tensor, Any]]]:
        """
        Get a single sample and its label from the dataset.

        Returns the data from the first modality to maintain compatibility
        with the BaseDataset interface, while still supporting multi-modality
        access through dataset_dict.
        Args:
            index: Index of the sample to retrieve

        Returns:
            Dictionary of (data tensor, label) pairs for each modality

        """
        return {
            k: self.dataset_dict[k].__getitem__(index) for k in self.dataset_dict.keys()
        }

    def get_modality_item(self, modality: str, index: int) -> Tuple[torch.Tensor, Any]:
        """
        Get a sample for a specific modality.
        Args:
            modality: The modality name to retrieve data from
            index: Index of the sample to retrieve

        Returns:
            Tuple of (data tensor, label) for the specified modality and sample index

        Raises:
            KeyError: If the requested modality doesn't exist in the dataset
        """
        if modality not in self.dataset_dict:
            raise KeyError(f"Modality '{modality}' not found in dataset")

        return self.dataset_dict[modality][index]

__getitem__(index)

Get a single sample and its label from the dataset.

Returns the data from the first modality to maintain compatibility with the BaseDataset interface, while still supporting multi-modality access through dataset_dict. Args: index: Index of the sample to retrieve

Returns:

Type Description
Union[Tuple[Tensor, Any], Dict[str, Tuple[Tensor, Any]]]

Dictionary of (data tensor, label) pairs for each modality

Source code in src/autoencodix/data/_stackix_dataset.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __getitem__(
    self, index: int
) -> Union[Tuple[torch.Tensor, Any], Dict[str, Tuple[torch.Tensor, Any]]]:
    """
    Get a single sample and its label from the dataset.

    Returns the data from the first modality to maintain compatibility
    with the BaseDataset interface, while still supporting multi-modality
    access through dataset_dict.
    Args:
        index: Index of the sample to retrieve

    Returns:
        Dictionary of (data tensor, label) pairs for each modality

    """
    return {
        k: self.dataset_dict[k].__getitem__(index) for k in self.dataset_dict.keys()
    }

__init__(dataset_dict, config)

Initialize a StackixDataset instance.

Parameters:

Name Type Description Default
dataset_dict Dict[str, BaseDataset]

Dictionary mapping modality names to dataset objects

required
config DefaultConfig

Configuration object

required

Raises:

Type Description
ValueError

If the datasets dictionary is empty or if modality datasets have different numbers of samples

NotImplementedError

If the datasets have incompatible shapes for concatenation

Source code in src/autoencodix/data/_stackix_dataset.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __init__(
    self,
    dataset_dict: Dict[str, BaseDataset],
    config: DefaultConfig,
):
    """
    Initialize a StackixDataset instance.

    Args:
        dataset_dict: Dictionary mapping modality names to dataset objects
        config: Configuration object

    Raises:
        ValueError: If the datasets dictionary is empty or if modality datasets have different numbers of samples
        NotImplementedError: If the datasets have incompatible shapes for concatenation
    """
    if not dataset_dict:
        raise ValueError("dataset_dict cannot be empty")

    # Use first modality for base class initialization
    first_modality_key = next(iter(dataset_dict.keys()))
    first_modality = dataset_dict[first_modality_key]
    try:
        data = torch.cat(
            [v.data for _, v in dataset_dict.items() if hasattr(v, "data")], dim=1
        )
    except Exception:
        raise NotImplementedError(
            "Data modalities have different shapes, set requires_paired=True in config"
        )
    super().__init__(
        data=data,
        sample_ids=first_modality.sample_ids,
        config=config,
        split_indices=first_modality.split_indices,
        metadata=first_modality.metadata,
        feature_ids=[
            v.feature_ids
            for v in dataset_dict.values()
            if hasattr(v, "feature_ids")
        ],
    )

    self.dataset_dict = dataset_dict
    self.modality_keys = list(dataset_dict.keys())

    # Ensure all datasets have the same number of samples
    sample_counts = [len(dataset) for dataset in dataset_dict.values()]
    if not all(count == sample_counts[0] for count in sample_counts):
        raise ValueError(
            "All modality datasets must have the same number of samples"
        )

__len__()

Return the number of samples in the dataset.

Source code in src/autoencodix/data/_stackix_dataset.py
74
75
76
def __len__(self) -> int:
    """Return the number of samples in the dataset."""
    return len(next(iter(self.dataset_dict.values())))

get_modality_item(modality, index)

Get a sample for a specific modality. Args: modality: The modality name to retrieve data from index: Index of the sample to retrieve

Returns:

Type Description
Tuple[Tensor, Any]

Tuple of (data tensor, label) for the specified modality and sample index

Raises:

Type Description
KeyError

If the requested modality doesn't exist in the dataset

Source code in src/autoencodix/data/_stackix_dataset.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def get_modality_item(self, modality: str, index: int) -> Tuple[torch.Tensor, Any]:
    """
    Get a sample for a specific modality.
    Args:
        modality: The modality name to retrieve data from
        index: Index of the sample to retrieve

    Returns:
        Tuple of (data tensor, label) for the specified modality and sample index

    Raises:
        KeyError: If the requested modality doesn't exist in the dataset
    """
    if modality not in self.dataset_dict:
        raise KeyError(f"Modality '{modality}' not found in dataset")

    return self.dataset_dict[modality][index]

StackixPreprocessor

Bases: BasePreprocessor

Preprocessor for Stackix architecture, which handles multiple modalities separately.

Unlike GeneralPreprocessor which combines all modalities, StackixPreprocessor keeps modalities separate for individual VAE training in the Stackix architecture.

Attributes: config: Configuration parameters for preprocessing and model architecture _datapackage: Dictionary storing processed data splits _dataset_container:Container for processed datasets by split

Source code in src/autoencodix/data/_stackix_preprocessor.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
class StackixPreprocessor(BasePreprocessor):
    """Preprocessor for Stackix architecture, which handles multiple modalities separately.

    Unlike GeneralPreprocessor which combines all modalities, StackixPreprocessor
    keeps modalities separate for individual VAE training in the Stackix architecture.

    Attributes:
    config: Configuration parameters for preprocessing and model architecture
    _datapackage: Dictionary storing processed data splits
    _dataset_container:Container for processed datasets by split
    """

    def __init__(
        self, config: DefaultConfig, ontologies: Optional[Union[Tuple, Dict]] = None
    ) -> None:
        """Initialize the StackixPreprocessor with the given configuration.
        Args:
            config: Configuration parameters for preprocessing
        """
        super().__init__(config=config)
        self._datapackage: Optional[Dict[str, Any]] = None
        self._dataset_container: Optional[DatasetContainer] = None

    def preprocess(
        self, raw_user_data: Optional[DataPackage] = None, predict_new_data=False
    ) -> DatasetContainer:
        """Execute preprocessing steps for Stackix architecture.

        Args
        raw_user_data: Raw user data to preprocess, or None to use self._datapackage

        Returns:
            Container with MultiModalDataset for each split

        Raises:
            TypeError: If datapackage is None after preprocessing
        """
        self._datapackage = self._general_preprocess(
            raw_user_data, predict_new_data=predict_new_data
        )
        self._dataset_container = DatasetContainer()

        for split in ["train", "valid", "test"]:
            if (
                split not in self._datapackage
                or self._datapackage[split].get("data") is None
            ):
                self._dataset_container[split] = None
                continue
            dataset_dict = self._build_dataset_dict(
                datapackage=self._datapackage[split]["data"],
                split_indices=self._datapackage[split]["indices"],
            )
            stackix_ds = MultiModalDataset(
                datasets=dataset_dict,
                config=self.config,
            )
            self._dataset_container[split] = stackix_ds
        return self._dataset_container

    def _extract_primary_data(self, modality_data: Any) -> np.ndarray:
        primary_data = modality_data.X
        if issparse(primary_data):
            primary_data = primary_data.toarray()
        return primary_data

    @no_type_check
    def _combine_layers(
        self, modality_name: str, modality_data: Any
    ) -> Tuple[np.ndarray, Dict[str, tuple[int]]]:
        """Combine layers from a modality and return the combined data and indices.

        Args:
            modality_name: Name of the modality
            modality_data: Data for the modality

        Returns:
            Combined data and list of (layer_name, start_idx, end_idx) tuples
        """
        layer_list: List[np.ndarray] = []
        layer_indices: Dict[str, Tuple[int]] = {}

        selected_layers: List[str] = self.config.data_config.data_info[
            modality_name
        ].selected_layers

        start_idx = 0
        for layer_name in selected_layers:
            if layer_name == "X":
                data = self._extract_primary_data(modality_data)
                layer_list.append(data)
                end_idx = start_idx + data.shape[1]
                layer_indices[layer_name] = [start_idx, end_idx]  # type: ignore
                start_idx += data.shape[1]
                continue
            elif layer_name in modality_data.layers:
                layer_data = modality_data.layers[layer_name]
                if issparse(layer_data):
                    layer_data = layer_data.toarray()
                layer_list.append(layer_data)
                end_idx = start_idx + layer_data.shape[1]
                layer_indices[layer_name] = [start_idx, end_idx]  # type: ignore
                start_idx += layer_data.shape[1]

        combined_data: np.ndarray = (
            np.concatenate(layer_list, axis=1) if layer_list else np.array([])
        )
        return combined_data, layer_indices

    def _build_dataset_dict(
        self, datapackage: DataPackage, split_indices: np.ndarray
    ) -> Dict[str, NumericDataset]:
        """For each seperate entry in our datapackge we build a NumericDataset
        and store it in a dictionary with the modality as key.

        Args:
            datapackage:DataPackage containing the data to be processed
            split_indices: List of indices for splitting the data
        Returns:
            Dictionary mapping modality names to NumericDataset objects

        """
        dataset_dict: Dict[str, NumericDataset] = {}
        layer_id_dict: Dict[str, Dict[str, List]] = {}
        for k, _ in datapackage:
            attr_name, dict_key = k.split(
                "."
            )  # see DataPackage __iter__ method for why this makes sense
            metadata = None
            if datapackage.annotation is not None:  # prevents error in Single Cell case
                # case where each numeric data has it's own annotation/metadata
                metadata = datapackage.annotation.get(dict_key)
                if metadata is None:
                    # case where there is one "paired" metadata for all numeric data
                    metadata = datapackage.annotation.get("paired")
                # case where we have the unpaired case, but we have one metadata that included all samples across all numeric data
                if metadata is None:
                    if not len(datapackage.annotation.keys()) == 1:
                        raise ValueError(
                            f"annotation key needs to be either 'paired' match a key of the numeric data or only one key exists that holds all unpaired data, please adjust config, got: {datapackage.annotation.keys()}"
                        )
                    metadata_key = next(iter(datapackage.annotation.keys()))
                    metadata = datapackage.annotation.get(metadata_key)

            if attr_name == "multi_bulk":
                df = datapackage[attr_name][dict_key]
                ds = NumericDataset(
                    data=df.values,
                    config=self.config,
                    sample_ids=df.index,
                    feature_ids=df.columns,
                    metadata=metadata,
                    split_indices=split_indices,
                )
                dataset_dict[dict_key] = ds
            elif attr_name == "multi_sc":
                mudata = datapackage["multi_sc"]["multi_sc"]
                if isinstance(mudata, ad.AnnData):
                    raise TypeError(
                        "Expected a MuData object, but got an AnnData object."
                    )

                layer_list: List[Any] = []
                print("building dataset_dict")
                for mod_name, mod_data in mudata.mod.items():
                    layers, indices = self._combine_layers(
                        modality_name=mod_name, modality_data=mod_data
                    )
                    layer_id_dict[mod_name] = indices
                    layer_list.append(layers)
                    mod_concat = np.concatenate(layer_list, axis=1)
                    ds = NumericDataset(
                        data=mod_concat,
                        config=self.config,
                        sample_ids=mudata.obs_names,
                        feature_ids=mod_data.var_names * len(layer_list),
                        metadata=mod_data.obs,
                        split_indices=split_indices,
                    )
                    dataset_dict[mod_name] = ds
            else:
                continue
        self._layer_indices = layer_id_dict
        return dataset_dict

    def format_reconstruction(
        self, reconstruction: Any, result: Optional[Result] = None
    ) -> DataPackage:
        """Takes the reconstructed tensor and from which modality it comes and uses the dataset_dict
        to obtain the format of the original datapackage, but instead of the .data attribute
        we populate this attribute with the reconstructed tensor (as pd.DataFrame or MuData object)

        Args:
            reconstruction: The reconstructed tensor
            result: Optional[Result] containing additional information
        Returns:
            DataPackage with reconstructed data in original format

        """

        if result is None:
            raise ValueError(
                "Result object is not provided. This is needed for the StackixPreprocessor."
            )
        reconstruction = result.sub_reconstructions
        if not isinstance(reconstruction, dict):
            raise TypeError(
                f"Expected value to be of type dict for Stackix, got {type(reconstruction)}."
            )

        if self.config.data_case == DataCase.MULTI_BULK:
            return self._format_multi_bulk(reconstructions=reconstruction)

        elif self.config.data_case == DataCase.MULTI_SINGLE_CELL:
            return self._format_multi_sc(reconstructions=reconstruction)
        else:
            raise ValueError(
                f"Unsupported data_case {self.config.data_case} for StackixPreprocessor."
            )

    def _format_multi_bulk(
        self, reconstructions: Dict[str, torch.Tensor]
    ) -> DataPackage:
        multi_bulk_dict = {}
        annotation_dict = {}
        dp = DataPackage()
        for name, reconstruction in reconstructions.items():
            if not isinstance(reconstruction, torch.Tensor):
                raise TypeError(
                    f"Expected value to be of type torch.Tensor, got {type(reconstruction)}."
                )
            if self._dataset_container is None:
                raise ValueError("Dataset container is not initialized.")
            stackix_ds = self._dataset_container["test"]
            if stackix_ds is None:
                raise ValueError("No dataset found for split: test")
            dataset_dict = stackix_ds.datasets
            df = pd.DataFrame(
                reconstruction.numpy(),
                index=dataset_dict[name].sample_ids,
                columns=dataset_dict[name].feature_ids,
            )
            multi_bulk_dict[name] = df
            annotation_dict[name] = dataset_dict[name].metadata

        dp["multi_bulk"] = multi_bulk_dict
        dp["annotation"] = annotation_dict
        return dp

    def _format_multi_sc(self, reconstructions: Dict[str, torch.Tensor]) -> DataPackage:
        """Formats reconstructed tensors back into a MuData object for single-cell data.

        This uses the stored layer indices to accurately split the reconstructed tensor
        back into the original layers.

        Args:
        reconstruction: Dictionary of reconstructed tensors for each modality

        Returns:
            DataPackage containing the reconstructed MuData object
        """
        import mudata as md

        dp = DataPackage()
        modalities = {}

        if self._dataset_container is None:
            raise ValueError("Dataset container is not initialized.")
        if not hasattr(self, "_layer_indices"):
            raise ValueError(
                "Layer indices not found. Make sure _build_dataset_dict was called."
            )

        stackix_ds = self._dataset_container["test"]
        if stackix_ds is None:
            raise ValueError("No dataset found for split: test")

        dataset_dict = stackix_ds.datasets

        # Process each modality in the reconstruction
        for mod_name, recon_tensor in reconstructions.items():
            if not isinstance(recon_tensor, torch.Tensor):
                raise TypeError(
                    f"Expected value to be of type torch.Tensor, got {type(recon_tensor)}."
                )
            if mod_name not in dataset_dict:
                raise ValueError(f"Modality {mod_name} not found in dataset dictionary")
            original_dataset = dataset_dict[mod_name]

            layer_indices = self._layer_indices[mod_name]

            start_idx, end_idx = layer_indices["X"]
            x_data = recon_tensor.numpy()[:, start_idx:end_idx]

            var_names = original_dataset.feature_ids

            mod_data = ad.AnnData(
                X=x_data,
                obs=original_dataset.metadata,
                var=pd.DataFrame(index=var_names[0 : x_data.shape[1]]),
            )

            # Add additional layers based on stored indices
            for layer_name, ids in layer_indices.items():
                if layer_name == "X":
                    continue  # X is already set

                layer_data = recon_tensor.numpy()[:, ids[0] : ids[1]]
                mod_data.layers[layer_name] = layer_data

            modalities[mod_name] = mod_data

        # Create MuData object from all modalities
        mdata = md.MuData(modalities)

        # Create and return DataPackage
        dp["multi_sc"] = {"multi_sc": mdata}
        return dp

__init__(config, ontologies=None)

Initialize the StackixPreprocessor with the given configuration. Args: config: Configuration parameters for preprocessing

Source code in src/autoencodix/data/_stackix_preprocessor.py
30
31
32
33
34
35
36
37
38
39
def __init__(
    self, config: DefaultConfig, ontologies: Optional[Union[Tuple, Dict]] = None
) -> None:
    """Initialize the StackixPreprocessor with the given configuration.
    Args:
        config: Configuration parameters for preprocessing
    """
    super().__init__(config=config)
    self._datapackage: Optional[Dict[str, Any]] = None
    self._dataset_container: Optional[DatasetContainer] = None

format_reconstruction(reconstruction, result=None)

Takes the reconstructed tensor and from which modality it comes and uses the dataset_dict to obtain the format of the original datapackage, but instead of the .data attribute we populate this attribute with the reconstructed tensor (as pd.DataFrame or MuData object)

Parameters:

Name Type Description Default
reconstruction Any

The reconstructed tensor

required
result Optional[Result]

Optional[Result] containing additional information

None

Returns: DataPackage with reconstructed data in original format

Source code in src/autoencodix/data/_stackix_preprocessor.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def format_reconstruction(
    self, reconstruction: Any, result: Optional[Result] = None
) -> DataPackage:
    """Takes the reconstructed tensor and from which modality it comes and uses the dataset_dict
    to obtain the format of the original datapackage, but instead of the .data attribute
    we populate this attribute with the reconstructed tensor (as pd.DataFrame or MuData object)

    Args:
        reconstruction: The reconstructed tensor
        result: Optional[Result] containing additional information
    Returns:
        DataPackage with reconstructed data in original format

    """

    if result is None:
        raise ValueError(
            "Result object is not provided. This is needed for the StackixPreprocessor."
        )
    reconstruction = result.sub_reconstructions
    if not isinstance(reconstruction, dict):
        raise TypeError(
            f"Expected value to be of type dict for Stackix, got {type(reconstruction)}."
        )

    if self.config.data_case == DataCase.MULTI_BULK:
        return self._format_multi_bulk(reconstructions=reconstruction)

    elif self.config.data_case == DataCase.MULTI_SINGLE_CELL:
        return self._format_multi_sc(reconstructions=reconstruction)
    else:
        raise ValueError(
            f"Unsupported data_case {self.config.data_case} for StackixPreprocessor."
        )

preprocess(raw_user_data=None, predict_new_data=False)

Execute preprocessing steps for Stackix architecture.

Args raw_user_data: Raw user data to preprocess, or None to use self._datapackage

Returns:

Type Description
DatasetContainer

Container with MultiModalDataset for each split

Raises:

Type Description
TypeError

If datapackage is None after preprocessing

Source code in src/autoencodix/data/_stackix_preprocessor.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def preprocess(
    self, raw_user_data: Optional[DataPackage] = None, predict_new_data=False
) -> DatasetContainer:
    """Execute preprocessing steps for Stackix architecture.

    Args
    raw_user_data: Raw user data to preprocess, or None to use self._datapackage

    Returns:
        Container with MultiModalDataset for each split

    Raises:
        TypeError: If datapackage is None after preprocessing
    """
    self._datapackage = self._general_preprocess(
        raw_user_data, predict_new_data=predict_new_data
    )
    self._dataset_container = DatasetContainer()

    for split in ["train", "valid", "test"]:
        if (
            split not in self._datapackage
            or self._datapackage[split].get("data") is None
        ):
            self._dataset_container[split] = None
            continue
        dataset_dict = self._build_dataset_dict(
            datapackage=self._datapackage[split]["data"],
            split_indices=self._datapackage[split]["indices"],
        )
        stackix_ds = MultiModalDataset(
            datasets=dataset_dict,
            config=self.config,
        )
        self._dataset_container[split] = stackix_ds
    return self._dataset_container

TensorAwareDataset

Bases: BaseDataset

Handles dtype mapping and tensor conversion logic.

Source code in src/autoencodix/data/_numeric_dataset.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class TensorAwareDataset(BaseDataset):
    """
    Handles dtype mapping and tensor conversion logic.
    """

    @staticmethod
    def _to_tensor(
        data: Union[torch.Tensor, np.ndarray, Any], dtype: torch.dtype
    ) -> torch.Tensor:
        """
        Convert data to tensor with specified dtype.

        Args:
            data: Input data to convert
            dtype: Desired data type

        Returns:
            Tensor with the specified dtype
        """
        if isinstance(data, torch.Tensor):
            return data.clone().detach().to(dtype)
        else:
            return torch.tensor(data, dtype=dtype)

    @staticmethod
    def _map_float_precision_to_dtype(float_precision: str) -> torch.dtype:
        """
        Map fabric precision types to torch tensor dtypes.

        Args:
            float_precision: Precision type (e.g., 'bf16-mixed', '16-mixed')

        Returns:
            Corresponding torch dtype
        """
        precision_mapping = {
            "transformer-engine": torch.float32,  # Default for transformer-engine
            "transformer-engine-float16": torch.float16,
            "16-true": torch.float16,
            "16-mixed": torch.float16,
            "bf16-true": torch.bfloat16,
            "bf16-mixed": torch.bfloat16,
            "32-true": torch.float32,
            "64-true": torch.float64,
            "64": torch.float64,
            "32": torch.float32,
            "16": torch.float16,
            "bf16": torch.bfloat16,
        }
        # Default to torch.float32 if the precision is not recognized
        return precision_mapping.get(float_precision, torch.float32)

    def _to_df(self) -> pd.DataFrame:
        """
        Convert the dataset to a pandas DataFrame.

        Returns:
            DataFrame representation of the dataset
        """
        if isinstance(self.data, torch.Tensor):
            return pd.DataFrame(
                self.data.numpy(), columns=self.feature_ids, index=self.sample_ids
            )
        elif isinstance(self.data, list) and all(
            isinstance(item, torch.Tensor) for item in self.data
        ):
            # Handle image modality
            # Get the list of tensors
            tensor_list = self.data

            # Flatten each tensor and collect as rows
            rows = [
                (
                    t.flatten().cpu().numpy()
                    if isinstance(t, torch.Tensor)
                    else t.flatten()
                )
                for t in tensor_list
            ]

            df_flat = pd.DataFrame(
                rows,
                index=self.sample_ids,
                columns=["Pixel_" + str(i) for i in range(len(rows[0]))],
            )
            return df_flat
        else:
            raise TypeError(
                "Data is not a torch.Tensor and cannot be converted to DataFrame."
            )

    def _get_target_dtype(self) -> torch.dtype:
        """Get the target dtype based on config, with MPS compatibility check."""
        target_dtype = self._map_float_precision_to_dtype(self.config.float_precision)

        # MPS doesn't support float64, so fallback to float32
        if target_dtype == torch.float64 and self.config.device == "mps":
            print("Warning: MPS doesn't support float64, using float32 instead")
            target_dtype = torch.float32

        return target_dtype

XModalPreprocessor

Bases: GeneralPreprocessor

Preprocessor for cross-modal data, handling multiple data types and their transformations.

Attributes:

Name Type Description
data_config

Configuration specific to data handling.

dataset_dicts

Dictionary holding datasets for different splits (train, test, valid).

Source code in src/autoencodix/data/_xmodal_preprocessor.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
class XModalPreprocessor(GeneralPreprocessor):
    """Preprocessor for cross-modal data, handling multiple data types and their transformations.


    Attributes:
        data_config: Configuration specific to data handling.
        dataset_dicts: Dictionary holding datasets for different splits (train, test, valid).
    """

    def __init__(
        self, config: DefaultConfig, ontologies: Optional[Union[Tuple, Dict]] = None
    ):
        """Initializes the XModalPreprocessor
        Args:
            config: Configuration object for the preprocessor.
            ontologies: Optional ontologies for data processing.
        """
        super().__init__(config=config, ontologies=ontologies)
        self.data_config = config.data_config

    def preprocess(
        self,
        raw_user_data: Optional[DataPackage] = None,
        predict_new_data: bool = False,
    ) -> DatasetContainer:
        """Preprocess the data according to the configuration.
        Args:
            raw_user_data: Optional raw data provided by the user.
            predict_new_data: Flag indicating if new data is being predicted.
        """
        self.dataset_dicts = self._general_preprocess(
            raw_user_data=raw_user_data, predict_new_data=predict_new_data
        )
        datasets = {}
        for split in ["train", "test", "valid"]:
            cur_split = self.dataset_dicts.get(split)
            if cur_split is None:
                print(f"split is None: {split}")
                continue
            cur_data = cur_split.get("data")
            if not isinstance(cur_data, DataPackage):
                raise TypeError(
                    f"expected type of cur_data to be DataPackage, got {type(cur_data)}"
                )
            cur_indices = cur_split.get("indices")
            datasets[split] = MultiModalDataset(
                datasets=self._process_dp(dp=cur_data, indices=cur_indices),
                config=self.config,
            )

        for k, v in self.dataset_dicts.items():
            print(f"key: {k}, type: {type(v)}")

        return DatasetContainer(
            train=datasets["train"], test=datasets["test"], valid=datasets["valid"]
        )

    def format_reconstruction(self, reconstruction, result=None):
        pass

    def _process_dp(self, dp: DataPackage, indices: Dict[str, Any]):
        """Processes a DataPackage into a dictionary of BaseDataset objects.

        Args:
            dp: The DataPackage to process.
            indices: The indices for splitting the data.
        Returns:
            A dictionary mapping modality names to BaseDataset objects.
        """

        dataset_dict: Dict[str, BaseDataset] = {}
        for k, v in dp:
            dp_key, sub_key = k.split(".")
            data = v
            metadata = None
            if dp.annotation is not None:  # prevents error in SingleCell case
                metadata = dp.annotation.get(sub_key)
                if metadata is None:
                    metadata = dp.annotation.get("paired")
                # case where we have the unpaired case, but we have one metadata that included all samples across all numeric data
                if metadata is None:
                    if not len(dp.annotation.keys()) == 1:
                        raise ValueError(
                            f"annotation key needs to be either 'paired' match a key of the numeric data or only one key exists that holds all unpaired data, please adjust config, got: {dp.annotation.keys()}"
                        )
                    metadata_key = next(iter(dp.annotation.keys()))
                    metadata = dp.annotation.get(metadata_key)

            if dp_key == "multi_bulk":
                if not isinstance(data, pd.DataFrame):
                    raise ValueError(
                        f"Expected data for multi_bulk: {k}, {v} to be pd.DataFrame, got {type(data)}"
                    )
                if metadata is None:
                    raise ValueError("metadata cannot be None")
                metadata_num = metadata.loc[
                    data.index
                ]  # needed when we have only one annotation df containing metadata for all modalities
                dataset_dict[k] = NumericDataset(
                    data=data.values,
                    config=self.config,
                    sample_ids=data.index,
                    feature_ids=data.columns,
                    split_indices=indices,
                    metadata=metadata_num,
                )
            elif dp_key == "img":
                if not isinstance(data, list):
                    raise ValueError()
                if not isinstance(data[0], ImgData):
                    raise ValueError()
                dataset_dict[k] = ImageDataset(
                    data=data,
                    config=self.config,
                    split_indices=indices,
                    metadata=metadata,
                )
            elif dp_key == "multi_sc":
                if not isinstance(data, md.MuData):
                    raise ValueError()
                for mod_key, mod_data in data.mod.items():
                    selected_layers = self.config.data_config.data_info[
                        mod_key
                    ].selected_layers
                    if not selected_layers[0] == "X" and len(selected_layers) != 1:
                        raise NotImplementedError(
                            "Xmodalix works only with X layer of single cell data as of now"
                        )
                    dataset_dict[k] = NumericDataset(
                        data=mod_data.X,
                        config=self.config,
                        sample_ids=mod_data.obs_names,
                        feature_ids=mod_data.var_names,
                        split_indices=indices,
                        metadata=mod_data.obs,
                    )

            elif dp_key == "annotation":
                pass

            else:
                raise NotImplementedError(
                    f"Got datapackage attribute: {k}, probably you have added an attribute to the Datapackage class without adjusting this method. Only supports: ['multi_bulk', 'multi_sc', 'img' and 'annotation']"
                )
        return dataset_dict

__init__(config, ontologies=None)

Initializes the XModalPreprocessor Args: config: Configuration object for the preprocessor. ontologies: Optional ontologies for data processing.

Source code in src/autoencodix/data/_xmodal_preprocessor.py
27
28
29
30
31
32
33
34
35
36
def __init__(
    self, config: DefaultConfig, ontologies: Optional[Union[Tuple, Dict]] = None
):
    """Initializes the XModalPreprocessor
    Args:
        config: Configuration object for the preprocessor.
        ontologies: Optional ontologies for data processing.
    """
    super().__init__(config=config, ontologies=ontologies)
    self.data_config = config.data_config

preprocess(raw_user_data=None, predict_new_data=False)

Preprocess the data according to the configuration. Args: raw_user_data: Optional raw data provided by the user. predict_new_data: Flag indicating if new data is being predicted.

Source code in src/autoencodix/data/_xmodal_preprocessor.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def preprocess(
    self,
    raw_user_data: Optional[DataPackage] = None,
    predict_new_data: bool = False,
) -> DatasetContainer:
    """Preprocess the data according to the configuration.
    Args:
        raw_user_data: Optional raw data provided by the user.
        predict_new_data: Flag indicating if new data is being predicted.
    """
    self.dataset_dicts = self._general_preprocess(
        raw_user_data=raw_user_data, predict_new_data=predict_new_data
    )
    datasets = {}
    for split in ["train", "test", "valid"]:
        cur_split = self.dataset_dicts.get(split)
        if cur_split is None:
            print(f"split is None: {split}")
            continue
        cur_data = cur_split.get("data")
        if not isinstance(cur_data, DataPackage):
            raise TypeError(
                f"expected type of cur_data to be DataPackage, got {type(cur_data)}"
            )
        cur_indices = cur_split.get("indices")
        datasets[split] = MultiModalDataset(
            datasets=self._process_dp(dp=cur_data, indices=cur_indices),
            config=self.config,
        )

    for k, v in self.dataset_dicts.items():
        print(f"key: {k}, type: {type(v)}")

    return DatasetContainer(
        train=datasets["train"], test=datasets["test"], valid=datasets["valid"]
    )