Skip to content

segger.data.pyg_dataset

The pyg_dataset module provides PyTorch Geometric (PyG) dataset integration for spatial transcriptomics data. This module enables seamless integration between Segger's spatial data processing and PyTorch-based machine learning workflows.

pyg_dataset

STPyGDataset

Bases: InMemoryDataset

An in-memory dataset class for handling training using spatial transcriptomics data.

Source code in src/segger/data/pyg_dataset.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class STPyGDataset(InMemoryDataset):
    """An in-memory dataset class for handling training using spatial
    transcriptomics data.
    """

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_filter: Optional[Callable] = None,
    ):
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self) -> List[str]:
        """Return a list of raw file names in the raw directory.

        Returns:
            List[str]: List of raw file names.
        """
        return os.listdir(self.raw_dir)

    @property
    def processed_file_names(self) -> List[str]:
        """Return a list of processed file names in the processed directory.

        Returns:
            List[str]: List of processed file names.
        """
        paths = glob.glob(f"{self.processed_dir}/tiles_x*_y*_*_*.pt")
        # paths = paths.append(paths = glob.glob(f'{self.processed_dir}/tiles_x*_y*_*_*.pt'))
        file_names = list(map(os.path.basename, paths))
        return file_names

    def len(self) -> int:
        """Return the number of processed files.

        Returns:
            int: Number of processed files.
        """
        return len(self.processed_file_names)

    def get(self, idx: int) -> Data:
        """Get a processed data object.

        Args:
            idx: Index of the data object to retrieve.

        Returns:
            Data: The processed data object.
        """
        filepath = Path(self.processed_dir) / self.processed_file_names[idx]
        data = torch.load(filepath)
        # this is an issue in PyG's RandomLinkSplit, dimensions are not consistent if there is only one edge in the graph
        if hasattr(data["tx", "belongs", "bd"], "edge_label_index"):
            if data["tx", "belongs", "bd"].edge_label_index.dim() == 1:
                data["tx", "belongs", "bd"].edge_label_index = data[
                    "tx", "belongs", "bd"
                ].edge_label_index.unsqueeze(1)
                data["tx", "belongs", "bd"].edge_label = data[
                    "tx", "belongs", "bd"
                ].edge_label.unsqueeze(0)
            assert data["tx", "belongs", "bd"].edge_label_index.dim() == 2
        return data

processed_file_names property

processed_file_names

Return a list of processed file names in the processed directory.

Returns:

Type Description
List[str]

List[str]: List of processed file names.

raw_file_names property

raw_file_names

Return a list of raw file names in the raw directory.

Returns:

Type Description
List[str]

List[str]: List of raw file names.

get

get(idx)

Get a processed data object.

Parameters:

Name Type Description Default
idx int

Index of the data object to retrieve.

required

Returns:

Name Type Description
Data Data

The processed data object.

Source code in src/segger/data/pyg_dataset.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def get(self, idx: int) -> Data:
    """Get a processed data object.

    Args:
        idx: Index of the data object to retrieve.

    Returns:
        Data: The processed data object.
    """
    filepath = Path(self.processed_dir) / self.processed_file_names[idx]
    data = torch.load(filepath)
    # this is an issue in PyG's RandomLinkSplit, dimensions are not consistent if there is only one edge in the graph
    if hasattr(data["tx", "belongs", "bd"], "edge_label_index"):
        if data["tx", "belongs", "bd"].edge_label_index.dim() == 1:
            data["tx", "belongs", "bd"].edge_label_index = data[
                "tx", "belongs", "bd"
            ].edge_label_index.unsqueeze(1)
            data["tx", "belongs", "bd"].edge_label = data[
                "tx", "belongs", "bd"
            ].edge_label.unsqueeze(0)
        assert data["tx", "belongs", "bd"].edge_label_index.dim() == 2
    return data

len

len()

Return the number of processed files.

Returns:

Name Type Description
int int

Number of processed files.

Source code in src/segger/data/pyg_dataset.py
44
45
46
47
48
49
50
def len(self) -> int:
    """Return the number of processed files.

    Returns:
        int: Number of processed files.
    """
    return len(self.processed_file_names)

Overview

The STPyGDataset class extends PyTorch Geometric's InMemoryDataset to provide a standardized interface for loading and managing spatial transcriptomics data in machine learning pipelines. It handles the conversion of processed tiles into PyTorch Geometric format and provides utilities for training and validation.