Skip to content

Utils Module

AnnDataConverter

Utility class for converting datasets into AnnData or multimodal AnnData dictionaries.

Source code in src/autoencodix/utils/adata_converter.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class AnnDataConverter:
    """Utility class for converting datasets into AnnData or multimodal AnnData dictionaries."""

    @staticmethod
    def _numeric_ds_to_adata(ds: NumericDataset) -> Dict[str, ad.AnnData]:
        """Convert a NumericDataset to an AnnData object.

        Args:
            ds: The numeric dataset to convert.

        Returns:
            An AnnData object containing the dataset's data, features, and metadata.
        """
        if not isinstance(ds.metadata, pd.DataFrame):
            raise ValueError(
                f"metadata needs to be pd.DataFrame, got {type(ds.metadata)}"
            )
        metadata = ds.metadata.copy()
        metadata.index = metadata.index.astype(str)

        var = pd.DataFrame(index=pd.Index(ds.feature_ids, dtype=str))
        return {
            "global": ad.AnnData(
                X=ds.data.detach().cpu().numpy(),
                var=var,
                obs=metadata,
            )
        }

    @staticmethod
    def _parse_multimodal(mds: MultiModalDataset) -> Dict[str, ad.AnnData]:
        """Convert a MultiModalDataset into a dictionary of AnnData objects.

        Args:
            mds: The multimodal dataset to convert.

        Returns:
            A dictionary mapping modality names to AnnData objects.

        Raises:
            NotImplementedError: If any modality is not a NumericDataset.
        """
        result_dict: Dict[str, ad.AnnData] = {}
        for mod_name, dataset in mds.datasets.items():
            if not isinstance(dataset, NumericDataset):
                raise NotImplementedError(
                    f"Feature Importance is only implemented for NumericDataset, got type: {type(dataset)}"
                )
            result_dict[mod_name] = AnnDataConverter._numeric_ds_to_adata(dataset)  # type: ignore
        return result_dict

    @staticmethod
    def dataset_to_adata(
        datasetcontainer: DatasetContainer,
        split: Literal["train", "valid", "test"] = "train",
    ) -> Optional[Dict[str, ad.AnnData]]:
        """Convert a DatasetContainer split to an AnnData or multimodal AnnData dictionary.

        Args:
            datasetcontainer: Container holding train/valid/test datasets.
            split: The dataset split to convert. Defaults to "train".

        Returns:
            A single AnnData object (for NumericDataset) or a dictionary of AnnData objects (for MultiModalDataset).

        Raises:
            ValueError: If the specified split does not exist in the DatasetContainer.
            NotImplementedError: If the dataset type is not supported.
        """
        if not hasattr(datasetcontainer, split):
            raise ValueError(
                f"Split: {split} not present in DatasetContainer: {datasetcontainer}"
            )

        ds = datasetcontainer[split]

        if isinstance(ds, MultiModalDataset):
            return AnnDataConverter._parse_multimodal(ds)
        elif isinstance(ds, NumericDataset):
            return AnnDataConverter._numeric_ds_to_adata(ds)
        elif ds is None:
            import warnings

            warnings.warn(f"No dataset found for split: {split}, returning None")
            return None

        else:
            raise NotImplementedError(
                f"Conversion not implemented for type: {type(ds)}"
            )

dataset_to_adata(datasetcontainer, split='train') staticmethod

Convert a DatasetContainer split to an AnnData or multimodal AnnData dictionary.

Parameters:

Name Type Description Default
datasetcontainer DatasetContainer

Container holding train/valid/test datasets.

required
split Literal['train', 'valid', 'test']

The dataset split to convert. Defaults to "train".

'train'

Returns:

Type Description
Optional[Dict[str, AnnData]]

A single AnnData object (for NumericDataset) or a dictionary of AnnData objects (for MultiModalDataset).

Raises:

Type Description
ValueError

If the specified split does not exist in the DatasetContainer.

NotImplementedError

If the dataset type is not supported.

Source code in src/autoencodix/utils/adata_converter.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
@staticmethod
def dataset_to_adata(
    datasetcontainer: DatasetContainer,
    split: Literal["train", "valid", "test"] = "train",
) -> Optional[Dict[str, ad.AnnData]]:
    """Convert a DatasetContainer split to an AnnData or multimodal AnnData dictionary.

    Args:
        datasetcontainer: Container holding train/valid/test datasets.
        split: The dataset split to convert. Defaults to "train".

    Returns:
        A single AnnData object (for NumericDataset) or a dictionary of AnnData objects (for MultiModalDataset).

    Raises:
        ValueError: If the specified split does not exist in the DatasetContainer.
        NotImplementedError: If the dataset type is not supported.
    """
    if not hasattr(datasetcontainer, split):
        raise ValueError(
            f"Split: {split} not present in DatasetContainer: {datasetcontainer}"
        )

    ds = datasetcontainer[split]

    if isinstance(ds, MultiModalDataset):
        return AnnDataConverter._parse_multimodal(ds)
    elif isinstance(ds, NumericDataset):
        return AnnDataConverter._numeric_ds_to_adata(ds)
    elif ds is None:
        import warnings

        warnings.warn(f"No dataset found for split: {split}, returning None")
        return None

    else:
        raise NotImplementedError(
            f"Conversion not implemented for type: {type(ds)}"
        )

BulkDataReader

Reads bulk data from files based on configuration.

Supports both paired and unpaired data reading strategies.

Attributes:

Name Type Description
config

Configuration object

Source code in src/autoencodix/utils/_bulkreader.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
class BulkDataReader:
    """Reads bulk data from files based on configuration.

    Supports both paired and unpaired data reading strategies.

    Attributes:
        config: Configuration object
    """

    def __init__(self, config: DefaultConfig):
        """Initialize the BulkDataReader with a configuration.

        Args:
            config: Configuration object containing data paths and specifications.
        """
        self.config = config

    def read_data(self) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
        """Read all data according to the configuration.

        Returns:
            A tuple containing (bulk_dataframes, annotation_dataframes)
        """
        if self.config.requires_paired or self.config.requires_paired is None:
            return self.read_paired_data()
        else:
            return self.read_unpaired_data()

    def read_paired_data(
        self,
    ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
        """Reads numeric paired data

        Returns:
            Tuple containing two Dicts:
                 1. with name of the data as key and pandas DataFrame as value
                 2. with  str 'paired' as key and a common annotaion/metadata as DataFrame
        """
        common_samples: Optional[Set[str]] = None
        bulk_dfs: Dict[str, pd.DataFrame] = {}
        annotation_df = pd.DataFrame()
        has_annotation = False

        # First pass: read all data files and track common samples
        for key, info in self.config.data_config.data_info.items():
            if info.data_type == "IMG":
                continue

            file_path = os.path.join(info.file_path)
            df = self._read_tabular_data(file_path, info.sep or "\t")

            if df is None:
                continue

            if info.data_type == "NUMERIC" and not info.is_single_cell:
                current_samples = set(df.index)
                if common_samples is None:
                    common_samples = current_samples
                else:
                    common_samples &= current_samples

                bulk_dfs[key] = df

            elif info.data_type == "ANNOTATION":
                has_annotation = True
                annotation_df = df

        # Second pass: filter to common samples
        if common_samples:
            common_samples_list = list(common_samples)

            # Reindex bulk dataframes to common samples
            for key in bulk_dfs:
                bulk_dfs[key] = bulk_dfs[key].reindex(common_samples_list)

            # Handle annotation dataframe
            if has_annotation:
                annotation = annotation_df.reindex(common_samples_list)
            else:
                # Create empty annotation with common sample indices
                annotation_df = pd.DataFrame(index=common_samples_list)
                annotation = annotation_df
        else:
            print("Warning: No common samples found across datasets")
            annotation = annotation_df

        return bulk_dfs, {"paired": annotation}

    def read_unpaired_data(
        self,
    ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
        """Read data without enforcing sample alignment across modalities.

        Returns:
            A tuple containing (bulk_dataframes, annotation_dataframes)
        """
        bulk_dfs: Dict[str, pd.DataFrame] = {}
        annotations: Dict[str, pd.DataFrame] = {}

        for key, info in self.config.data_config.data_info.items():
            if info.data_type == "IMG" or info.is_single_cell:
                continue  # Skip image and single-cell data

            # Read main data file
            file_path = os.path.join(info.file_path)
            df = self._read_tabular_data(file_path=file_path, sep=info.sep)

            if df is None:
                continue

            if info.data_type == "NUMERIC":
                bulk_dfs[key] = df

                if hasattr(info, "extra_anno_file") and info.extra_anno_file:
                    extra_anno_file = os.path.join(info.extra_anno_file)
                    extra_anno_df = self._read_tabular_data(
                        file_path=extra_anno_file, sep=info.sep
                    )
                    if extra_anno_df is not None:
                        annotations[key] = extra_anno_df

            elif info.data_type == "ANNOTATION":
                annotations[key] = df

        bulk_dfs, annotations = self._validate_and_filter_unpaired(
            bulk_dfs, annotations
        )

        return bulk_dfs, annotations

    def _validate_and_filter_unpaired(
        self,
        bulk_dfs: Dict[str, pd.DataFrame],
        annotations: Dict[str, pd.DataFrame],
    ) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
        """Validates that all samples in bulk data have a corresponding annotation.

        If a single global annotation file is provided, it creates a perfectly
        matched annotation dataframe for each bulk dataframe.

        Warns and drops samples that do not have a corresponding annotation.

        Args:
            bulk_dfs: Dictionary of bulk data modalities and their dataframes.
            annotations: Dictionary of annotation dataframes, possibly one global one.

        Returns:
            A tuple of two dictionaries:
            1. The filtered bulk dataframes.
            2. The new, synchronized annotation dataframes, with keys matching the bulk dataframes.
        """
        if not annotations:
            warnings.warn(
                "No annotation files were provided. Cannot validate sample annotations."
            )
            return bulk_dfs, {}

        # If annotations have keys that match bulk_dfs, we assume they are already paired.
        # This logic focuses on the case where one annotation file is meant for all bulk files.
        # A simple heuristic: if there is one annotation file and its key is not in bulk_dfs.
        annotation_keys = set(annotations.keys())
        bulk_keys = set(bulk_dfs.keys())

        # Check for the global annotation case
        if len(annotation_keys) == 1 and not annotation_keys.intersection(bulk_keys):
            global_annotation_key = list(annotation_keys)[0]
            global_annotation_df = annotations[global_annotation_key]

            filtered_bulk_dfs = {}
            synchronized_annotations = {}

            for key, data_df in bulk_dfs.items():
                data_samples = data_df.index
                annotation_samples = global_annotation_df.index

                # Find the intersection of valid sample IDs
                valid_ids = data_samples.intersection(annotation_samples)

                # Check for and warn about dropped samples
                if len(valid_ids) < len(data_samples):
                    missing_ids = sorted(list(set(data_samples) - set(valid_ids)))
                    warnings.warn(
                        f"For data modality '{key}', {len(missing_ids)} sample(s) "
                        f"were found without a corresponding annotation and will be dropped: {missing_ids}"
                    )

                # Filter both the data and the annotation to the valid IDs
                filtered_bulk_dfs[key] = data_df.loc[valid_ids]
                synchronized_annotations[key] = global_annotation_df.loc[valid_ids]

            return filtered_bulk_dfs, synchronized_annotations
        else:
            # Handle the case where annotations are already meant to be paired by key
            # (Or a more complex case we are not handling yet)
            warnings.warn(
                "Proceeding without global annotation synchronization. Assuming annotations are pre-aligned by key."
            )
            return bulk_dfs, annotations

    def _read_tabular_data(
        self, file_path: str, sep: Union[str, None] = None
    ) -> pd.DataFrame:
        """Read tabular data from a file with error handling.

        Args:
        file_path: Path to the data file.
        sep: Separator character for CSV/TSV files.

        Returns:
            The loaded DataFrame.
        """
        try:
            if file_path.endswith(".parquet"):
                print(f"reading parquet: {file_path}")
                return pd.read_parquet(file_path)
            elif file_path.endswith((".csv", ".txt", ".tsv")):
                return pd.read_csv(file_path, sep=sep, index_col=0)
            else:
                raise ValueError(
                    f"Unsupported file type for {file_path}. Supported formats: .parquet, .csv, .txt, .tsv"
                )
        except Exception as e:
            raise e

__init__(config)

Initialize the BulkDataReader with a configuration.

Parameters:

Name Type Description Default
config DefaultConfig

Configuration object containing data paths and specifications.

required
Source code in src/autoencodix/utils/_bulkreader.py
19
20
21
22
23
24
25
def __init__(self, config: DefaultConfig):
    """Initialize the BulkDataReader with a configuration.

    Args:
        config: Configuration object containing data paths and specifications.
    """
    self.config = config

read_data()

Read all data according to the configuration.

Returns:

Type Description
Tuple[Dict[str, DataFrame], Dict[str, DataFrame]]

A tuple containing (bulk_dataframes, annotation_dataframes)

Source code in src/autoencodix/utils/_bulkreader.py
27
28
29
30
31
32
33
34
35
36
def read_data(self) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
    """Read all data according to the configuration.

    Returns:
        A tuple containing (bulk_dataframes, annotation_dataframes)
    """
    if self.config.requires_paired or self.config.requires_paired is None:
        return self.read_paired_data()
    else:
        return self.read_unpaired_data()

read_paired_data()

Reads numeric paired data

Returns:

Type Description
Tuple[Dict[str, DataFrame], Dict[str, DataFrame]]

Tuple containing two Dicts: 1. with name of the data as key and pandas DataFrame as value 2. with str 'paired' as key and a common annotaion/metadata as DataFrame

Source code in src/autoencodix/utils/_bulkreader.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def read_paired_data(
    self,
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
    """Reads numeric paired data

    Returns:
        Tuple containing two Dicts:
             1. with name of the data as key and pandas DataFrame as value
             2. with  str 'paired' as key and a common annotaion/metadata as DataFrame
    """
    common_samples: Optional[Set[str]] = None
    bulk_dfs: Dict[str, pd.DataFrame] = {}
    annotation_df = pd.DataFrame()
    has_annotation = False

    # First pass: read all data files and track common samples
    for key, info in self.config.data_config.data_info.items():
        if info.data_type == "IMG":
            continue

        file_path = os.path.join(info.file_path)
        df = self._read_tabular_data(file_path, info.sep or "\t")

        if df is None:
            continue

        if info.data_type == "NUMERIC" and not info.is_single_cell:
            current_samples = set(df.index)
            if common_samples is None:
                common_samples = current_samples
            else:
                common_samples &= current_samples

            bulk_dfs[key] = df

        elif info.data_type == "ANNOTATION":
            has_annotation = True
            annotation_df = df

    # Second pass: filter to common samples
    if common_samples:
        common_samples_list = list(common_samples)

        # Reindex bulk dataframes to common samples
        for key in bulk_dfs:
            bulk_dfs[key] = bulk_dfs[key].reindex(common_samples_list)

        # Handle annotation dataframe
        if has_annotation:
            annotation = annotation_df.reindex(common_samples_list)
        else:
            # Create empty annotation with common sample indices
            annotation_df = pd.DataFrame(index=common_samples_list)
            annotation = annotation_df
    else:
        print("Warning: No common samples found across datasets")
        annotation = annotation_df

    return bulk_dfs, {"paired": annotation}

read_unpaired_data()

Read data without enforcing sample alignment across modalities.

Returns:

Type Description
Tuple[Dict[str, DataFrame], Dict[str, DataFrame]]

A tuple containing (bulk_dataframes, annotation_dataframes)

Source code in src/autoencodix/utils/_bulkreader.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def read_unpaired_data(
    self,
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
    """Read data without enforcing sample alignment across modalities.

    Returns:
        A tuple containing (bulk_dataframes, annotation_dataframes)
    """
    bulk_dfs: Dict[str, pd.DataFrame] = {}
    annotations: Dict[str, pd.DataFrame] = {}

    for key, info in self.config.data_config.data_info.items():
        if info.data_type == "IMG" or info.is_single_cell:
            continue  # Skip image and single-cell data

        # Read main data file
        file_path = os.path.join(info.file_path)
        df = self._read_tabular_data(file_path=file_path, sep=info.sep)

        if df is None:
            continue

        if info.data_type == "NUMERIC":
            bulk_dfs[key] = df

            if hasattr(info, "extra_anno_file") and info.extra_anno_file:
                extra_anno_file = os.path.join(info.extra_anno_file)
                extra_anno_df = self._read_tabular_data(
                    file_path=extra_anno_file, sep=info.sep
                )
                if extra_anno_df is not None:
                    annotations[key] = extra_anno_df

        elif info.data_type == "ANNOTATION":
            annotations[key] = df

    bulk_dfs, annotations = self._validate_and_filter_unpaired(
        bulk_dfs, annotations
    )

    return bulk_dfs, annotations

ImageDataReader

Reads and processes image data.

Reads all images from the specified directory, processes them, and returns a list of ImgData objects.

Source code in src/autoencodix/utils/_imgreader.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
class ImageDataReader:
    """Reads and processes image data.

    Reads all images from the specified directory, processes them,
    and returns a list of ImgData objects.
    """

    def __init__(self, config: DefaultConfig):
        self.config = config

    def validate_image_path(self, image_path: Union[str, Path]) -> bool:
        """Checks if file extension is allowed:

        Allowed are (independent of capitalization):
            - jpg
            - jpeg
            - png
            - tif
            - tiff

        Args:
            image_path: path or str of image to read
        """
        path = Path(image_path) if isinstance(image_path, str) else image_path
        return (
            path.exists()
            and path.is_file()
            and path.suffix.lower() in {".jpg", ".jpeg", ".png", ".tif", ".tiff"}
        )

    def parse_image_to_tensor(
        self,
        image_path: Union[str, Path],
        to_h: Optional[int] = None,
        to_w: Optional[int] = None,
    ) -> np.ndarray:
        """Reads an image from the given path, optionally resizes it, and converts it to a tensor.

        Args:
            image_path: The path to the image file.
            to_h: The desired height of the output tensor, by default None.
            to_w: The desired width of the output tensor, by default None.

        Returns:
            The processed image as a tensor.

        Raises:
            FileNotFoundError: If the image path is invalid or the image cannot be read.
            ImageProcessingError: If the image format is unsupported or an unexpected error occurs during processing.
        """

        if not self.validate_image_path(image_path):
            raise FileNotFoundError(f"Invalid image path: {image_path}")
        image_path = Path(image_path)
        SUPPORTED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".tif", ".tiff"}
        if image_path.suffix.lower() not in SUPPORTED_EXTENSIONS:
            raise ImageProcessingError(
                f"Unsupported image format: {image_path.suffix}. "
                f"Supported formats are: {', '.join(SUPPORTED_EXTENSIONS)}"
            )
        try:
            if image_path.suffix.lower() in {".tif", ".tiff"}:
                image = cv2.imread(str(image_path), cv2.IMREAD_UNCHANGED)
            else:
                image = cv2.imread(str(image_path))

            if image is None:
                raise FileNotFoundError(f"Failed to read image: {image_path}")

            (h, w, _) = image.shape[:3]
            if to_h is None:
                to_h = h
            if to_w is None:
                to_w = w

            if not (2 <= len(image.shape) <= 3):
                raise ImageProcessingError(
                    f"Image has unsupported shape: {image.shape}. "
                    "Supported shapes are 2D and 3D."
                )

            image = cv2.resize(image, (to_w, to_h), interpolation=cv2.INTER_AREA)

            if len(image.shape) == 3:
                image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            if len(image.shape) == 2:
                image = np.expand_dims(image, axis=2)

            image = image.transpose(2, 0, 1)
            return image

        except Exception as e:
            raise e

    def read_all_images_from_dir(
        self,
        img_dir: str,
        to_h: Optional[int],
        to_w: Optional[int],
        annotation_df: pd.DataFrame,
        is_paired: Union[bool, None] = None,
    ) -> List[ImgData]:
        """Reads all images from a specified directory, processes them, returns list of ImgData objects.

        Args:
            img_dir: The directory containing the images.
            to_h: The desired height of the output tensors.
            to_w: The desired width of the output tensors.
            annotation_df: DataFrame containing image annotations.
            is_paired: Whether the images are paired with annotations.

        Returns:
            List of processed image data objects.

        Raises:
            ValueError: If the annotation DataFrame is missing required columns.
        """
        if self.config.img_path_col not in annotation_df.columns:
            raise ValueError(
                f" The defined column for image paths: {self.config.img_path_col} column is missing in the annotation_df\
                             you can define this in the config via the param `img_path_col`"
            )

        SUPPORTED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".tif", ".tiff"}
        paths = [
            os.path.join(img_dir, f)
            for f in os.listdir(img_dir)
            if Path(f).suffix.lower() in SUPPORTED_EXTENSIONS
        ]
        if is_paired or is_paired is None:
            paths = [
                p
                for p in paths
                if os.path.basename(p)
                in annotation_df[self.config.img_path_col].tolist()
            ]
        imgs = []
        for p in paths:
            img = self.parse_image_to_tensor(image_path=p, to_h=to_h, to_w=to_w)
            img_path = os.path.basename(p)
            subset: Union[pd.Series, pd.DataFrame] = annotation_df[
                annotation_df[self.config.img_path_col] == img_path
            ]
            if not subset.empty:
                imgs.append(
                    ImgData(
                        img=img,
                        sample_id=str(subset.index[0]),
                        annotation=subset,
                    )
                )
        return imgs

    def read_annotation_file(self, data_info: DataInfo) -> pd.DataFrame:
        """Reads annotation file and returns DataFrame with file contents
        Args:
            data_info: specific part of the Configuration object for input data
        Returns:
            DataFrame with annotation data.

        """
        anno_file = (
            os.path.join(data_info.file_path)
            if data_info.extra_anno_file is None
            else os.path.join(data_info.extra_anno_file)
        )
        sep = data_info.sep
        if anno_file.endswith(".parquet"):
            annotation = pd.read_parquet(anno_file)
        elif anno_file.endswith((".csv", ".txt", ".tsv")):
            annotation = pd.read_csv(anno_file, sep=sep, index_col=0, engine="python")
        else:
            raise ValueError(f"Unsupported file type for: {anno_file}")
        return annotation

    def read_data(
        self, config: DefaultConfig
    ) -> Tuple[Dict[str, List[ImgData]], Dict[str, pd.DataFrame]]:
        """Read image data from the specified directory based on configuration.

        Args:
            config: The configuration object containing the data configuration.

        Returns:
            A Tuple of Dicts:
            1. Dict with type of image data as key and actual List of ImgData as value.
            2. Dict with type of image data as key and DataFrame of annotation data as value.

        Raises:
            Exception: If no image data is found in the configuration or other validation errors occur.
        """
        # Find all image data sources in config
        image_sources = {
            k: v
            for k, v in config.data_config.data_info.items()
            if v.data_type == "IMG"
        }

        if not image_sources:
            raise ValueError("No image data found in the configuration.")

        result = {}
        annotation = {}
        for key, img_info in image_sources.items():
            try:
                result[key], annotation[key] = self._read_data(config, img_info)
                print(f"Successfully loaded {len(result[key])} images for {key}")
            except Exception as e:
                print(f"Error loading images for {key}: {str(e)}")
                # Decide whether to raise or continue based on your requirements

        return result, annotation

    def _read_data(
        self, config: DefaultConfig, img_info: DataInfo
    ) -> Tuple[List[ImgData], pd.DataFrame]:
        """Read data for a specific image source.

        Args:
            config: The configuration object containing the data configuration.
            img_info: The specific image info configuration.

        Returns:
            A Tuple of Dicts:
            1. Dict with type of image data as key and actual List of ImgData as value.
            2. Dict with type of image data as key and DataFrame of annotation data as value.

        """
        img_dir = img_info.file_path
        img_size_finder: ImageSizeFinder = ImageSizeFinder(config)
        to_h, to_w = img_size_finder.get_nearest_quadratic_image_size()

        if img_info.extra_anno_file is not None:
            # Use image-specific annotation file if provided
            annotation = self.read_annotation_file(img_info)
        else:
            # Otherwise use the global annotation file
            try:
                anno_info = next(
                    f
                    for f in config.data_config.data_info.values()
                    if f.data_type == "ANNOTATION"
                )
                annotation = self.read_annotation_file(anno_info)
            except StopIteration:
                raise ValueError("No annotation data found in the configuration.")

        images = self.read_all_images_from_dir(
            img_dir=img_dir,
            to_h=to_h,
            to_w=to_w,
            annotation_df=annotation,
            is_paired=config.requires_paired,
        )
        annotations: pd.DataFrame = pd.concat([img.annotation for img in images])

        return images, annotations

parse_image_to_tensor(image_path, to_h=None, to_w=None)

Reads an image from the given path, optionally resizes it, and converts it to a tensor.

Parameters:

Name Type Description Default
image_path Union[str, Path]

The path to the image file.

required
to_h Optional[int]

The desired height of the output tensor, by default None.

None
to_w Optional[int]

The desired width of the output tensor, by default None.

None

Returns:

Type Description
ndarray

The processed image as a tensor.

Raises:

Type Description
FileNotFoundError

If the image path is invalid or the image cannot be read.

ImageProcessingError

If the image format is unsupported or an unexpected error occurs during processing.

Source code in src/autoencodix/utils/_imgreader.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def parse_image_to_tensor(
    self,
    image_path: Union[str, Path],
    to_h: Optional[int] = None,
    to_w: Optional[int] = None,
) -> np.ndarray:
    """Reads an image from the given path, optionally resizes it, and converts it to a tensor.

    Args:
        image_path: The path to the image file.
        to_h: The desired height of the output tensor, by default None.
        to_w: The desired width of the output tensor, by default None.

    Returns:
        The processed image as a tensor.

    Raises:
        FileNotFoundError: If the image path is invalid or the image cannot be read.
        ImageProcessingError: If the image format is unsupported or an unexpected error occurs during processing.
    """

    if not self.validate_image_path(image_path):
        raise FileNotFoundError(f"Invalid image path: {image_path}")
    image_path = Path(image_path)
    SUPPORTED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".tif", ".tiff"}
    if image_path.suffix.lower() not in SUPPORTED_EXTENSIONS:
        raise ImageProcessingError(
            f"Unsupported image format: {image_path.suffix}. "
            f"Supported formats are: {', '.join(SUPPORTED_EXTENSIONS)}"
        )
    try:
        if image_path.suffix.lower() in {".tif", ".tiff"}:
            image = cv2.imread(str(image_path), cv2.IMREAD_UNCHANGED)
        else:
            image = cv2.imread(str(image_path))

        if image is None:
            raise FileNotFoundError(f"Failed to read image: {image_path}")

        (h, w, _) = image.shape[:3]
        if to_h is None:
            to_h = h
        if to_w is None:
            to_w = w

        if not (2 <= len(image.shape) <= 3):
            raise ImageProcessingError(
                f"Image has unsupported shape: {image.shape}. "
                "Supported shapes are 2D and 3D."
            )

        image = cv2.resize(image, (to_w, to_h), interpolation=cv2.INTER_AREA)

        if len(image.shape) == 3:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        if len(image.shape) == 2:
            image = np.expand_dims(image, axis=2)

        image = image.transpose(2, 0, 1)
        return image

    except Exception as e:
        raise e

read_all_images_from_dir(img_dir, to_h, to_w, annotation_df, is_paired=None)

Reads all images from a specified directory, processes them, returns list of ImgData objects.

Parameters:

Name Type Description Default
img_dir str

The directory containing the images.

required
to_h Optional[int]

The desired height of the output tensors.

required
to_w Optional[int]

The desired width of the output tensors.

required
annotation_df DataFrame

DataFrame containing image annotations.

required
is_paired Union[bool, None]

Whether the images are paired with annotations.

None

Returns:

Type Description
List[ImgData]

List of processed image data objects.

Raises:

Type Description
ValueError

If the annotation DataFrame is missing required columns.

Source code in src/autoencodix/utils/_imgreader.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def read_all_images_from_dir(
    self,
    img_dir: str,
    to_h: Optional[int],
    to_w: Optional[int],
    annotation_df: pd.DataFrame,
    is_paired: Union[bool, None] = None,
) -> List[ImgData]:
    """Reads all images from a specified directory, processes them, returns list of ImgData objects.

    Args:
        img_dir: The directory containing the images.
        to_h: The desired height of the output tensors.
        to_w: The desired width of the output tensors.
        annotation_df: DataFrame containing image annotations.
        is_paired: Whether the images are paired with annotations.

    Returns:
        List of processed image data objects.

    Raises:
        ValueError: If the annotation DataFrame is missing required columns.
    """
    if self.config.img_path_col not in annotation_df.columns:
        raise ValueError(
            f" The defined column for image paths: {self.config.img_path_col} column is missing in the annotation_df\
                         you can define this in the config via the param `img_path_col`"
        )

    SUPPORTED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".tif", ".tiff"}
    paths = [
        os.path.join(img_dir, f)
        for f in os.listdir(img_dir)
        if Path(f).suffix.lower() in SUPPORTED_EXTENSIONS
    ]
    if is_paired or is_paired is None:
        paths = [
            p
            for p in paths
            if os.path.basename(p)
            in annotation_df[self.config.img_path_col].tolist()
        ]
    imgs = []
    for p in paths:
        img = self.parse_image_to_tensor(image_path=p, to_h=to_h, to_w=to_w)
        img_path = os.path.basename(p)
        subset: Union[pd.Series, pd.DataFrame] = annotation_df[
            annotation_df[self.config.img_path_col] == img_path
        ]
        if not subset.empty:
            imgs.append(
                ImgData(
                    img=img,
                    sample_id=str(subset.index[0]),
                    annotation=subset,
                )
            )
    return imgs

read_annotation_file(data_info)

Reads annotation file and returns DataFrame with file contents Args: data_info: specific part of the Configuration object for input data Returns: DataFrame with annotation data.

Source code in src/autoencodix/utils/_imgreader.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def read_annotation_file(self, data_info: DataInfo) -> pd.DataFrame:
    """Reads annotation file and returns DataFrame with file contents
    Args:
        data_info: specific part of the Configuration object for input data
    Returns:
        DataFrame with annotation data.

    """
    anno_file = (
        os.path.join(data_info.file_path)
        if data_info.extra_anno_file is None
        else os.path.join(data_info.extra_anno_file)
    )
    sep = data_info.sep
    if anno_file.endswith(".parquet"):
        annotation = pd.read_parquet(anno_file)
    elif anno_file.endswith((".csv", ".txt", ".tsv")):
        annotation = pd.read_csv(anno_file, sep=sep, index_col=0, engine="python")
    else:
        raise ValueError(f"Unsupported file type for: {anno_file}")
    return annotation

read_data(config)

Read image data from the specified directory based on configuration.

Parameters:

Name Type Description Default
config DefaultConfig

The configuration object containing the data configuration.

required

Returns:

Type Description
Dict[str, List[ImgData]]

A Tuple of Dicts:

Dict[str, DataFrame]
  1. Dict with type of image data as key and actual List of ImgData as value.
Tuple[Dict[str, List[ImgData]], Dict[str, DataFrame]]
  1. Dict with type of image data as key and DataFrame of annotation data as value.

Raises:

Type Description
Exception

If no image data is found in the configuration or other validation errors occur.

Source code in src/autoencodix/utils/_imgreader.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def read_data(
    self, config: DefaultConfig
) -> Tuple[Dict[str, List[ImgData]], Dict[str, pd.DataFrame]]:
    """Read image data from the specified directory based on configuration.

    Args:
        config: The configuration object containing the data configuration.

    Returns:
        A Tuple of Dicts:
        1. Dict with type of image data as key and actual List of ImgData as value.
        2. Dict with type of image data as key and DataFrame of annotation data as value.

    Raises:
        Exception: If no image data is found in the configuration or other validation errors occur.
    """
    # Find all image data sources in config
    image_sources = {
        k: v
        for k, v in config.data_config.data_info.items()
        if v.data_type == "IMG"
    }

    if not image_sources:
        raise ValueError("No image data found in the configuration.")

    result = {}
    annotation = {}
    for key, img_info in image_sources.items():
        try:
            result[key], annotation[key] = self._read_data(config, img_info)
            print(f"Successfully loaded {len(result[key])} images for {key}")
        except Exception as e:
            print(f"Error loading images for {key}: {str(e)}")
            # Decide whether to raise or continue based on your requirements

    return result, annotation

validate_image_path(image_path)

Checks if file extension is allowed:

Allowed are (independent of capitalization): - jpg - jpeg - png - tif - tiff

Parameters:

Name Type Description Default
image_path Union[str, Path]

path or str of image to read

required
Source code in src/autoencodix/utils/_imgreader.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def validate_image_path(self, image_path: Union[str, Path]) -> bool:
    """Checks if file extension is allowed:

    Allowed are (independent of capitalization):
        - jpg
        - jpeg
        - png
        - tif
        - tiff

    Args:
        image_path: path or str of image to read
    """
    path = Path(image_path) if isinstance(image_path, str) else image_path
    return (
        path.exists()
        and path.is_file()
        and path.suffix.lower() in {".jpg", ".jpeg", ".png", ".tif", ".tiff"}
    )

ModelOutput dataclass

A structured output dataclass for autoencoder models.

This class is used to encapsulate the outputs of autoencoder models in a consistent format, allowing for flexibility in the type of outputs returned by different architectures.

Attributes:

Name Type Description
reconstruction Tensor

The reconstructed input data.

latent_mean Optional[Tensor]

The mean of the latent space distribution, applicable for models like VAEs.

latent_logvar Optional[Tensor]

The log variance of the latent space distribution, applicable for models like VAEs.

additional_info Optional[dict]

A dictionary to store any additional information or intermediate outputs.

Source code in src/autoencodix/utils/_model_output.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
@dataclass
class ModelOutput:
    """A structured output dataclass for autoencoder models.

    This class is used to encapsulate the outputs of autoencoder models in a
    consistent format, allowing for flexibility in the type of outputs returned
    by different architectures.

    Attributes:
        reconstruction: The reconstructed input data.
        latent_mean: The mean of the latent space distribution, applicable for models like VAEs.
        latent_logvar: The log variance of the latent space distribution, applicable for models like VAEs.
        additional_info:  A dictionary to store any additional information or intermediate outputs.
    """

    reconstruction: torch.Tensor
    latentspace: torch.Tensor
    latent_mean: Optional[torch.Tensor] = None
    latent_logvar: Optional[torch.Tensor] = None
    additional_info: Optional[dict] = None

    def __iter__(self):
        yield self

Result dataclass

A dataclass to store results from the pipeline with predefined keys.

Attributes:

Name Type Description
latentspaces TrainingDynamics

TrainingDynamics object storing latent space representations for 'train', 'valid', and 'test' splits.

sample_ids TrainingDynamics

TrainingDynamics object storing sample identifiers for 'train', 'valid', and 'test' splits.

reconstructions TrainingDynamics

TrainingDynamics object storing reconstructed outputs for 'train', 'valid', and 'test' splits.

mus TrainingDynamics

TrainingDynamics object storing mean values of latent distributions for 'train', 'valid', and 'test' splits.

sigmas TrainingDynamics

TrainingDynamics object storing standard deviations of latent distributions for 'train', 'valid', and 'test' splits.

losses TrainingDynamics

TrainingDynamics object storing the total loss for different epochs and splits ('train', 'valid', 'test').

sub_losses LossRegistry

LossRegistry object (extendable) for all sublosses.

preprocessed_data Tensor

torch.Tensor containing data after preprocessing.

model Union[Dict[str, Module], Module]

final trained torch.nn.Module model.

model_checkpoints TrainingDynamics

TrainingDynamics object storing model state at each checkpoint.

datasets Optional[DatasetContainer]

Optional[DatasetContainer] containing train, valid, and test datasets.

new_datasets Optional[DatasetContainer]

Optional[DatasetContainer] containing new train, valid, and test datasets.

adata_latent Optional[AnnData]

Optional[AnnData] containing latent representations as AnnData.

final_reconstruction Optional[Union[DataPackage, MuData]]

Optional[Union[DataPackage, MuData]] containing final reconstruction results.

sub_results Optional[Dict[str, Any]]

Optional[Dict[str, Any]] containing sub-results for multi-task or multi-modal models.

sub_reconstructions Optional[Dict[str, Any]]

Optional[Dict[str, Any]] containing sub-reconstructions for multi-task or multi-modal models.

embedding_evaluation DataFrame

pd.DataFrame containing embedding evaluation results.

Source code in src/autoencodix/utils/_result.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
@dataclass
class Result:
    """A dataclass to store results from the pipeline with predefined keys.

    Attributes:
        latentspaces: TrainingDynamics object storing latent space representations for 'train', 'valid', and 'test' splits.
        sample_ids: TrainingDynamics object storing sample identifiers for 'train', 'valid', and 'test' splits.
        reconstructions: TrainingDynamics object storing reconstructed outputs for 'train', 'valid', and 'test' splits.
        mus: TrainingDynamics object storing mean values of latent distributions for 'train', 'valid', and 'test' splits.
        sigmas: TrainingDynamics object storing standard deviations of latent distributions for 'train', 'valid', and 'test' splits.
        losses: TrainingDynamics object storing the total loss for different epochs and splits ('train', 'valid', 'test').
        sub_losses: LossRegistry object (extendable) for all sublosses.
        preprocessed_data: torch.Tensor containing data after preprocessing.
        model: final trained torch.nn.Module model.
        model_checkpoints: TrainingDynamics object storing model state at each checkpoint.
        datasets: Optional[DatasetContainer] containing train, valid, and test datasets.
        new_datasets: Optional[DatasetContainer] containing new train, valid, and test datasets.
        adata_latent: Optional[AnnData] containing latent representations as AnnData.
        final_reconstruction: Optional[Union[DataPackage, MuData]] containing final reconstruction results.
        sub_results: Optional[Dict[str, Any]] containing sub-results for multi-task or multi-modal models.
        sub_reconstructions: Optional[Dict[str, Any]] containing sub-reconstructions for multi-task or multi-modal models.
        embedding_evaluation: pd.DataFrame containing embedding evaluation results.
    """

    latentspaces: TrainingDynamics = field(default_factory=TrainingDynamics)
    sample_ids: TrainingDynamics = field(default_factory=TrainingDynamics)
    reconstructions: TrainingDynamics = field(default_factory=TrainingDynamics)
    mus: TrainingDynamics = field(default_factory=TrainingDynamics)
    sigmas: TrainingDynamics = field(default_factory=TrainingDynamics)
    losses: TrainingDynamics = field(default_factory=TrainingDynamics)
    sub_losses: LossRegistry = field(default_factory=LossRegistry)
    preprocessed_data: torch.Tensor = field(default_factory=torch.Tensor)
    model: Union[Dict[str, torch.nn.Module], torch.nn.Module] = field(
        default_factory=torch.nn.Module
    )
    model_checkpoints: TrainingDynamics = field(default_factory=TrainingDynamics)

    datasets: Optional[DatasetContainer] = field(
        default_factory=lambda: DatasetContainer(train=None, valid=None, test=None)
    )
    new_datasets: Optional[DatasetContainer] = field(
        default_factory=lambda: DatasetContainer(train=None, valid=None, test=None)
    )

    adata_latent: Optional[AnnData] = field(default_factory=AnnData)
    final_reconstruction: Optional[
        Union[DataPackage, MuData]  # ty: ignore[invalid-type-form]
    ] = field(default=None)
    sub_results: Optional[Dict[str, Any]] = field(default=None)
    sub_reconstructions: Optional[Dict[str, Any]] = field(default=None)

    # Embedding evaluation results
    embedding_evaluation: pd.DataFrame = field(default_factory=pd.DataFrame)

    # plots: Dict[str, Any] = field(
    #     default_factory=nested_dict
    # )  ## Nested dictionary of plots as figure handles

    def __getitem__(self, key: str) -> Any:
        """Retrieve the value associated with a specific key.

        Args:
            key: The name of the attribute to retrieve.
        Returns:
            The value of the specified attribute.
        Raises:
            KeyError - If the key is not a valid attribute of the Results class.

        """
        if not hasattr(self, key):
            raise KeyError(
                f"Invalid key: '{key}'. Allowed keys are: {', '.join(self.__annotations__.keys())}"
            )
        return getattr(self, key)

    def __setitem__(self, key: str, value: Any) -> None:
        """Assign a value to a specific attribute.

        Args:
            key: The name of the attribute to set.
            value: The value to assign to the attribute.
        Raises:
            KeyError: If the key is not a valid attribute of the Results class.

        """
        if not hasattr(self, key):
            raise KeyError(
                f"Invalid key: '{key}'. Allowed keys are: {', '.join(self.__annotations__.keys())}"
            )
        setattr(self, key, value)

    def _is_empty_value(self, value: Any) -> bool:
        """
        Helper method to check if an attribute of the Result object is empty.

        Parameters:
            value (Any): The value to check
        Returns:
            bool: True if the value is empty, False otherwise

        """

        if isinstance(value, TrainingDynamics):
            return len(value._data) == 0
        elif isinstance(value, torch.Tensor):
            return value.numel() == 0
        elif isinstance(value, torch.nn.Module):
            return sum(p.numel() for p in value.parameters()) == 0
        elif isinstance(value, DatasetContainer):
            return all(v is None for v in [value.train, value.valid, value.test])
        elif isinstance(value, LossRegistry):
            # single Nones are handled in update method (skipped)
            return all(v is None for _, v in value.losses())

        return False

    def update(self, other: "Result") -> None:
        """Update the current Result object with values from another Result object.

        For TrainingDynamics, merges the data across epochs and splits and overwrites if already exists.
        For all other attributes, replaces the current value with the other value.

        Args:
            other: The Result object to update from.
        Raises:
            TypeError: If the input object is not a Result instance

        """
        if not isinstance(other, Result):
            raise TypeError(f"Expected Result object, got {type(other)}")

        for field_name in self.__annotations__.keys():
            current_value = getattr(self, field_name)
            other_value = getattr(other, field_name)
            if self._is_empty_value(other_value):
                continue

            # Handle TrainingDynamics - merge data
            if isinstance(current_value, TrainingDynamics):
                current_value = self._update_traindynamics(current_value, other_value)
            # For all other types - replace with other value
            if isinstance(current_value, LossRegistry):
                for key, value in other_value.losses():
                    if value is None:
                        continue
                    if not isinstance(value, TrainingDynamics):
                        raise ValueError(
                            f"Expected TrainingDynamics object, got {type(value)}"
                        )
                    updated_dynamic = self._update_traindynamics(
                        current_value=current_value.get(key=key), other_value=value
                    )
                    current_value.set(key=key, value=updated_dynamic)
            else:
                setattr(self, field_name, other_value)

    def _update_traindynamics(
        self, current_value: TrainingDynamics, other_value: TrainingDynamics
    ) -> TrainingDynamics:
        """Update TrainingDynamics object with values from another TrainingDynamics object.

        Args:
        current_value: The current TrainingDynamics object to update.
        other_value: The TrainingDynamics object to update from.

        Returns:
            Updated TrainingDynamics object.

        Examples:
            >>> current = TrainingDynamics()
            >>> current._data = {1: {"train": np.array([1, 2, 3])},
            ...                   2: None}

            >>> other = TrainingDynamics()
            >>> other._data = {1: {"train": np.array([4, 5, 6])},
            ...                 2: {"train": np.array([7, 8, 9])}}
            >>> # after update
            >>> print(current._data)
            {1: {"train": np.array([4, 5, 6])}, # updated
            2: {"train": np.array([7, 8, 9])}} # kept, because other was None

        """

        if current_value is None:
            return other_value
        if current_value._data is None:
            return other_value

        for epoch, split_data in other_value._data.items():
            if split_data is None:
                continue
            if len(split_data) == 0:
                continue

            # If current epoch is None, it should be updated
            if epoch in current_value._data and current_value._data[epoch] is None:
                current_value._data[epoch] = {}
                for split, data in split_data.items():
                    if data is None:
                        continue
                    current_value.add(epoch=epoch, data=data, split=split)
                continue

            if epoch not in current_value._data:
                for split, data in split_data.items():
                    if data is None:
                        continue
                    current_value.add(epoch=epoch, data=data, split=split)
                continue
            # case when current epoch exists, then update all but None values
            for split, value in split_data.items():
                if value is not None:
                    current_value.add(epoch=epoch, data=value, split=split)

        # Ensure ordering
        current_value._data = dict(sorted(current_value._data.items()))

        return current_value

    def __str__(self) -> str:
        """Provide a readable string representation of the Result object's public attributes.

        Returns:
            Formatted string showing all public attributes and their values
        """
        output = ["Result Object Public Attributes:", "-" * 30]

        for name in self.__annotations__.keys():
            if name.startswith("__"):
                continue

            value = getattr(self, name)
            if isinstance(value, TrainingDynamics):
                output.append(f"{name}: TrainingDynamics object")
            elif isinstance(value, torch.nn.Module):
                output.append(f"{name}: {value.__class__.__name__}")
            elif isinstance(value, dict):
                output.append(f"{name}: Dict with {len(value)} items")
            elif isinstance(value, torch.Tensor):
                output.append(f"{name}: Tensor of shape {tuple(value.shape)}")
            else:
                output.append(f"{name}: {value}")

        return "\n".join(output)

    def __repr__(self) -> str:
        """Return the same representation as __str__ for consistency."""
        return self.__str__()

    def get_latent_df(
        self, epoch: int, split: str, modality: Optional[str] = None
    ) -> pd.DataFrame:
        """Return latent representations as a DataFrame.

        Retrieves latent vectors and their corresponding sample IDs for a given
        epoch and data split. If a specific modality is provided, the results
        are restricted to that modality. Column names are inferred from model
        ontologies if available; otherwise, generic latent dimension labels are
        used.

        Args:
            epoch: The epoch number to retrieve latents from.
            split: The dataset split to query (e.g., "train", "valid", "test").
            modality: Optional modality name to filter the latents and sample IDs.

        Returns:
            A DataFrame where rows correspond to samples, columns represent latent
            dimensions, and the index contains sample IDs.
        """
        try:
            latents = self.latentspaces.get(epoch=epoch, split=split)
            ids = self.sample_ids.get(epoch=epoch, split=split)
            if modality is not None:  # for x-modalix and other multi-modal models
                latents = latents[modality]
                ids = ids[modality]
            if hasattr(self.model, "ontologies") and self.model.ontologies is not None:
                cols = list(self.model.ontologies[0].keys())
            else:
                cols = ["LatDim_" + str(i) for i in range(latents.shape[1])]
            return pd.DataFrame(latents, index=ids, columns=cols)
        except Exception as e:
            import warnings

            warnings.warn(
                f"We could not create visualizations for the loss plots.\n"
                f"This usually happens if you try to visualize after saving and loading "
                f"the pipeline object with `save_all=False`. This memory-efficient saving mode "
                f"does not retain past training loss data.\n\n"
                f"Original error message: {e}"
            )

            return pd.DataFrame()

    def get_reconstructions_df(
        self, epoch: int, split: str, modality: Optional[str] = None
    ) -> pd.DataFrame:
        """Return reconstructions as a DataFrame.

        Retrieves reconstructed features and their corresponding sample IDs for a
        given epoch and data split. If a specific modality is provided, the results
        are restricted to that modality. Column names are based on the dataset's
        feature identifiers.

        Args:
            epoch: The epoch number to retrieve reconstructions from.
            split: The dataset split to query (e.g., "train", "valid", "test").
            modality: Optional modality name to filter the reconstructions and
                sample IDs.

        Returns:
            A DataFrame where rows correspond to samples, columns represent
            reconstructed features, and the index contains sample IDs.
        """
        reconstructions = self.reconstructions.get(epoch=epoch, split=split)
        ids = self.sample_ids.get(epoch=epoch, split=split)
        if modality is not None:
            reconstructions = reconstructions[modality]
            ids = ids[modality]

        cols = self.datasets.train.feature_ids
        return pd.DataFrame(reconstructions, index=ids, columns=cols)

__getitem__(key)

Retrieve the value associated with a specific key.

Parameters:

Name Type Description Default
key str

The name of the attribute to retrieve.

required

Returns: The value of the specified attribute. Raises: KeyError - If the key is not a valid attribute of the Results class.

Source code in src/autoencodix/utils/_result.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def __getitem__(self, key: str) -> Any:
    """Retrieve the value associated with a specific key.

    Args:
        key: The name of the attribute to retrieve.
    Returns:
        The value of the specified attribute.
    Raises:
        KeyError - If the key is not a valid attribute of the Results class.

    """
    if not hasattr(self, key):
        raise KeyError(
            f"Invalid key: '{key}'. Allowed keys are: {', '.join(self.__annotations__.keys())}"
        )
    return getattr(self, key)

__repr__()

Return the same representation as str for consistency.

Source code in src/autoencodix/utils/_result.py
308
309
310
def __repr__(self) -> str:
    """Return the same representation as __str__ for consistency."""
    return self.__str__()

__setitem__(key, value)

Assign a value to a specific attribute.

Parameters:

Name Type Description Default
key str

The name of the attribute to set.

required
value Any

The value to assign to the attribute.

required

Raises: KeyError: If the key is not a valid attribute of the Results class.

Source code in src/autoencodix/utils/_result.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def __setitem__(self, key: str, value: Any) -> None:
    """Assign a value to a specific attribute.

    Args:
        key: The name of the attribute to set.
        value: The value to assign to the attribute.
    Raises:
        KeyError: If the key is not a valid attribute of the Results class.

    """
    if not hasattr(self, key):
        raise KeyError(
            f"Invalid key: '{key}'. Allowed keys are: {', '.join(self.__annotations__.keys())}"
        )
    setattr(self, key, value)

__str__()

Provide a readable string representation of the Result object's public attributes.

Returns:

Type Description
str

Formatted string showing all public attributes and their values

Source code in src/autoencodix/utils/_result.py
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def __str__(self) -> str:
    """Provide a readable string representation of the Result object's public attributes.

    Returns:
        Formatted string showing all public attributes and their values
    """
    output = ["Result Object Public Attributes:", "-" * 30]

    for name in self.__annotations__.keys():
        if name.startswith("__"):
            continue

        value = getattr(self, name)
        if isinstance(value, TrainingDynamics):
            output.append(f"{name}: TrainingDynamics object")
        elif isinstance(value, torch.nn.Module):
            output.append(f"{name}: {value.__class__.__name__}")
        elif isinstance(value, dict):
            output.append(f"{name}: Dict with {len(value)} items")
        elif isinstance(value, torch.Tensor):
            output.append(f"{name}: Tensor of shape {tuple(value.shape)}")
        else:
            output.append(f"{name}: {value}")

    return "\n".join(output)

get_latent_df(epoch, split, modality=None)

Return latent representations as a DataFrame.

Retrieves latent vectors and their corresponding sample IDs for a given epoch and data split. If a specific modality is provided, the results are restricted to that modality. Column names are inferred from model ontologies if available; otherwise, generic latent dimension labels are used.

Parameters:

Name Type Description Default
epoch int

The epoch number to retrieve latents from.

required
split str

The dataset split to query (e.g., "train", "valid", "test").

required
modality Optional[str]

Optional modality name to filter the latents and sample IDs.

None

Returns:

Type Description
DataFrame

A DataFrame where rows correspond to samples, columns represent latent

DataFrame

dimensions, and the index contains sample IDs.

Source code in src/autoencodix/utils/_result.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
def get_latent_df(
    self, epoch: int, split: str, modality: Optional[str] = None
) -> pd.DataFrame:
    """Return latent representations as a DataFrame.

    Retrieves latent vectors and their corresponding sample IDs for a given
    epoch and data split. If a specific modality is provided, the results
    are restricted to that modality. Column names are inferred from model
    ontologies if available; otherwise, generic latent dimension labels are
    used.

    Args:
        epoch: The epoch number to retrieve latents from.
        split: The dataset split to query (e.g., "train", "valid", "test").
        modality: Optional modality name to filter the latents and sample IDs.

    Returns:
        A DataFrame where rows correspond to samples, columns represent latent
        dimensions, and the index contains sample IDs.
    """
    try:
        latents = self.latentspaces.get(epoch=epoch, split=split)
        ids = self.sample_ids.get(epoch=epoch, split=split)
        if modality is not None:  # for x-modalix and other multi-modal models
            latents = latents[modality]
            ids = ids[modality]
        if hasattr(self.model, "ontologies") and self.model.ontologies is not None:
            cols = list(self.model.ontologies[0].keys())
        else:
            cols = ["LatDim_" + str(i) for i in range(latents.shape[1])]
        return pd.DataFrame(latents, index=ids, columns=cols)
    except Exception as e:
        import warnings

        warnings.warn(
            f"We could not create visualizations for the loss plots.\n"
            f"This usually happens if you try to visualize after saving and loading "
            f"the pipeline object with `save_all=False`. This memory-efficient saving mode "
            f"does not retain past training loss data.\n\n"
            f"Original error message: {e}"
        )

        return pd.DataFrame()

get_reconstructions_df(epoch, split, modality=None)

Return reconstructions as a DataFrame.

Retrieves reconstructed features and their corresponding sample IDs for a given epoch and data split. If a specific modality is provided, the results are restricted to that modality. Column names are based on the dataset's feature identifiers.

Parameters:

Name Type Description Default
epoch int

The epoch number to retrieve reconstructions from.

required
split str

The dataset split to query (e.g., "train", "valid", "test").

required
modality Optional[str]

Optional modality name to filter the reconstructions and sample IDs.

None

Returns:

Type Description
DataFrame

A DataFrame where rows correspond to samples, columns represent

DataFrame

reconstructed features, and the index contains sample IDs.

Source code in src/autoencodix/utils/_result.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
def get_reconstructions_df(
    self, epoch: int, split: str, modality: Optional[str] = None
) -> pd.DataFrame:
    """Return reconstructions as a DataFrame.

    Retrieves reconstructed features and their corresponding sample IDs for a
    given epoch and data split. If a specific modality is provided, the results
    are restricted to that modality. Column names are based on the dataset's
    feature identifiers.

    Args:
        epoch: The epoch number to retrieve reconstructions from.
        split: The dataset split to query (e.g., "train", "valid", "test").
        modality: Optional modality name to filter the reconstructions and
            sample IDs.

    Returns:
        A DataFrame where rows correspond to samples, columns represent
        reconstructed features, and the index contains sample IDs.
    """
    reconstructions = self.reconstructions.get(epoch=epoch, split=split)
    ids = self.sample_ids.get(epoch=epoch, split=split)
    if modality is not None:
        reconstructions = reconstructions[modality]
        ids = ids[modality]

    cols = self.datasets.train.feature_ids
    return pd.DataFrame(reconstructions, index=ids, columns=cols)

update(other)

Update the current Result object with values from another Result object.

For TrainingDynamics, merges the data across epochs and splits and overwrites if already exists. For all other attributes, replaces the current value with the other value.

Parameters:

Name Type Description Default
other Result

The Result object to update from.

required

Raises: TypeError: If the input object is not a Result instance

Source code in src/autoencodix/utils/_result.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def update(self, other: "Result") -> None:
    """Update the current Result object with values from another Result object.

    For TrainingDynamics, merges the data across epochs and splits and overwrites if already exists.
    For all other attributes, replaces the current value with the other value.

    Args:
        other: The Result object to update from.
    Raises:
        TypeError: If the input object is not a Result instance

    """
    if not isinstance(other, Result):
        raise TypeError(f"Expected Result object, got {type(other)}")

    for field_name in self.__annotations__.keys():
        current_value = getattr(self, field_name)
        other_value = getattr(other, field_name)
        if self._is_empty_value(other_value):
            continue

        # Handle TrainingDynamics - merge data
        if isinstance(current_value, TrainingDynamics):
            current_value = self._update_traindynamics(current_value, other_value)
        # For all other types - replace with other value
        if isinstance(current_value, LossRegistry):
            for key, value in other_value.losses():
                if value is None:
                    continue
                if not isinstance(value, TrainingDynamics):
                    raise ValueError(
                        f"Expected TrainingDynamics object, got {type(value)}"
                    )
                updated_dynamic = self._update_traindynamics(
                    current_value=current_value.get(key=key), other_value=value
                )
                current_value.set(key=key, value=updated_dynamic)
        else:
            setattr(self, field_name, other_value)

SingleCellDataReader

Reader for multi-modal single-cell data.

Source code in src/autoencodix/utils/_screader.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class SingleCellDataReader:
    """Reader for multi-modal single-cell data."""

    @staticmethod
    def read_data(
        config: DefaultConfig,
    ) -> Dict[str, MuData]:  # ty: ignore[invalid-type-form]
        """Read multiple single-cell modalities into MuData object(s).

        Args:
        config: Configuration object containing data paths and parameters.

        Returns:
            For non-paired translation: Dict of Dicts with {'multi_sc': DataDict} as outer dict and with modalty keys and mudata obj as inner dict.
            For paired translation and non translation cases: dict with "multi_sc" as key and mudata as value
        """
        modalities: Dict[str, AnnData] = {}

        for mod_key, mod_info in config.data_config.data_info.items():
            if not mod_info.is_single_cell:
                continue
            adata = sc.read_h5ad(mod_info.file_path)
            modalities[mod_key] = adata

        # if config.requires_paired:
        #     mdata = md.MuData(modalities)
        #     common_cells = list(
        #         set.intersection(
        #             *(set(adata.obs_names) for adata in modalities.values())
        #         )
        #     )
        #     print(f"Number of common cells: {len(common_cells)}")
        #     mdata = mdata[common_cells]
        #     return {"multi_sc": mdata}

        if config.requires_paired:
            common_cells_set = set.intersection(
                *(set(adata.obs_names) for adata in modalities.values())
            )
            common_cells_sorted = sorted(list(common_cells_set))

            # Subset EACH modality individually with the sorted common cells
            # This ensures each modality is aligned to the same order
            aligned_modalities = {}
            for mod_key, adata in modalities.items():
                aligned_modalities[mod_key] = adata[common_cells_sorted].copy()
            mdata = md.MuData(aligned_modalities)

            print(f"Number of common cells: {len(common_cells_sorted)}")

            # Clean obs_names: remove modality prefixes
            cleaned_names = [
                name.split(":")[-1] if ":" in name else name
                for name in mdata.obs.columns
            ]
            mdata.obs.columns = cleaned_names

            # Remove duplicate columns from obs
            mdata.obs = mdata.obs.loc[:, ~mdata.obs.columns.duplicated(keep="first")]

            return {"multi_sc": mdata}
        return {"multi_sc": modalities}

read_data(config) staticmethod

Read multiple single-cell modalities into MuData object(s).

Args: config: Configuration object containing data paths and parameters.

Returns:

Type Description
Dict[str, MuData]

For non-paired translation: Dict of Dicts with {'multi_sc': DataDict} as outer dict and with modalty keys and mudata obj as inner dict.

Dict[str, MuData]

For paired translation and non translation cases: dict with "multi_sc" as key and mudata as value

Source code in src/autoencodix/utils/_screader.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@staticmethod
def read_data(
    config: DefaultConfig,
) -> Dict[str, MuData]:  # ty: ignore[invalid-type-form]
    """Read multiple single-cell modalities into MuData object(s).

    Args:
    config: Configuration object containing data paths and parameters.

    Returns:
        For non-paired translation: Dict of Dicts with {'multi_sc': DataDict} as outer dict and with modalty keys and mudata obj as inner dict.
        For paired translation and non translation cases: dict with "multi_sc" as key and mudata as value
    """
    modalities: Dict[str, AnnData] = {}

    for mod_key, mod_info in config.data_config.data_info.items():
        if not mod_info.is_single_cell:
            continue
        adata = sc.read_h5ad(mod_info.file_path)
        modalities[mod_key] = adata

    # if config.requires_paired:
    #     mdata = md.MuData(modalities)
    #     common_cells = list(
    #         set.intersection(
    #             *(set(adata.obs_names) for adata in modalities.values())
    #         )
    #     )
    #     print(f"Number of common cells: {len(common_cells)}")
    #     mdata = mdata[common_cells]
    #     return {"multi_sc": mdata}

    if config.requires_paired:
        common_cells_set = set.intersection(
            *(set(adata.obs_names) for adata in modalities.values())
        )
        common_cells_sorted = sorted(list(common_cells_set))

        # Subset EACH modality individually with the sorted common cells
        # This ensures each modality is aligned to the same order
        aligned_modalities = {}
        for mod_key, adata in modalities.items():
            aligned_modalities[mod_key] = adata[common_cells_sorted].copy()
        mdata = md.MuData(aligned_modalities)

        print(f"Number of common cells: {len(common_cells_sorted)}")

        # Clean obs_names: remove modality prefixes
        cleaned_names = [
            name.split(":")[-1] if ":" in name else name
            for name in mdata.obs.columns
        ]
        mdata.obs.columns = cleaned_names

        # Remove duplicate columns from obs
        mdata.obs = mdata.obs.loc[:, ~mdata.obs.columns.duplicated(keep="first")]

        return {"multi_sc": mdata}
    return {"multi_sc": modalities}