Skip to content

segger.data.sample

The sample module is the core of the Segger data processing framework, providing comprehensive classes for handling spatial transcriptomics data. This module contains the main classes that orchestrate the entire data processing pipeline from raw data to machine learning-ready graphs.

sample

STInMemoryDataset

STInMemoryDataset(sample, extents, margin=10)

A class for handling in-memory representations of ST data.

This class is used to load and manage ST sample data from parquet files, filter boundaries and transcripts, and provide spatial tiling for further analysis. The class also pre-loads KDTrees for efficient spatial queries.

Parameters:

Name Type Description Default
sample STSampleParquet

The ST sample containing paths to the data files.

required
extents Polygon

The polygon defining the spatial extents for the dataset.

required
margin int

The margin to buffer around the extents when filtering data. Defaults to 10.

10

Parameters:

Name Type Description Default
sample STSampleParquet

The ST sample from which the data is loaded.

required
extents Polygon

The spatial extents of the dataset.

required
margin int

The buffer margin around the extents for filtering.

10
transcripts

The filtered transcripts within the dataset extents.

required
boundaries

The filtered boundaries within the dataset extents.

required
kdtree_tx

The KDTree for fast spatial queries on the transcripts.

required

Raises:

Type Description
ValueError

If the transcripts or boundaries could not be loaded or filtered.

Initialize the STInMemoryDataset instance by loading transcripts and boundaries from parquet files and pre-loading a KDTree for fast spatial queries.

Parameters:

Name Type Description Default
sample STSampleParquet

The ST sample containing paths to the data files.

required
extents Polygon

The polygon defining the spatial extents for the dataset.

required
margin int

The margin to buffer around the extents when filtering data. Defaults to 10.

10
Source code in src/segger/data/sample.py
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
def __init__(
    self,
    sample: STSampleParquet,
    extents: shapely.Polygon,
    margin: int = 10,
):
    """Initialize the STInMemoryDataset instance by loading transcripts
    and boundaries from parquet files and pre-loading a KDTree for fast
    spatial queries.

    Args:
        sample: The ST sample containing paths to the data files.
        extents: The polygon defining the spatial extents for the dataset.
        margin: The margin to buffer around the extents when filtering data. Defaults to 10.
    """
    # Set properties
    self.sample = sample
    self.extents = extents
    self.margin = margin
    self.settings = self.sample.settings

    # Load data from parquet files
    self._load_transcripts(self.sample._transcripts_filepath)
    self._load_boundaries(self.sample._boundaries_filepath)

    # Pre-load KDTrees
    self.kdtree_tx = KDTree(
        self.transcripts[self.settings.transcripts.xy], leafsize=100
    )

STSampleParquet

STSampleParquet(base_dir, n_workers=1, scale_factor=1.0, sample_type=None, weights=None)

A class to manage spatial transcriptomics data stored in parquet files.

This class provides methods for loading, processing, and saving data related to ST samples. It supports parallel processing and efficient handling of transcript and boundary data.

Initialize the STSampleParquet instance.

Parameters:

Name Type Description Default
base_dir PathLike

The base directory containing the ST data.

required
n_workers Optional[int]

The number of workers for parallel processing. Defaults to 1.

1
sample_type str

The sample type of the raw data, e.g., 'xenium' or 'merscope'. Defaults to None.

None
weights DataFrame

DataFrame containing weights for transcript embedding. Defaults to None.

None
scale_factor Optional[float]

The scale factor to be used for expanding the boundary extents during spatial queries. If not provided, the default from settings will be used. Defaults to None.

1.0

Raises:

Type Description
FileNotFoundError

If the base directory does not exist or the required files are missing.

Source code in src/segger/data/sample.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    base_dir: os.PathLike,
    n_workers: Optional[int] = 1,
    scale_factor: Optional[float] = 1.0,
    sample_type: str = None,
    weights: pd.DataFrame = None,
):
    """Initialize the STSampleParquet instance.

    Args:
        base_dir: The base directory containing the ST data.
        n_workers: The number of workers for parallel processing. Defaults to 1.
        sample_type: The sample type of the raw data, e.g., 'xenium' or 'merscope'. Defaults to None.
        weights: DataFrame containing weights for transcript embedding. Defaults to None.
        scale_factor: The scale factor to be used for expanding the boundary extents
            during spatial queries. If not provided, the default from settings
            will be used. Defaults to None.

    Raises:
        FileNotFoundError: If the base directory does not exist or the required files are
            missing.
    """
    # Setup paths and resource constraints
    self._base_dir = Path(base_dir)
    self.settings = utils.load_settings(sample_type)
    transcripts_fn = self.settings.transcripts.filename
    self._transcripts_filepath = self._base_dir / transcripts_fn
    boundaries_fn = self.settings.boundaries.filename
    self._boundaries_filepath = self._base_dir / boundaries_fn
    self.n_workers = n_workers
    self.settings.boundaries.scale_factor = 1
    nuclear_column = getattr(self.settings.transcripts, "nuclear_column", None)
    if nuclear_column is None or self.settings.boundaries.scale_factor != 1.0:
        print(
            "Boundary-transcript overlap information has not been pre-computed. It will be calculated during tile generation."
        )
    # Set scale factor if provided
    if scale_factor != 1.0:
        self.settings.boundaries.scale_factor = scale_factor

    # Ensure transcript IDs exist
    utils.ensure_transcript_ids(
        self._transcripts_filepath,
        self.settings.transcripts.x,
        self.settings.transcripts.y,
        self.settings.transcripts.id,
    )

    # Setup logging
    logging.basicConfig(level=logging.INFO)
    self.logger = logging.Logger(f"STSample@{base_dir}")

    # Internal caches
    self._extents = None
    self._transcripts_metadata = None
    self._boundaries_metadata = None

    # Setup default embedding for transcripts
    self._emb_genes = None
    if weights is not None:
        self._emb_genes = weights.index.to_list()
    classes = self.transcripts_metadata["feature_names"]
    self._transcript_embedding = TranscriptEmbedding(np.array(classes), weights)

boundaries_metadata cached property

boundaries_metadata

Retrieve metadata for the boundaries stored in the sample.

Returns:

Type Description
dict

Metadata dictionary for boundaries including column sizes.

Raises:

Type Description
FileNotFoundError

If the boundaries parquet file does not exist.

extents cached property

extents

Get the combined extents (bounding box) of the transcripts and boundaries.

Returns:

Type Description
Polygon

shapely.Polygon: The bounding box covering all transcripts and boundaries.

n_transcripts property

n_transcripts

Get the total number of transcripts in the sample.

Returns:

Name Type Description
int int

The number of transcripts.

transcripts_metadata cached property

transcripts_metadata

Retrieve metadata for the transcripts stored in the sample.

Returns:

Type Description
dict

Metadata dictionary for transcripts including column sizes and

dict

feature names.

Raises:

Type Description
FileNotFoundError

If the transcript parquet file does not exist.

save

save(data_dir, k_bd=3, dist_bd=15.0, k_tx=3, dist_tx=5.0, k_tx_ex=100, dist_tx_ex=20, tile_size=None, tile_width=None, tile_height=None, neg_sampling_ratio=5.0, frac=1.0, val_prob=0.1, test_prob=0.2, mutually_exclusive_genes=None)

Save the tiles of the sample as PyTorch geometric datasets.

See documentation for 'STTile' for more information on dataset contents.

Note: This function requires either 'tile_size' OR both 'tile_width' and 'tile_height' to be provided.

Parameters:

Name Type Description Default
data_dir PathLike

The directory where the dataset should be saved.

required
k_bd int

Number of nearest neighbors for boundary nodes. Defaults to 3.

3
dist_bd float

Maximum distance for boundary neighbors. Defaults to 15.0.

15.0
k_tx int

Number of nearest neighbors for transcript nodes. Defaults to 3.

3
dist_tx float

Maximum distance for transcript neighbors. Defaults to 5.0.

5.0
tile_size Optional[int]

If provided, specifies the size of the tile. Overrides tile_width and tile_height. Defaults to None.

None
tile_width Optional[int]

Width of the tiles in pixels. Ignored if tile_size is provided. Defaults to None.

None
tile_height Optional[int]

Height of the tiles in pixels. Ignored if tile_size is provided. Defaults to None.

None
neg_sampling_ratio float

Ratio of negative samples. Defaults to 5.0.

5.0
frac float

Fraction of the dataset to process. Defaults to 1.0.

1.0
val_prob float

Proportion of data for use for validation split. Defaults to 0.1.

0.1
test_prob float

Proportion of data for use for test split. Defaults to 0.2.

0.2

Raises:

Type Description
ValueError

If the 'frac' parameter is greater than 1.0 or if the calculated number of tiles is zero.

AssertionError

If the specified directory structure is not properly set up.

Source code in src/segger/data/sample.py
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
def save(
    self,
    data_dir: os.PathLike,
    k_bd: int = 3,
    dist_bd: float = 15.0,
    k_tx: int = 3,
    dist_tx: float = 5.0,
    k_tx_ex: int = 100,
    dist_tx_ex: float = 20,
    tile_size: Optional[int] = None,
    tile_width: Optional[int] = None,
    tile_height: Optional[int] = None,
    neg_sampling_ratio: float = 5.0,
    frac: float = 1.0,
    val_prob: float = 0.1,
    test_prob: float = 0.2,
    mutually_exclusive_genes: Optional[List] = None,
):
    """Save the tiles of the sample as PyTorch geometric datasets.

    See documentation for 'STTile' for more information on dataset contents.

    Note: This function requires either 'tile_size' OR both 'tile_width' and
    'tile_height' to be provided.

    Args:
        data_dir: The directory where the dataset should be saved.
        k_bd: Number of nearest neighbors for boundary nodes. Defaults to 3.
        dist_bd: Maximum distance for boundary neighbors. Defaults to 15.0.
        k_tx: Number of nearest neighbors for transcript nodes. Defaults to 3.
        dist_tx: Maximum distance for transcript neighbors. Defaults to 5.0.
        tile_size: If provided, specifies the size of the tile. Overrides `tile_width`
            and `tile_height`. Defaults to None.
        tile_width: Width of the tiles in pixels. Ignored if `tile_size` is provided. Defaults to None.
        tile_height: Height of the tiles in pixels. Ignored if `tile_size` is provided. Defaults to None.
        neg_sampling_ratio: Ratio of negative samples. Defaults to 5.0.
        frac: Fraction of the dataset to process. Defaults to 1.0.
        val_prob: Proportion of data for use for validation split. Defaults to 0.1.
        test_prob: Proportion of data for use for test split. Defaults to 0.2.

    Raises:
        ValueError: If the 'frac' parameter is greater than 1.0 or if the calculated
            number of tiles is zero.
        AssertionError: If the specified directory structure is not properly set up.
    """
    # Check inputs
    try:
        if frac > 1:
            msg = f"Arg 'frac' should be <= 1.0, but got {frac}."
            raise ValueError(msg)
        if tile_size is not None:
            n_tiles = self.n_transcripts / tile_size / self.n_workers * frac
            if int(n_tiles) == 0:
                msg = f"Sampling parameters would yield 0 total tiles."
                raise ValueError(msg)
    # Propagate errors to logging
    except Exception as e:
        self.logger.error(str(e), exc_info=True)
        raise e

    # Setup directory structure to save tiles
    data_dir = Path(data_dir)
    STSampleParquet._setup_directory(data_dir)

    # Function to parallelize over workers
    def func(region):
        xm = STInMemoryDataset(sample=self, extents=region)
        tiles = xm._tile(tile_width, tile_height, tile_size)
        # print(tiles)
        if frac < 1:
            tiles = random.sample(tiles, int(len(tiles) * frac))
        for tile in tiles:
            # Choose training, test, or validation datasets
            data_type = np.random.choice(
                a=["train_tiles", "test_tiles", "val_tiles"],
                p=[1 - (test_prob + val_prob), test_prob, val_prob],
            )
            xt = STTile(dataset=xm, extents=tile)
            pyg_data = xt.to_pyg_dataset(
                k_bd=k_bd,
                dist_bd=dist_bd,
                k_tx=k_tx,
                dist_tx=dist_tx,
                k_tx_ex=k_tx_ex,
                dist_tx_ex=dist_tx_ex,
                neg_sampling_ratio=neg_sampling_ratio,
                mutually_exclusive_genes = mutually_exclusive_genes
            )
            if pyg_data is not None:
                if pyg_data["tx", "belongs", "bd"].edge_index.numel() == 0:
                    # this tile is only for testing
                    data_type = "test_tiles"
                filepath = data_dir / data_type / "processed" / f"{xt.uid}.pt"
                torch.save(pyg_data, filepath)

    # TODO: Add Dask backend
    regions = self._get_balanced_regions()
    outs = []
    for region in regions:
        outs.append(func(region))
    return outs

save_debug

save_debug(data_dir, k_bd=3, dist_bd=15.0, k_tx=3, dist_tx=5.0, k_tx_ex=100, dist_tx_ex=20, tile_width=None, tile_height=None, neg_sampling_ratio=5.0, frac=1.0, val_prob=0.1, test_prob=0.2)

Debug version of save method that processes tiles sequentially and prints detailed information about the process.

Parameters:

Name Type Description Default
data_dir PathLike

The directory where the dataset should be saved.

required
k_bd int

Number of nearest neighbors for boundary nodes. Defaults to 3.

3
dist_bd float

Maximum distance for boundary neighbors. Defaults to 15.0.

15.0
k_tx int

Number of nearest neighbors for transcript nodes. Defaults to 3.

3
dist_tx float

Maximum distance for transcript neighbors. Defaults to 5.0.

5.0
k_tx_ex int

Number of nearest neighbors for transcript exclusion. Defaults to 100.

100
dist_tx_ex float

Maximum distance for transcript exclusion. Defaults to 20.

20
tile_width Optional[float]

Width of the tiles in pixels. Defaults to None.

None
tile_height Optional[float]

Height of the tiles in pixels. Defaults to None.

None
neg_sampling_ratio float

Ratio of negative samples. Defaults to 5.0.

5.0
frac float

Fraction of the dataset to process. Defaults to 1.0.

1.0
val_prob float

Proportion of data for use for validation split. Defaults to 0.1.

0.1
test_prob float

Proportion of data for use for test split. Defaults to 0.2.

0.2
Source code in src/segger/data/sample.py
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
def save_debug(
    self,
    data_dir: os.PathLike,
    k_bd: int = 3,
    dist_bd: float = 15.0,
    k_tx: int = 3,
    dist_tx: float = 5.0,
    k_tx_ex: int = 100,
    dist_tx_ex: float = 20,
    tile_width: Optional[float] = None,
    tile_height: Optional[float] = None,
    neg_sampling_ratio: float = 5.0,
    frac: float = 1.0,
    val_prob: float = 0.1,
    test_prob: float = 0.2,
):
    """Debug version of save method that processes tiles sequentially and prints
    detailed information about the process.

    Args:
        data_dir: The directory where the dataset should be saved.
        k_bd: Number of nearest neighbors for boundary nodes. Defaults to 3.
        dist_bd: Maximum distance for boundary neighbors. Defaults to 15.0.
        k_tx: Number of nearest neighbors for transcript nodes. Defaults to 3.
        dist_tx: Maximum distance for transcript neighbors. Defaults to 5.0.
        k_tx_ex: Number of nearest neighbors for transcript exclusion. Defaults to 100.
        dist_tx_ex: Maximum distance for transcript exclusion. Defaults to 20.
        tile_width: Width of the tiles in pixels. Defaults to None.
        tile_height: Height of the tiles in pixels. Defaults to None.
        neg_sampling_ratio: Ratio of negative samples. Defaults to 5.0.
        frac: Fraction of the dataset to process. Defaults to 1.0.
        val_prob: Proportion of data for use for validation split. Defaults to 0.1.
        test_prob: Proportion of data for use for test split. Defaults to 0.2.
    """
    print("\n=== Starting Debug Tile Generation ===")
    print(f"Parameters:")
    print(f"- k_bd: {k_bd} (boundary neighbors)")
    print(f"- dist_bd: {dist_bd} (boundary distance)")
    print(f"- k_tx: {k_tx} (transcript neighbors)")
    print(f"- dist_tx: {dist_tx} (transcript distance)")
    print(f"- tile_width: {tile_width}")
    print(f"- tile_height: {tile_height}")
    print(f"- frac: {frac}")
    print(f"- val_prob: {val_prob}")
    print(f"- test_prob: {test_prob}")

    # Setup directory structure to save tiles
    data_dir = Path(data_dir)
    STSampleParquet._setup_directory(data_dir)
    print(f"\nOutput directory: {data_dir}")

    # Get regions to process
    regions = self._get_balanced_regions()
    print(f"\nTotal regions to process: {len(regions)}")
    print("Region bounds:")
    for i, region in enumerate(regions):
        print(f"Region {i+1}: {region.bounds}")

    # Process each region sequentially
    for region_idx, region in enumerate(regions):
        print(f"\n=== Processing Region {region_idx + 1}/{len(regions)} ===")
        print(f"Region bounds: {region.bounds}")

        xm = STInMemoryDataset(sample=self, extents=region)
        tiles = xm._tile(tile_width, tile_height, None)
        print(f"Generated {len(tiles)} tiles for this region")

        if frac < 1:
            tiles = random.sample(tiles, int(len(tiles) * frac))
            print(f"After sampling: {len(tiles)} tiles")

        # Process each tile
        for tile_idx, tile in enumerate(tiles):
            print(f"\n--- Processing Tile {tile_idx + 1}/{len(tiles)} ---")
            print(f"Tile bounds: {tile.bounds}")

            # Choose training, test, or validation datasets
            data_type = np.random.choice(
                a=["train_tiles", "test_tiles", "val_tiles"],
                p=[1 - (test_prob + val_prob), test_prob, val_prob],
            )
            print(f"Assigned to: {data_type}")

            xt = STTile(dataset=xm, extents=tile)
            print(f"Tile UID: {xt.uid}")

            pyg_data = xt.to_pyg_dataset(
                k_bd=k_bd,
                dist_bd=dist_bd,
                k_tx=k_tx,
                dist_tx=dist_tx,
                k_tx_ex=k_tx_ex,
                dist_tx_ex=dist_tx_ex,
                neg_sampling_ratio=neg_sampling_ratio,
            )

            if pyg_data is not None:
                if pyg_data["tx", "belongs", "bd"].edge_index.numel() == 0:
                    data_type = "test_tiles"
                    print("No tx-belongs-bd edges found, reassigning to test_tiles")

                filepath = data_dir / data_type / "processed" / f"{xt.uid}.pt"
                torch.save(pyg_data, filepath)
                print(f"Saved to: {filepath}")

                # Print some statistics about the generated data
                print(f"Data statistics:")
                print(f"- Number of transcripts: {pyg_data['tx'].num_nodes}")
                print(f"- Number of boundaries: {pyg_data['bd'].num_nodes}")
                print(
                    f"- Number of tx-tx edges: {pyg_data['tx', 'neighbors', 'tx'].edge_index.shape[1]}"
                )
                print(
                    f"- Number of tx-bd edges: {pyg_data['tx', 'neighbors', 'bd'].edge_index.shape[1]}"
                )
                # print(f"- Number of tx-belongs-bd edges: {pyg_data['tx', 'belongs', 'bd'].edge_index.shape[1]}")
            else:
                print("Skipping tile - no valid data generated")

    print("\n=== Debug Tile Generation Completed ===")

set_transcript_embedding

set_transcript_embedding(weights)

Set the transcript embedding for the sample.

Parameters:

Name Type Description Default
weights DataFrame

A DataFrame containing the weights for each transcript.

required

Raises:

Type Description
ValueError

If the provided weights do not match the number of transcript features.

Source code in src/segger/data/sample.py
302
303
304
305
306
307
308
309
310
311
312
313
def set_transcript_embedding(self, weights: pd.DataFrame):
    """Set the transcript embedding for the sample.

    Args:
        weights: A DataFrame containing the weights for each transcript.

    Raises:
        ValueError: If the provided weights do not match the number of transcript
            features.
    """
    classes = self._transcripts_metadata["feature_names"]
    self._transcript_embedding = TranscriptEmbedding(classes, weights)

STTile

STTile(dataset, extents)

A class representing a tile of a ST sample.

Parameters:

Name Type Description Default
dataset STInMemoryDataset

The ST dataset containing data.

required
extents Polygon

The extents of the tile in the sample.

required
boundaries

Filtered boundaries within the tile extents.

required
transcripts

Filtered transcripts within the tile extents.

required

Initialize a STTile instance.

Parameters:

Name Type Description Default
dataset STInMemoryDataset

The ST dataset containing data.

required
extents Polygon

The extents of the tile in the sample.

required
Note

The boundaries and transcripts attributes are cached to avoid the overhead of filtering when tiles are instantiated. This is particularly useful in multiprocessing settings where generating tiles in parallel could lead to high overhead.

Internal Args

_boundaries: Cached DataFrame of filtered boundaries. Initially set to None. _transcripts: Cached DataFrame of filtered transcripts. Initially set to None.

Source code in src/segger/data/sample.py
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
def __init__(
    self,
    dataset: STInMemoryDataset,
    extents: shapely.Polygon,
):
    """Initialize a STTile instance.

    Args:
        dataset: The ST dataset containing data.
        extents: The extents of the tile in the sample.

    Note:
        The `boundaries` and `transcripts` attributes are cached to avoid the
        overhead of filtering when tiles are instantiated. This is particularly
        useful in multiprocessing settings where generating tiles in parallel
        could lead to high overhead.

    Internal Args:
        _boundaries: Cached DataFrame of filtered boundaries. Initially set to None.
        _transcripts: Cached DataFrame of filtered transcripts. Initially set to None.
    """
    self.dataset = dataset
    self.extents = extents
    self.margin = dataset.margin
    self.settings = self.dataset.settings

    # Internal caches for filtered data
    self._boundaries = None
    self._transcripts = None

boundaries cached property

boundaries

Return the filtered boundaries within the tile extents, cached for efficiency.

The boundaries are computed only once and cached. If the boundaries have not been computed yet, they are computed using get_filtered_boundaries().

Returns:

Type Description
DataFrame

pd.DataFrame: A DataFrame containing the filtered boundaries within the tile extents.

transcripts cached property

transcripts

Return the filtered transcripts within the tile extents, cached for efficiency.

The transcripts are computed only once and cached. If the transcripts have not been computed yet, they are computed using get_filtered_transcripts().

Returns:

Type Description
DataFrame

pd.DataFrame: A DataFrame containing the filtered transcripts within the tile extents.

uid property

uid

Generate a unique identifier for the tile based on its extents.

This UID is particularly useful for saving or indexing tiles in distributed processing environments.

The UID is constructed using the minimum and maximum x and y coordinates of the tile's bounding box, representing its position and size in the sample.

Returns:

Name Type Description
str str

A unique identifier string in the format 'x=_y=_w=_h=' where: - <x_min>: Minimum x-coordinate of the tile's extents. - <y_min>: Minimum y-coordinate of the tile's extents. - <width>: Width of the tile. - <height>: Height of the tile.

Example

If the tile's extents are bounded by (x_min, y_min) = (100, 200) and (x_max, y_max) = (150, 250), the generated UID would be: 'x=100_y=200_w=50_h=50'

canonical_edges

canonical_edges(edge_index)

Sort edge indices to ensure canonical ordering.

Parameters:

Name Type Description Default
edge_index

The edge index tensor to sort.

required

Returns:

Type Description

torch.Tensor: The sorted edge index tensor.

Source code in src/segger/data/sample.py
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
def canonical_edges(edge_index):
    """Sort edge indices to ensure canonical ordering.

    Args:
        edge_index: The edge index tensor to sort.

    Returns:
        torch.Tensor: The sorted edge index tensor.
    """
    return torch.sort(edge_index, dim=0)[0]

get_boundary_props

get_boundary_props(area=True, convexity=True, elongation=True, circularity=True)

Compute geometric properties of boundary polygons.

Parameters:

Name Type Description Default
area bool

If True, compute the area of each boundary polygon. Defaults to True.

True
convexity bool

If True, compute the convexity of each boundary polygon. Defaults to True.

True
elongation bool

If True, compute the elongation of each boundary polygon. Defaults to True.

True
circularity bool

If True, compute the circularity of each boundary polygon. Defaults to True.

True

Returns:

Type Description
Tensor

torch.Tensor: A tensor containing the computed properties for each boundary polygon.

Note

The intention is for this function to simplify testing new strategies for 'bd' node representations. You can just change the function body to return another torch.Tensor without worrying about changes to the rest of the code.

Source code in src/segger/data/sample.py
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
def get_boundary_props(
    self,
    area: bool = True,
    convexity: bool = True,
    elongation: bool = True,
    circularity: bool = True,
) -> torch.Tensor:
    """Compute geometric properties of boundary polygons.

    Args:
        area: If True, compute the area of each boundary polygon. Defaults to True.
        convexity: If True, compute the convexity of each boundary polygon. Defaults to True.
        elongation: If True, compute the elongation of each boundary polygon. Defaults to True.
        circularity: If True, compute the circularity of each boundary polygon. Defaults to True.

    Returns:
        torch.Tensor: A tensor containing the computed properties for each boundary
            polygon.

    Note:
        The intention is for this function to simplify testing new strategies
        for 'bd' node representations. You can just change the function body to
        return another torch.Tensor without worrying about changes to the rest
        of the code.
    """
    # Get polygons from coordinates
    # Use getattr to check for the geometry column
    geometry_column = getattr(self.settings.boundaries, 'geometry', None)
    if geometry_column and geometry_column in self.boundaries.columns:
        polygons = self.boundaries[geometry_column]
    else:
        polygons = self.boundaries['geometry']  # Assign None if the geometry column does not exist
    # Geometric properties of polygons
    props = self.get_polygon_props(polygons)
    props = torch.as_tensor(props.values).float()

    return props

get_filtered_boundaries

get_filtered_boundaries()

Filter the boundaries in the sample to include only those within the specified tile extents.

Returns:

Type Description
DataFrame

pd.DataFrame: A DataFrame containing the filtered boundaries within the tile extents.

Source code in src/segger/data/sample.py
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
def get_filtered_boundaries(self) -> pd.DataFrame:
    """Filter the boundaries in the sample to include only those within
    the specified tile extents.

    Returns:
        pd.DataFrame: A DataFrame containing the filtered boundaries within the tile
            extents.
    """
    filtered_boundaries = utils.filter_boundaries(
        boundaries=self.dataset.boundaries,
        inset=self.extents,
        outset=self.extents.buffer(self.margin, join_style="mitre"),
        x=self.settings.boundaries.x,
        y=self.settings.boundaries.y,
        label=self.settings.boundaries.label,
    )
    return filtered_boundaries

get_filtered_transcripts

get_filtered_transcripts()

Filter the transcripts in the sample to include only those within the specified tile extents.

Returns:

Type Description
DataFrame

pd.DataFrame: A DataFrame containing the filtered transcripts within the tile extents.

Source code in src/segger/data/sample.py
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
def get_filtered_transcripts(self) -> pd.DataFrame:
    """Filter the transcripts in the sample to include only those within
    the specified tile extents.

    Returns:
        pd.DataFrame: A DataFrame containing the filtered transcripts within the tile
            extents.
    """

    # Buffer tile bounds to include transcripts around boundary
    outset = self.extents.buffer(self.margin, join_style="mitre")
    xmin, ymin, xmax, ymax = outset.bounds

    # Get transcripts inside buffered region
    x, y = self.settings.transcripts.xy
    mask = self.dataset.transcripts[x].between(xmin, xmax)
    mask &= self.dataset.transcripts[y].between(ymin, ymax)
    filtered_transcripts = self.dataset.transcripts[mask]

    return filtered_transcripts

get_kdtree_edge_index staticmethod

get_kdtree_edge_index(index_coords, query_coords, k, max_distance)

Compute the k-nearest neighbor edge indices using a KDTree.

Parameters:

Name Type Description Default
index_coords ndarray

An array of shape (n_samples, n_features) representing the coordinates of the points to be indexed.

required
query_coords ndarray

An array of shape (m_samples, n_features) representing the coordinates of the query points.

required
k int

The number of nearest neighbors to find for each query point.

required
max_distance float

The maximum distance to consider for neighbors.

required

Returns:

Type Description
Tensor

torch.Tensor: An array of shape (2, n_edges) containing the edge indices. Each column represents an edge between two points, where the first row contains the source indices and the second row contains the target indices.

Source code in src/segger/data/sample.py
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
@staticmethod
def get_kdtree_edge_index(
    index_coords: np.ndarray,
    query_coords: np.ndarray,
    k: int,
    max_distance: float,
) -> torch.Tensor:
    """Compute the k-nearest neighbor edge indices using a KDTree.

    Args:
        index_coords: An array of shape (n_samples, n_features) representing the
            coordinates of the points to be indexed.
        query_coords: An array of shape (m_samples, n_features) representing the
            coordinates of the query points.
        k: The number of nearest neighbors to find for each query point.
        max_distance: The maximum distance to consider for neighbors.

    Returns:
        torch.Tensor: An array of shape (2, n_edges) containing the edge indices. Each
            column represents an edge between two points, where the first row
            contains the source indices and the second row contains the target
            indices.
    """
    # KDTree search
    tree = KDTree(index_coords)
    dist, idx = tree.query(query_coords, k, max_distance)

    # To sparse adjacency
    edge_index = np.argwhere(dist != np.inf).T
    edge_index[1] = idx[dist != np.inf]
    edge_index = torch.tensor(edge_index, dtype=torch.long).contiguous()

    return edge_index

get_polygon_props staticmethod

get_polygon_props(polygons, area=True, convexity=True, elongation=True, circularity=True)

Compute geometric properties of polygons.

Parameters:

Name Type Description Default
polygons GeoSeries

A GeoSeries containing polygon geometries.

required
area bool

If True, compute the area of each polygon. Defaults to True.

True
convexity bool

If True, compute the convexity of each polygon. Defaults to True.

True
elongation bool

If True, compute the elongation of each polygon. Defaults to True.

True
circularity bool

If True, compute the circularity of each polygon. Defaults to True.

True

Returns:

Type Description
DataFrame

pd.DataFrame: A DataFrame containing the computed properties for each polygon.

Source code in src/segger/data/sample.py
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
@staticmethod
def get_polygon_props(
    polygons: gpd.GeoSeries,
    area: bool = True,
    convexity: bool = True,
    elongation: bool = True,
    circularity: bool = True,
) -> pd.DataFrame:
    """Compute geometric properties of polygons.

    Args:
        polygons: A GeoSeries containing polygon geometries.
        area: If True, compute the area of each polygon. Defaults to True.
        convexity: If True, compute the convexity of each polygon. Defaults to True.
        elongation: If True, compute the elongation of each polygon. Defaults to True.
        circularity: If True, compute the circularity of each polygon. Defaults to True.

    Returns:
        pd.DataFrame: A DataFrame containing the computed properties for each polygon.
    """
    props = pd.DataFrame(index=polygons.index, dtype=float)
    if area:
        props["area"] = polygons.area
    if convexity:
        props["convexity"] = polygons.convex_hull.area / polygons.area
    if elongation:
        rects = polygons.minimum_rotated_rectangle()
        props["elongation"] = rects.area / polygons.envelope.area
    if circularity:
        r = polygons.minimum_bounding_radius()
        props["circularity"] = polygons.area / r**2

    return props

get_transcript_props

get_transcript_props()

Encode transcript features in a sparse format.

Returns:

Type Description
Tensor

torch.Tensor: A sparse tensor containing the encoded transcript features.

Note

The intention is for this function to simplify testing new strategies for 'tx' node representations. For example, the encoder can be any type of encoder that transforms the transcript labels into a numerical matrix (in sparse format).

Source code in src/segger/data/sample.py
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
def get_transcript_props(self) -> torch.Tensor:
    """Encode transcript features in a sparse format.

    Returns:
        torch.Tensor: A sparse tensor containing the encoded transcript features.

    Note:
        The intention is for this function to simplify testing new strategies
        for 'tx' node representations. For example, the encoder can be any type
        of encoder that transforms the transcript labels into a numerical
        matrix (in sparse format).
    """
    # Encode transcript features in sparse format
    embedding = self.dataset.sample._transcript_embedding
    label = self.settings.transcripts.label
    props = embedding.embed(self.transcripts[label])

    return props

to_pyg_dataset

to_pyg_dataset(neg_sampling_ratio=10, k_bd=3, dist_bd=15, k_tx=3, dist_tx=5, k_tx_ex=100, dist_tx_ex=20, area=True, convexity=True, elongation=True, circularity=True, mutually_exclusive_genes=None)

Convert the sample data to a PyG HeteroData object.

Parameters:

Name Type Description Default
neg_sampling_ratio float

Ratio of negative samples. Defaults to 10.

10
k_bd int

Number of nearest neighbors for boundary nodes. Defaults to 3.

3
dist_bd float

Maximum distance for boundary neighbors. Defaults to 15.

15
k_tx int

Number of nearest neighbors for transcript nodes. Defaults to 3.

3
dist_tx float

Maximum distance for transcript neighbors. Defaults to 5.

5
k_tx_ex int

Number of nearest neighbors for transcript exclusion. Defaults to 100.

100
dist_tx_ex float

Maximum distance for transcript exclusion. Defaults to 20.

20
area bool

If True, compute area of boundary polygons. Defaults to True.

True
convexity bool

If True, compute convexity of boundary polygons. Defaults to True.

True
elongation bool

If True, compute elongation of boundary polygons. Defaults to True.

True
circularity bool

If True, compute circularity of boundary polygons. Defaults to True.

True
mutually_exclusive_genes Optional[List]

List of mutually exclusive gene pairs. Defaults to None.

None

Returns:

Name Type Description
HeteroData HeteroData

A PyTorch Geometric HeteroData object containing the sample data.

Source code in src/segger/data/sample.py
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
def to_pyg_dataset(
    self,
    # train: bool,
    neg_sampling_ratio: float = 10,
    k_bd: int = 3,
    dist_bd: float = 15,
    k_tx: int = 3,
    dist_tx: float = 5,
    k_tx_ex: int = 100,
    dist_tx_ex: float = 20,
    area: bool = True,
    convexity: bool = True,
    elongation: bool = True,
    circularity: bool = True,
    mutually_exclusive_genes: Optional[List] = None,
) -> HeteroData:
    """Convert the sample data to a PyG HeteroData object.

    Args:
        neg_sampling_ratio: Ratio of negative samples. Defaults to 10.
        k_bd: Number of nearest neighbors for boundary nodes. Defaults to 3.
        dist_bd: Maximum distance for boundary neighbors. Defaults to 15.
        k_tx: Number of nearest neighbors for transcript nodes. Defaults to 3.
        dist_tx: Maximum distance for transcript neighbors. Defaults to 5.
        k_tx_ex: Number of nearest neighbors for transcript exclusion. Defaults to 100.
        dist_tx_ex: Maximum distance for transcript exclusion. Defaults to 20.
        area: If True, compute area of boundary polygons. Defaults to True.
        convexity: If True, compute convexity of boundary polygons. Defaults to True.
        elongation: If True, compute elongation of boundary polygons. Defaults to True.
        circularity: If True, compute circularity of boundary polygons. Defaults to True.
        mutually_exclusive_genes: List of mutually exclusive gene pairs. Defaults to None.

    Returns:
        HeteroData: A PyTorch Geometric HeteroData object containing the sample data.
    """
    # Initialize an empty HeteroData object
    pyg_data = HeteroData()

    # Set up Transcript nodes
    # Get transcript IDs - use getattr to safely check for id attribute
    transcript_id_column = getattr(self.settings.transcripts, "id", None)
    if transcript_id_column is None:
        raise ValueError(
            "Transcript IDs not found in DataFrame. Please run add_transcript_ids() "
            "as a preprocessing step before creating the dataset."
        )

    # Assign IDs to PyG data
    pyg_data["tx"].id = torch.tensor(
        self.transcripts[transcript_id_column].values, dtype=torch.long
    )
    pyg_data["tx"].pos = torch.tensor(
        self.transcripts[self.settings.transcripts.xyz].values,
        dtype=torch.float32,
    )
    pyg_data["tx"].x = self.get_transcript_props()



    # Set up Transcript-Transcript neighbor edges
    nbrs_edge_idx = self.get_kdtree_edge_index(
        self.transcripts[self.settings.transcripts.xyz],
        self.transcripts[self.settings.transcripts.xyz],
        k=k_tx,
        max_distance=dist_tx,
    )

    # If there are no tx-neighbors-tx edges, skip saving tile
    if nbrs_edge_idx.shape[1] == 0:
        return None

    pyg_data["tx", "neighbors", "tx"].edge_index = nbrs_edge_idx


    if mutually_exclusive_genes is not None:
        # Get potential repulsive edges (k-nearest neighbors within distance)
        # --- Step 1: Get repulsive edges (mutually exclusive genes) ---
        repels_edge_idx = self.get_kdtree_edge_index(
            self.transcripts[self.settings.transcripts.xyz],
            self.transcripts[self.settings.transcripts.xyz],
            k=k_tx_ex,
            max_distance=dist_tx_ex,
        )
        gene_ids = self.transcripts[self.settings.transcripts.label].tolist()

        # Filter repels_edge_idx to only keep mutually exclusive gene pairs
        src_genes = [gene_ids[i] for i in repels_edge_idx[0].tolist()]
        dst_genes = [gene_ids[i] for i in repels_edge_idx[1].tolist()]
        mask = [
            tuple(sorted((a, b))) in mutually_exclusive_genes if a != b else False
            for a, b in zip(src_genes, dst_genes)
        ]
        repels_edge_idx = repels_edge_idx[:, torch.tensor(mask)]

        # --- Step 2: Get attractive edges (same gene, at least one node in repels) ---
        # Nodes involved in repels (for filtering nbrs_edge_idx)
        repels_nodes = torch.cat([repels_edge_idx[0], repels_edge_idx[1]]).unique()

        # Filter nbrs_edge_idx: keep edges where (1) same gene AND (2) at least one node in repels
        attractive_mask = torch.zeros(nbrs_edge_idx.shape[1], dtype=torch.bool)
        for i, (src, dst) in enumerate(nbrs_edge_idx.t().tolist()):
            if (src != dst) and (gene_ids[src] == gene_ids[dst]) and (src in repels_nodes or dst in repels_nodes):
                attractive_mask[i] = True
        attractive_edge_idx = nbrs_edge_idx[:, attractive_mask]

        # --- Step 3: Combine repels (label=0) and attractive (label=1) edges ---
        edge_label_index = torch.cat([repels_edge_idx, attractive_edge_idx], dim=1)
        edge_label = torch.cat([
            torch.zeros(repels_edge_idx.shape[1], dtype=torch.long),  # 0 for repels
            torch.ones(attractive_edge_idx.shape[1], dtype=torch.long)  # 1 for attracts
        ])

        # --- Step 4: Store in PyG data object ---
        pyg_data["tx", "attracts", "tx"].edge_label_index = edge_label_index
        pyg_data["tx", "attracts", "tx"].edge_label = edge_label


    # Set up Boundary nodes
    # Check if boundaries have geometries
    geometry_column = getattr(self.settings.boundaries, 'geometry', None)
    if geometry_column and geometry_column in self.boundaries.columns:
        polygons = gpd.GeoSeries(self.boundaries[geometry_column], index=self.boundaries.index)
    else:
        # Fallback: compute polygons
        polygons = utils.get_polygons_from_xy(
            self.boundaries,
            x=self.settings.boundaries.x,
            y=self.settings.boundaries.y,
            label=self.settings.boundaries.label,
            scale_factor=self.settings.boundaries.scale_factor,
        )

    # Ensure self.boundaries is a GeoDataFrame with correct geometry
    self.boundaries = gpd.GeoDataFrame(self.boundaries.copy(), geometry=polygons)
    centroids = polygons.centroid.get_coordinates()
    pyg_data["bd"].id = polygons.index.to_numpy()
    pyg_data["bd"].pos = torch.tensor(centroids.values, dtype=torch.float32)
    pyg_data["bd"].x = self.get_boundary_props(
        area, convexity, elongation, circularity
    )

    # Set up Boundary-Transcript neighbor edges
    dist = np.sqrt(polygons.area.max()) * 10  # heuristic distance
    nbrs_edge_idx = self.get_kdtree_edge_index(
        centroids,
        self.transcripts[self.settings.transcripts.xy],
        k=k_bd,
        max_distance=dist,
    )
    pyg_data["tx", "neighbors", "bd"].edge_index = nbrs_edge_idx

    # If there are no tx-neighbors-bd edges, we put the tile automatically in test set
    if nbrs_edge_idx.numel() == 0:
        # logging.warning(f"No tx-neighbors-bd edges found in tile {self.uid}.")
        pyg_data["tx", "belongs", "bd"].edge_index = torch.tensor(
            [], dtype=torch.long
        )
        return pyg_data

    # Now we identify and split the tx-belongs-bd edges
    edge_type = ("tx", "belongs", "bd")

    # Find nuclear transcripts
    tx_cell_ids = self.transcripts[self.settings.boundaries.id]
    cell_ids_map = {idx: i for (i, idx) in enumerate(polygons.index)}

    # Get nuclear column and value from settings
    nuclear_column = getattr(self.settings.transcripts, "nuclear_column", None)
    nuclear_value = getattr(self.settings.transcripts, "nuclear_value", None)

    if nuclear_column is None or self.settings.boundaries.scale_factor != 1.0:
        is_nuclear = utils.compute_nuclear_transcripts(
            polygons=polygons,
            transcripts=self.transcripts,
            x_col=self.settings.transcripts.x,
            y_col=self.settings.transcripts.y,
            nuclear_column=nuclear_column,
            nuclear_value=nuclear_value,
        )
    else:
        is_nuclear = self.transcripts[nuclear_column].eq(nuclear_value)
    is_nuclear &= tx_cell_ids.isin(polygons.index)

    # # Set up overlap edges
    # row_idx = np.where(is_nuclear)[0]
    # col_idx = tx_cell_ids.iloc[row_idx].map(cell_ids_map)
    # blng_edge_idx = torch.tensor(np.stack([row_idx, col_idx])).long()
    # pyg_data[edge_type].edge_index = blng_edge_idx

    # # If there are no tx-belongs-bd edges, flag tile as test only (cannot be used for training)
    # if blng_edge_idx.numel() == 0:
    #     return pyg_data

    #         # If there are tx-bd edges, add negative edges for training
    # transform = RandomLinkSplit(
    #     num_val=0,
    #     num_test=0,
    #     is_undirected=True,
    #     edge_types=[edge_type],
    #     neg_sampling_ratio=neg_sampling_ratio,
    # )
    # pyg_data, _, _ = transform(pyg_data)

    # # Refilter negative edges to include only transcripts in the
    # # original positive edges (still need a memory-efficient solution)
    # edges = pyg_data[edge_type]
    # mask = edges.edge_label_index[0].unsqueeze(1) == edges.edge_index[0].unsqueeze(
    #     0
    # )
    # mask = torch.nonzero(torch.any(mask, 1)).squeeze()
    # edges.edge_label_index = edges.edge_label_index[:, mask]
    # edges.edge_label = edges.edge_label[mask]

    # return pyg_data


    # Set up overlap edges
    row_idx = np.where(is_nuclear)[0]
    col_idx = tx_cell_ids.iloc[row_idx].map(cell_ids_map)
    blng_edge_idx = torch.tensor(np.stack([row_idx, col_idx])).long()
    pyg_data[edge_type].edge_index = blng_edge_idx

    # If there are no tx-belongs-bd edges, flag tile as test only (cannot be used for training)
    if blng_edge_idx.numel() == 0:
        return pyg_data

    # If there are tx-bd edges, add negative edges for training
    pos_edges = blng_edge_idx  # shape (2, num_pos)
    num_pos = pos_edges.shape[1]

    # Negative edges (tx-neighbors-bd) - EXCLUDE positives
    neg_candidates = nbrs_edge_idx  # shape (2, num_candidates)

    # --- Fast Negative Filtering (PyTorch-only) ---
    # Reshape edges for broadcasting: (2, num_pos) vs (2, num_candidates, 1)
    pos_expanded = pos_edges.unsqueeze(2)  # shape (2, num_pos, 1)
    neg_expanded = neg_candidates.unsqueeze(1)  # shape (2, 1, num_candidates)

    # Compare all edges in one go (broadcasting)
    matches = (pos_expanded == neg_expanded).all(dim=0)  # shape (num_pos, num_candidates)
    is_negative = ~matches.any(dim=0)  # shape (num_candidates,)

    # Filter negatives
    neg_edges = neg_candidates[:, is_negative]  # shape (2, num_filtered_neg)
    num_neg = neg_edges.shape[1]

    # --- Combine and label ---
    edge_label_index = torch.cat([neg_edges, pos_edges], dim=1)
    edge_label = torch.cat([
        torch.zeros(num_neg, dtype=torch.float),
        torch.ones(num_pos, dtype=torch.float)
    ])

    mask = edge_label_index[0].unsqueeze(1) == blng_edge_idx[0].unsqueeze(0)
    mask = torch.nonzero(torch.any(mask, 1)).squeeze()
    edge_label_index = edge_label_index[:, mask]
    edge_label = edge_label[mask]

    pyg_data[edge_type].edge_label_index = edge_label_index
    pyg_data[edge_type].edge_label = edge_label

    return pyg_data

Usage Examples

Basic Data Loading

from segger.data.sample import STSampleParquet

# Load a spatial transcriptomics sample
sample = STSampleParquet(
    base_dir="/path/to/xenium/data",
    n_workers=4,
    sample_type="xenium"
)

# Get sample information
print(f"Transcripts: {sample.n_transcripts}")
print(f"Spatial extents: {sample.extents}")
print(f"Feature names: {sample.transcripts_metadata['feature_names'][:5]}")

Spatial Tiling and Processing

# Save processed tiles
sample.save(
    data_dir="./processed_data",
    tile_size=1000,  # 1000 transcripts per tile
    k_bd=3,          # 3 boundary neighbors
    k_tx=5,          # 5 transcript neighbors
    dist_bd=15.0,    # 15 pixel boundary distance
    dist_tx=5.0,     # 5 pixel transcript distance
    frac=0.8,        # Process 80% of data
    val_prob=0.1,    # 10% validation
    test_prob=0.2    # 20% test
)

In-Memory Dataset Processing

from segger.data.sample import STInMemoryDataset

# Create dataset for a specific region
dataset = STInMemoryDataset(
    sample=sample,
    extents=region_polygon,
    margin=10
)

# Generate tiles
tiles = dataset._tile(
    width=100,    # 100 pixel width
    height=100    # 100 pixel height
)

print(f"Generated {len(tiles)} tiles")

Individual Tile Processing

from segger.data.sample import STTile

# Process individual tile
tile = STTile(dataset=dataset, extents=tile_polygon)

# Get tile data
transcripts = tile.transcripts
boundaries = tile.boundaries

# Convert to PyG format
pyg_data = tile.to_pyg_dataset(
    k_bd=3,
    dist_bd=15,
    k_tx=5,
    dist_tx=5,
    area=True,
    convexity=True,
    elongation=True,
    circularity=True
)

print(f"Tile UID: {tile.uid}")
print(f"Transcripts: {len(transcripts)}")
print(f"Boundaries: {len(boundaries)}")