Skip to content

segger.data.io

MerscopeKeys

Bases: Enum

Keys for MERSCOPE data (Vizgen platform).

MerscopeSample

MerscopeSample(transcripts_df=None, transcripts_radius=10, boundaries_graph=False, embedding_df=None, verbose=True)

Bases: SpatialTranscriptomicsSample

Source code in src/segger/data/io.py
1115
1116
1117
1118
1119
1120
1121
1122
1123
def __init__(
    self,
    transcripts_df: dd.DataFrame = None,
    transcripts_radius: int = 10,
    boundaries_graph: bool = False,
    embedding_df: pd.DataFrame = None,
    verbose: bool = True,
):
    super().__init__(transcripts_df, transcripts_radius, boundaries_graph, embedding_df, MerscopeKeys)

filter_transcripts

filter_transcripts(transcripts_df, min_qv=20.0)

Filters transcripts based on specific criteria for Merscope using Dask.

Parameters:

Name Type Description Default
transcripts_df

dd.DataFrame The Dask DataFrame containing transcript data.

required
min_qv

float, optional The minimum quality value threshold for filtering transcripts.

20.0

Returns:

Type Description
DataFrame

dd.DataFrame The filtered Dask DataFrame.

Source code in src/segger/data/io.py
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
def filter_transcripts(self, transcripts_df: dd.DataFrame, min_qv: float = 20.0) -> dd.DataFrame:
    """
    Filters transcripts based on specific criteria for Merscope using Dask.

    Parameters:
        transcripts_df : dd.DataFrame
            The Dask DataFrame containing transcript data.
        min_qv : float, optional
            The minimum quality value threshold for filtering transcripts.

    Returns:
        dd.DataFrame
            The filtered Dask DataFrame.
    """
    # Add custom Merscope-specific filtering logic if needed
    # For now, apply only the quality value filter
    return transcripts_df[transcripts_df[self.keys.QUALITY_VALUE.value] >= min_qv]

SpatialDataKeys

Bases: Enum

Keys for MERSCOPE data (Vizgen platform).

SpatialDataSample

SpatialDataSample(transcripts_df=None, transcripts_radius=10, boundaries_graph=False, embedding_df=None, feature_name=None, verbose=True)

Bases: SpatialTranscriptomicsSample

Source code in src/segger/data/io.py
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
def __init__(
    self,
    transcripts_df: dd.DataFrame = None,
    transcripts_radius: int = 10,
    boundaries_graph: bool = False,
    embedding_df: pd.DataFrame = None,
    feature_name: str | None = None,
    verbose: bool = True,
):
    if feature_name is not None:
        # luca: just a quick hack for now, I propose to use dataclasses instead of enums to address this
        SpatialDataKeys.FEATURE_NAME._value_ = feature_name
    else:
        raise ValueError(
            "the automatic determination of a feature_name from a SpatialData object is not enabled yet"
        )

    super().__init__(
        transcripts_df, transcripts_radius, boundaries_graph, embedding_df, SpatialDataKeys, verbose=verbose
    )

filter_transcripts

filter_transcripts(transcripts_df, min_qv=20.0)

Filters transcripts based on quality value and removes unwanted transcripts for Xenium using Dask.

Parameters:

Name Type Description Default
transcripts_df DataFrame

The Dask DataFrame containing transcript data.

required
min_qv float

The minimum quality value threshold for filtering transcripts.

20.0

Returns:

Type Description
DataFrame

dd.DataFrame: The filtered Dask DataFrame.

Source code in src/segger/data/io.py
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
def filter_transcripts(self, transcripts_df: dd.DataFrame, min_qv: float = 20.0) -> dd.DataFrame:
    """
    Filters transcripts based on quality value and removes unwanted transcripts for Xenium using Dask.

    Parameters:
        transcripts_df (dd.DataFrame): The Dask DataFrame containing transcript data.
        min_qv (float, optional): The minimum quality value threshold for filtering transcripts.

    Returns:
        dd.DataFrame: The filtered Dask DataFrame.
    """
    filter_codewords = (
        "NegControlProbe_",
        "antisense_",
        "NegControlCodeword_",
        "BLANK_",
        "DeprecatedCodeword_",
        "UnassignedCodeword_",
    )

    # Ensure FEATURE_NAME is a string type for proper filtering (compatible with Dask)
    # Handle potential bytes to string conversion for Dask DataFrame
    if pd.api.types.is_object_dtype(transcripts_df[self.keys.FEATURE_NAME.value]):
        transcripts_df[self.keys.FEATURE_NAME.value] = transcripts_df[self.keys.FEATURE_NAME.value].apply(
            lambda x: x.decode("utf-8") if isinstance(x, bytes) else x
        )

    # Apply the quality value filter using Dask
    mask_quality = transcripts_df[self.keys.QUALITY_VALUE.value] >= min_qv

    # Apply the filter for unwanted codewords using Dask string functions
    mask_codewords = ~transcripts_df[self.keys.FEATURE_NAME.value].str.startswith(filter_codewords)

    # Combine the filters and return the filtered Dask DataFrame
    mask = mask_quality & mask_codewords

    # Return the filtered DataFrame lazily
    return transcripts_df[mask]

SpatialTranscriptomicsDataset

SpatialTranscriptomicsDataset(root, transform=None, pre_transform=None, pre_filter=None)

Bases: InMemoryDataset

A dataset class for handling SpatialTranscriptomics spatial transcriptomics data.

Attributes:

Name Type Description
root str

The root directory where the dataset is stored.

transform callable

A function/transform that takes in a Data object and returns a transformed version.

pre_transform callable

A function/transform that takes in a Data object and returns a transformed version.

pre_filter callable

A function that takes in a Data object and returns a boolean indicating whether to keep it.

Initialize the SpatialTranscriptomicsDataset.

Parameters:

Name Type Description Default
root str

Root directory where the dataset is stored.

required
transform callable

A function/transform that takes in a Data object and returns a transformed version. Defaults to None.

None
pre_transform callable

A function/transform that takes in a Data object and returns a transformed version. Defaults to None.

None
pre_filter callable

A function that takes in a Data object and returns a boolean indicating whether to keep it. Defaults to None.

None
Source code in src/segger/data/utils.py
399
400
401
402
403
404
405
406
407
408
409
410
def __init__(
    self, root: str, transform: Callable = None, pre_transform: Callable = None, pre_filter: Callable = None
):
    """Initialize the SpatialTranscriptomicsDataset.

    Args:
        root (str): Root directory where the dataset is stored.
        transform (callable, optional): A function/transform that takes in a Data object and returns a transformed version. Defaults to None.
        pre_transform (callable, optional): A function/transform that takes in a Data object and returns a transformed version. Defaults to None.
        pre_filter (callable, optional): A function that takes in a Data object and returns a boolean indicating whether to keep it. Defaults to None.
    """
    super().__init__(root, transform, pre_transform, pre_filter)

processed_file_names property

processed_file_names

Return a list of processed file names in the processed directory.

Returns:

Type Description
List[str]

List[str]: List of processed file names.

raw_file_names property

raw_file_names

Return a list of raw file names in the raw directory.

Returns:

Type Description
List[str]

List[str]: List of raw file names.

download

download()

Download the raw data. This method should be overridden if you need to download the data.

Source code in src/segger/data/utils.py
430
431
432
def download(self) -> None:
    """Download the raw data. This method should be overridden if you need to download the data."""
    pass

get

get(idx)

Get a processed data object.

Parameters:

Name Type Description Default
idx int

Index of the data object to retrieve.

required

Returns:

Name Type Description
Data Data

The processed data object.

Source code in src/segger/data/utils.py
446
447
448
449
450
451
452
453
454
455
456
457
def get(self, idx: int) -> Data:
    """Get a processed data object.

    Args:
        idx (int): Index of the data object to retrieve.

    Returns:
        Data: The processed data object.
    """
    data = torch.load(os.path.join(self.processed_dir, self.processed_file_names[idx]))
    data["tx"].x = data["tx"].x.to_dense()
    return data

len

len()

Return the number of processed files.

Returns:

Name Type Description
int int

Number of processed files.

Source code in src/segger/data/utils.py
438
439
440
441
442
443
444
def len(self) -> int:
    """Return the number of processed files.

    Returns:
        int: Number of processed files.
    """
    return len(self.processed_file_names)

process

process()

Process the raw data and save it to the processed directory. This method should be overridden if you need to process the data.

Source code in src/segger/data/utils.py
434
435
436
def process(self) -> None:
    """Process the raw data and save it to the processed directory. This method should be overridden if you need to process the data."""
    pass

SpatialTranscriptomicsKeys

Bases: Enum

Unified keys for spatial transcriptomics data, supporting multiple platforms.

SpatialTranscriptomicsSample

SpatialTranscriptomicsSample(transcripts_df=None, transcripts_radius=10, boundaries_graph=False, embedding_df=None, keys=None, verbose=True)

Bases: ABC

Initialize the SpatialTranscriptomicsSample class.

Parameters:

Name Type Description Default
transcripts_df DataFrame

A DataFrame containing transcript data.

None
transcripts_radius int

Radius for transcripts in the analysis.

10
boundaries_graph bool

Whether to include boundaries (e.g., nucleus, cell) graph information.

False
keys Dict

The enum class containing key mappings specific to the dataset.

None
Source code in src/segger/data/io.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(
    self,
    transcripts_df: pd.DataFrame = None,
    transcripts_radius: int = 10,
    boundaries_graph: bool = False,
    embedding_df: pd.DataFrame = None,
    keys: Dict = None,
    verbose: bool = True,
):
    """Initialize the SpatialTranscriptomicsSample class.

    Args:
        transcripts_df (pd.DataFrame, optional): A DataFrame containing transcript data.
        transcripts_radius (int, optional): Radius for transcripts in the analysis.
        boundaries_graph (bool, optional): Whether to include boundaries (e.g., nucleus, cell) graph information.
        keys (Dict, optional): The enum class containing key mappings specific to the dataset.
    """
    self.transcripts_df = transcripts_df
    self.transcripts_radius = transcripts_radius
    self.boundaries_graph = boundaries_graph
    self.keys = keys
    self.embedding_df = embedding_df
    self.current_embedding = "token"
    self.verbose = verbose

build_pyg_data_from_tile

build_pyg_data_from_tile(boundaries_df, transcripts_df, r_tx=5.0, k_tx=3, method='kd_tree', gpu=False, workers=1, scale_boundaries=1.0)

Builds PyG data from a tile of boundaries and transcripts data using Dask utilities for efficient processing.

Parameters:

Name Type Description Default
boundaries_df DataFrame

Dask DataFrame containing boundaries data (e.g., nucleus, cell).

required
transcripts_df DataFrame

Dask DataFrame containing transcripts data.

required
r_tx float

Radius for building the transcript-to-transcript graph.

5.0
k_tx int

Number of nearest neighbors for the tx-tx graph.

3
method str

Method for computing edge indices (e.g., 'kd_tree', 'faiss').

'kd_tree'
gpu bool

Whether to use GPU acceleration for edge index computation.

False
workers int

Number of workers to use for parallel processing.

1
scale_boundaries float

The factor by which to scale the boundary polygons. Default is 1.0.

1.0

Returns:

Name Type Description
HeteroData HeteroData

PyG Heterogeneous Data object.

Source code in src/segger/data/io.py
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
def build_pyg_data_from_tile(
    self,
    boundaries_df: dd.DataFrame,
    transcripts_df: dd.DataFrame,
    r_tx: float = 5.0,
    k_tx: int = 3,
    method: str = "kd_tree",
    gpu: bool = False,
    workers: int = 1,
    scale_boundaries: float = 1.0,
) -> HeteroData:
    """
    Builds PyG data from a tile of boundaries and transcripts data using Dask utilities for efficient processing.

    Parameters:
        boundaries_df (dd.DataFrame): Dask DataFrame containing boundaries data (e.g., nucleus, cell).
        transcripts_df (dd.DataFrame): Dask DataFrame containing transcripts data.
        r_tx (float): Radius for building the transcript-to-transcript graph.
        k_tx (int): Number of nearest neighbors for the tx-tx graph.
        method (str, optional): Method for computing edge indices (e.g., 'kd_tree', 'faiss').
        gpu (bool, optional): Whether to use GPU acceleration for edge index computation.
        workers (int, optional): Number of workers to use for parallel processing.
        scale_boundaries (float, optional): The factor by which to scale the boundary polygons. Default is 1.0.

    Returns:
        HeteroData: PyG Heterogeneous Data object.
    """
    # Initialize the PyG HeteroData object
    data = HeteroData()

    # Lazily compute boundaries geometries using Dask
    if self.verbose:
        print("Computing boundaries geometries...")
    bd_gdf = self.compute_boundaries_geometries(boundaries_df, scale_factor=scale_boundaries)
    bd_gdf = bd_gdf[bd_gdf["geometry"].notnull()]

    # Add boundary node data to PyG HeteroData lazily
    data["bd"].id = bd_gdf[self.keys.CELL_ID.value].values
    data["bd"].pos = torch.as_tensor(bd_gdf[["centroid_x", "centroid_y"]].values.astype(float))

    if data["bd"].pos.isnan().any():
        raise ValueError(data["bd"].id[data["bd"].pos.isnan().any(1)])

    bd_x = bd_gdf.iloc[:, 4:]
    data["bd"].x = torch.as_tensor(bd_x.to_numpy(), dtype=torch.float32)

    # Extract the transcript coordinates lazily
    if self.verbose:
        print("Preparing transcript features and positions...")
    x_xyz = transcripts_df[[self.keys.TRANSCRIPTS_X.value, self.keys.TRANSCRIPTS_Y.value]].to_numpy()
    data["tx"].id = torch.as_tensor(transcripts_df[self.keys.TRANSCRIPTS_ID.value].values.astype(int))
    data["tx"].pos = torch.tensor(x_xyz, dtype=torch.float32)

    # Lazily prepare transcript embeddings (if available)
    if self.verbose:
        print("Preparing transcript embeddings..")
    token_encoding = self.tx_encoder.transform(transcripts_df[self.keys.FEATURE_NAME.value])
    transcripts_df["token"] = token_encoding  # Store the integer tokens in the 'token' column
    data["tx"].token = torch.as_tensor(token_encoding).int()
    # Handle additional embeddings lazily as well
    if self.embedding_df is not None and not self.embedding_df.empty:
        embeddings = delayed(lambda df: self.embedding_df.loc[df[self.keys.FEATURE_NAME.value].values].values)(
            transcripts_df
        )
    else:
        embeddings = token_encoding
    if hasattr(embeddings, "compute"):
        embeddings = embeddings.compute()
    x_features = torch.as_tensor(embeddings).int()
    data["tx"].x = x_features

    # Check if the overlap column exists, if not, compute it lazily using Dask
    if self.keys.OVERLAPS_BOUNDARY.value not in transcripts_df.columns:
        if self.verbose:
            print(f"Computing overlaps for transcripts...")
        transcripts_df = self.compute_transcript_overlap_with_boundaries(
            transcripts_df, polygons_gdf=bd_gdf, scale_factor=1.0
        )

    # Connect transcripts with their corresponding boundaries (e.g., nuclei, cells)
    if self.verbose:
        print("Connecting transcripts with boundaries...")
    overlaps = transcripts_df[self.keys.OVERLAPS_BOUNDARY.value].values
    valid_cell_ids = bd_gdf[self.keys.CELL_ID.value].values
    ind = np.where(overlaps & transcripts_df[self.keys.CELL_ID.value].isin(valid_cell_ids))[0]
    tx_bd_edge_index = np.column_stack(
        (ind, np.searchsorted(valid_cell_ids, transcripts_df.iloc[ind][self.keys.CELL_ID.value]))
    )

    # Add transcript-boundary edge index to PyG HeteroData
    data["tx", "belongs", "bd"].edge_index = torch.as_tensor(tx_bd_edge_index.T, dtype=torch.long)

    # Compute transcript-to-transcript (tx-tx) edges using Dask (lazy computation)
    if self.verbose:
        print("Computing tx-tx edges...")
    tx_positions = transcripts_df[[self.keys.TRANSCRIPTS_X.value, self.keys.TRANSCRIPTS_Y.value]].values
    delayed_tx_edge_index = delayed(get_edge_index)(
        tx_positions, tx_positions, k=k_tx, dist=r_tx, method=method, gpu=gpu, workers=workers
    )
    tx_edge_index = delayed_tx_edge_index.compute()

    # Add the tx-tx edge index to the PyG HeteroData object
    data["tx", "neighbors", "tx"].edge_index = torch.as_tensor(tx_edge_index.T, dtype=torch.long)

    if self.verbose:
        print("Finished building PyG data for the tile.")
    return data

compute_boundaries_geometries

compute_boundaries_geometries(boundaries_df=None, polygons_gdf=None, scale_factor=1.0, area=True, convexity=True, elongation=True, circularity=True)

Computes geometries for boundaries (e.g., nuclei, cells) from the dataframe using Dask.

Parameters:

Name Type Description Default
boundaries_df DataFrame

The dataframe containing boundaries data. Required if polygons_gdf is not provided.

None
polygons_gdf GeoDataFrame

Precomputed Dask GeoDataFrame containing boundary polygons. If None, will compute from boundaries_df.

None
scale_factor float

The factor by which to scale the polygons (default is 1.0).

1.0
area bool

Whether to compute area.

True
convexity bool

Whether to compute convexity.

True
elongation bool

Whether to compute elongation.

True
circularity bool

Whether to compute circularity.

True

Returns:

Type Description
GeoDataFrame

dgpd.GeoDataFrame: A GeoDataFrame containing computed geometries.

Source code in src/segger/data/io.py
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
def compute_boundaries_geometries(
    self,
    boundaries_df: dd.DataFrame = None,
    polygons_gdf: dgpd.GeoDataFrame = None,
    scale_factor: float = 1.0,
    area: bool = True,
    convexity: bool = True,
    elongation: bool = True,
    circularity: bool = True,
) -> dgpd.GeoDataFrame:
    """
    Computes geometries for boundaries (e.g., nuclei, cells) from the dataframe using Dask.

    Parameters:
        boundaries_df (dd.DataFrame, optional): The dataframe containing boundaries data. Required if polygons_gdf is not provided.
        polygons_gdf (dgpd.GeoDataFrame, optional): Precomputed Dask GeoDataFrame containing boundary polygons. If None, will compute from boundaries_df.
        scale_factor (float, optional): The factor by which to scale the polygons (default is 1.0).
        area (bool, optional): Whether to compute area.
        convexity (bool, optional): Whether to compute convexity.
        elongation (bool, optional): Whether to compute elongation.
        circularity (bool, optional): Whether to compute circularity.

    Returns:
        dgpd.GeoDataFrame: A GeoDataFrame containing computed geometries.
    """
    # Check if polygons_gdf is provided, otherwise compute from boundaries_df
    if polygons_gdf is None:
        if boundaries_df is None:
            raise ValueError("Both boundaries_df and polygons_gdf cannot be None. Provide at least one.")

        # Generate polygons from boundaries_df if polygons_gdf is None
        if self.verbose:
            print(
                f"No precomputed polygons provided. Computing polygons from boundaries with a scale factor of {scale_factor}."
            )
        polygons_gdf = self.generate_and_scale_polygons(boundaries_df, scale_factor)

    # Check if the generated polygons_gdf is empty
    if polygons_gdf.shape[0] == 0:
        raise ValueError("No valid polygons were generated from the boundaries.")
    else:
        if self.verbose:
            print(f"Polygons are available. Proceeding with geometrical computations.")

    # Compute additional geometrical properties
    polygons = polygons_gdf.geometry

    # Compute additional geometrical properties
    if area:
        if self.verbose:
            print("Computing area...")
        polygons_gdf["area"] = polygons.area
    if convexity:
        if self.verbose:
            print("Computing convexity...")
        polygons_gdf["convexity"] = polygons.convex_hull.area / polygons.area
    if elongation:
        if self.verbose:
            print("Computing elongation...")
        r = polygons.minimum_rotated_rectangle()
        polygons_gdf["elongation"] = (r.length * r.length) / r.area
    if circularity:
        if self.verbose:
            print("Computing circularity...")
        r = polygons_gdf.minimum_bounding_radius()
        polygons_gdf["circularity"] = polygons.area / (r * r)

    if self.verbose:
        print("Geometrical computations completed.")

    return polygons_gdf.reset_index(drop=True)

compute_transcript_overlap_with_boundaries

compute_transcript_overlap_with_boundaries(transcripts_df, boundaries_df=None, polygons_gdf=None, scale_factor=1.0)

Computes the overlap of transcript locations with scaled boundary polygons and assigns corresponding cell IDs to the transcripts using Dask.

Parameters:

Name Type Description Default
transcripts_df DataFrame

Dask DataFrame containing transcript data.

required
boundaries_df DataFrame

Dask DataFrame containing boundary data. Required if polygons_gdf is not provided.

None
polygons_gdf GeoDataFrame

Precomputed Dask GeoDataFrame containing boundary polygons. If None, will compute from boundaries_df.

None
scale_factor float

The factor by which to scale the boundary polygons. Default is 1.0.

1.0

Returns:

Type Description
DataFrame

dd.DataFrame: The updated DataFrame with overlap information and assigned cell IDs.

Source code in src/segger/data/io.py
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
def compute_transcript_overlap_with_boundaries(
    self,
    transcripts_df: dd.DataFrame,
    boundaries_df: dd.DataFrame = None,
    polygons_gdf: dgpd.GeoDataFrame = None,
    scale_factor: float = 1.0,
) -> dd.DataFrame:
    """
    Computes the overlap of transcript locations with scaled boundary polygons
    and assigns corresponding cell IDs to the transcripts using Dask.

    Parameters:
        transcripts_df (dd.DataFrame): Dask DataFrame containing transcript data.
        boundaries_df (dd.DataFrame, optional): Dask DataFrame containing boundary data. Required if polygons_gdf is not provided.
        polygons_gdf (dgpd.GeoDataFrame, optional): Precomputed Dask GeoDataFrame containing boundary polygons. If None, will compute from boundaries_df.
        scale_factor (float, optional): The factor by which to scale the boundary polygons. Default is 1.0.

    Returns:
        dd.DataFrame: The updated DataFrame with overlap information and assigned cell IDs.
    """
    # Check if polygons_gdf is provided, otherwise compute from boundaries_df
    if polygons_gdf is None:
        if boundaries_df is None:
            raise ValueError("Both boundaries_df and polygons_gdf cannot be None. Provide at least one.")

        # Generate polygons from boundaries_df if polygons_gdf is None
        # if self.verbose: print(f"No precomputed polygons provided. Computing polygons from boundaries with a scale factor of {scale_factor}.")
        polygons_gdf = self.generate_and_scale_polygons(boundaries_df, scale_factor)

    if polygons_gdf.empty:
        raise ValueError("No valid polygons were generated from the boundaries.")
    else:
        if self.verbose:
            print(f"Polygons are available. Proceeding with overlap computation.")

    # Create a delayed function to check if a point is within any polygon
    def check_overlap(transcript, polygons_gdf):
        x = transcript[self.keys.TRANSCRIPTS_X.value]
        y = transcript[self.keys.TRANSCRIPTS_Y.value]
        point = Point(x, y)

        overlap = False
        cell_id = None

        # Check for point containment lazily within polygons
        for _, polygon in polygons_gdf.iterrows():
            if polygon.geometry.contains(point):
                overlap = True
                cell_id = polygon[self.keys.CELL_ID.value]
                break

        return overlap, cell_id

    # Apply the check_overlap function in parallel to each row using Dask's map_partitions
    if self.verbose:
        print(f"Starting overlap computation for transcripts with the boundary polygons.")
    if isinstance(transcripts_df, pd.DataFrame):
        # luca: I found this bug here
        warnings.warn("BUG! This function expects Dask DataFrames, not Pandas DataFrames.")
        # if we want to really have the below working in parallel, we need to add n_partitions>1 here
        transcripts_df = dd.from_pandas(transcripts_df, npartitions=1)
        transcripts_df.compute().columns
    transcripts_df = transcripts_df.map_partitions(
        lambda df: df.assign(
            **{
                self.keys.OVERLAPS_BOUNDARY.value: df.apply(
                    lambda row: delayed(check_overlap)(row, polygons_gdf)[0], axis=1
                ),
                self.keys.CELL_ID.value: df.apply(lambda row: delayed(check_overlap)(row, polygons_gdf)[1], axis=1),
            }
        )
    )

    return transcripts_df

create_scaled_polygon staticmethod

create_scaled_polygon(group, scale_factor, keys)

Static method to create and scale a polygon from boundary vertices and return a GeoDataFrame.

Parameters:

Name Type Description Default
group DataFrame

Group of boundary coordinates (for a specific cell).

required
scale_factor float

The factor by which to scale the polygons.

required
keys Dict or Enum

A collection of keys to access column names for 'cell_id', 'vertex_x', and 'vertex_y'.

required

Returns:

Type Description
GeoDataFrame

gpd.GeoDataFrame: A GeoDataFrame containing the scaled Polygon and cell_id.

Source code in src/segger/data/io.py
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
@staticmethod
def create_scaled_polygon(group: pd.DataFrame, scale_factor: float, keys) -> gpd.GeoDataFrame:
    """
    Static method to create and scale a polygon from boundary vertices and return a GeoDataFrame.

    Parameters:
        group (pd.DataFrame): Group of boundary coordinates (for a specific cell).
        scale_factor (float): The factor by which to scale the polygons.
        keys (Dict or Enum): A collection of keys to access column names for 'cell_id', 'vertex_x', and 'vertex_y'.

    Returns:
        gpd.GeoDataFrame: A GeoDataFrame containing the scaled Polygon and cell_id.
    """
    # Extract coordinates and cell ID from the group using keys
    x_coords = group[keys["vertex_x"]]
    y_coords = group[keys["vertex_y"]]
    cell_id = group[keys["cell_id"]].iloc[0]

    # Ensure there are at least 3 points to form a polygon
    if len(x_coords) >= 3:

        polygon = Polygon(zip(x_coords, y_coords))
        if polygon.is_valid and not polygon.is_empty:
            # Scale the polygon by the provided factor
            scaled_polygon = polygon.buffer(scale_factor)
            if scaled_polygon.is_valid and not scaled_polygon.is_empty:
                return gpd.GeoDataFrame(
                    {"geometry": [scaled_polygon], keys["cell_id"]: [cell_id]}, geometry="geometry", crs="EPSG:4326"
                )
    # Return an empty GeoDataFrame if no valid polygon is created
    return gpd.GeoDataFrame({"geometry": [None], keys["cell_id"]: [cell_id]}, geometry="geometry", crs="EPSG:4326")

filter_transcripts abstractmethod

filter_transcripts(transcripts_df, min_qv=20.0)

Abstract method to filter transcripts based on dataset-specific criteria.

Parameters:

Name Type Description Default
transcripts_df DataFrame

The dataframe containing transcript data.

required
min_qv float

The minimum quality value threshold for filtering transcripts.

20.0

Returns:

Type Description
DataFrame

pd.DataFrame: The filtered dataframe.

Source code in src/segger/data/io.py
63
64
65
66
67
68
69
70
71
72
73
74
75
@abstractmethod
def filter_transcripts(self, transcripts_df: pd.DataFrame, min_qv: float = 20.0) -> pd.DataFrame:
    """
    Abstract method to filter transcripts based on dataset-specific criteria.

    Parameters:
        transcripts_df (pd.DataFrame): The dataframe containing transcript data.
        min_qv (float, optional): The minimum quality value threshold for filtering transcripts.

    Returns:
        pd.DataFrame: The filtered dataframe.
    """
    pass

generate_and_scale_polygons

generate_and_scale_polygons(boundaries_df, scale_factor=1.0)

Generate and scale polygons from boundary coordinates using Dask. Keeps class structure intact by using static method for the core polygon generation.

Parameters:

Name Type Description Default
boundaries_df DataFrame

DataFrame containing boundary coordinates.

required
scale_factor float

The factor by which to scale the polygons (default is 1.0).

1.0

Returns:

Type Description
GeoDataFrame

dgpd.GeoDataFrame: A GeoDataFrame containing scaled Polygon objects and their centroids.

Source code in src/segger/data/io.py
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
def generate_and_scale_polygons(self, boundaries_df: dd.DataFrame, scale_factor: float = 1.0) -> dgpd.GeoDataFrame:
    """
    Generate and scale polygons from boundary coordinates using Dask.
    Keeps class structure intact by using static method for the core polygon generation.

    Parameters:
        boundaries_df (dd.DataFrame): DataFrame containing boundary coordinates.
        scale_factor (float, optional): The factor by which to scale the polygons (default is 1.0).

    Returns:
        dgpd.GeoDataFrame: A GeoDataFrame containing scaled Polygon objects and their centroids.
    """
    # if self.verbose: print(f"No precomputed polygons provided. Computing polygons from boundaries with a scale factor of {scale_factor}.")

    # Extract required columns from self.keys
    cell_id_column = self.keys.CELL_ID.value
    vertex_x_column = self.keys.BOUNDARIES_VERTEX_X.value
    vertex_y_column = self.keys.BOUNDARIES_VERTEX_Y.value

    create_polygon = self.create_scaled_polygon
    # Use a lambda to wrap the static method call and avoid passing the function object directly to Dask
    polygons_ddf = boundaries_df.groupby(cell_id_column).apply(
        lambda group: create_polygon(
            group=group,
            scale_factor=scale_factor,
            keys={  # Pass keys as a dict for the lambda function
                "vertex_x": vertex_x_column,
                "vertex_y": vertex_y_column,
                "cell_id": cell_id_column,
            },
        )
    )

    # Lazily compute centroids for each polygon
    if self.verbose:
        print("Adding centroids to the polygons...")
    polygons_ddf["centroid_x"] = polygons_ddf.geometry.centroid.x
    polygons_ddf["centroid_y"] = polygons_ddf.geometry.centroid.y

    polygons_ddf = polygons_ddf.drop_duplicates()
    # polygons_ddf = polygons_ddf.to_crs("EPSG:3857")

    return polygons_ddf

load_boundaries

load_boundaries(path, file_format='parquet', x_min=None, x_max=None, y_min=None, y_max=None)

Load boundaries data lazily using Dask, filtering by the specified bounding box.

Parameters:

Name Type Description Default
path Path

Path to the boundaries file.

required
file_format str

Format of the file to load. Only 'parquet' is supported in this refactor.

'parquet'
x_min float

Minimum X-coordinate for the bounding box.

None
x_max float

Maximum X-coordinate for the bounding box.

None
y_min float

Minimum Y-coordinate for the bounding box.

None
y_max float

Maximum Y-coordinate for the bounding box.

None

Returns:

Type Description
DataFrame

dd.DataFrame: The filtered boundaries DataFrame.

Source code in src/segger/data/io.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
def load_boundaries(
    self,
    path: Path,
    file_format: str = "parquet",
    x_min: float = None,
    x_max: float = None,
    y_min: float = None,
    y_max: float = None,
) -> dd.DataFrame:
    """
    Load boundaries data lazily using Dask, filtering by the specified bounding box.

    Parameters:
        path (Path): Path to the boundaries file.
        file_format (str, optional): Format of the file to load. Only 'parquet' is supported in this refactor.
        x_min (float, optional): Minimum X-coordinate for the bounding box.
        x_max (float, optional): Maximum X-coordinate for the bounding box.
        y_min (float, optional): Minimum Y-coordinate for the bounding box.
        y_max (float, optional): Maximum Y-coordinate for the bounding box.

    Returns:
        dd.DataFrame: The filtered boundaries DataFrame.
    """
    if file_format != "parquet":
        raise ValueError(f"Unsupported file format: {file_format}")

    self.boundaries_path = path

    # Use bounding box values from set_metadata if not explicitly provided
    x_min = x_min or self.x_min
    x_max = x_max or self.x_max
    y_min = y_min or self.y_min
    y_max = y_max or self.y_max

    # Define the list of columns to read
    columns_to_read = [
        self.keys.BOUNDARIES_VERTEX_X.value,
        self.keys.BOUNDARIES_VERTEX_Y.value,
        self.keys.CELL_ID.value,
    ]

    # Use filters to only load data within the specified bounding box (x_min, x_max, y_min, y_max)
    filters = [
        (self.keys.BOUNDARIES_VERTEX_X.value, ">=", x_min),
        (self.keys.BOUNDARIES_VERTEX_X.value, "<=", x_max),
        (self.keys.BOUNDARIES_VERTEX_Y.value, ">=", y_min),
        (self.keys.BOUNDARIES_VERTEX_Y.value, "<=", y_max),
    ]

    # Load the dataset lazily with filters applied for the bounding box
    columns = set(dd.read_parquet(path).columns)
    if "geometry" in columns:
        bbox = (x_min, y_min, x_max, y_max)
        # TODO: check that SpatialData objects write the "bbox covering metadata" to the parquet file
        gdf = dgpd.read_parquet(path, bbox=bbox)
        id_col, x_col, y_col = (
            self.keys.CELL_ID.value,
            self.keys.BOUNDARIES_VERTEX_X.value,
            self.keys.BOUNDARIES_VERTEX_Y.value,
        )

        # Function to expand each polygon into a list of vertices
        def expand_polygon(row):
            expanded_data = []
            polygon = row["geometry"]
            if polygon.geom_type == "Polygon":
                exterior_coords = polygon.exterior.coords
                for x, y in exterior_coords:
                    expanded_data.append({id_col: row.name, x_col: x, y_col: y})
            else:
                # Instead of expanding the gdf and then having code later to recreate it (when computing the pyg graph)
                # we could directly have this function returning a Dask GeoDataFrame. This means that we don't need
                # to implement this else black
                raise ValueError(f"Unsupported geometry type: {polygon.geom_type}")
            return expanded_data

        # Apply the function to each partition and collect results
        def process_partition(df):
            expanded_data = [expand_polygon(row) for _, row in df.iterrows()]
            # Flatten the list of lists
            flattened_data = [item for sublist in expanded_data for item in sublist]
            return pd.DataFrame(flattened_data)

        # Use map_partitions to apply the function and convert it into a Dask DataFrame
        boundaries_df = gdf.map_partitions(process_partition, meta={id_col: str, x_col: float, y_col: float})
    else:
        boundaries_df = dd.read_parquet(path, columns=columns_to_read, filters=filters)

        # Convert the cell IDs to strings lazily
        boundaries_df[self.keys.CELL_ID.value] = boundaries_df[self.keys.CELL_ID.value].apply(
            lambda x: str(x) if pd.notnull(x) else None, meta=("cell_id", "object")
        )

    if self.verbose:
        print(f"Loaded boundaries from '{path}' within bounding box ({x_min}, {x_max}, {y_min}, {y_max}).")

    return boundaries_df

load_transcripts

load_transcripts(base_path=None, sample=None, transcripts_filename=None, path=None, file_format='parquet', x_min=None, x_max=None, y_min=None, y_max=None)

Load transcripts from a Parquet file using Dask for efficient chunked processing, only within the specified bounding box, and return the filtered DataFrame with integer token embeddings.

Parameters:

Name Type Description Default
base_path Path

The base directory path where samples are stored.

None
sample str

The sample name or identifier.

None
transcripts_filename str

The filename of the transcripts file (default is derived from the dataset keys).

None
path Path

Specific path to the transcripts file.

None
file_format str

Format of the file to load (default is 'parquet').

'parquet'
x_min float

Minimum X-coordinate for the bounding box.

None
x_max float

Maximum X-coordinate for the bounding box.

None
y_min float

Minimum Y-coordinate for the bounding box.

None
y_max float

Maximum Y-coordinate for the bounding box.

None

Returns:

Type Description
DataFrame

dd.DataFrame: The filtered transcripts DataFrame.

Source code in src/segger/data/io.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def load_transcripts(
    self,
    base_path: Path = None,
    sample: str = None,
    transcripts_filename: str = None,
    path: Path = None,
    file_format: str = "parquet",
    x_min: float = None,
    x_max: float = None,
    y_min: float = None,
    y_max: float = None,
    # additional_embeddings: Optional[Dict[str, pd.DataFrame]] = None,
) -> dd.DataFrame:
    """
    Load transcripts from a Parquet file using Dask for efficient chunked processing,
    only within the specified bounding box, and return the filtered DataFrame with integer token embeddings.

    Parameters:
        base_path (Path, optional): The base directory path where samples are stored.
        sample (str, optional): The sample name or identifier.
        transcripts_filename (str, optional): The filename of the transcripts file (default is derived from the dataset keys).
        path (Path, optional): Specific path to the transcripts file.
        file_format (str, optional): Format of the file to load (default is 'parquet').
        x_min (float, optional): Minimum X-coordinate for the bounding box.
        x_max (float, optional): Maximum X-coordinate for the bounding box.
        y_min (float, optional): Minimum Y-coordinate for the bounding box.
        y_max (float, optional): Maximum Y-coordinate for the bounding box.

    Returns:
        dd.DataFrame: The filtered transcripts DataFrame.
    """
    if file_format != "parquet":
        raise ValueError("This version only supports parquet files with Dask.")

    # Set the file path for transcripts
    transcripts_filename = transcripts_filename or self.keys.TRANSCRIPTS_FILE.value
    file_path = path or (base_path / sample / transcripts_filename)
    self.transcripts_path = file_path

    # Set metadata
    # self.set_metadata()

    # Use bounding box values from set_metadata if not explicitly provided
    x_min = x_min or self.x_min
    x_max = x_max or self.x_max
    y_min = y_min or self.y_min
    y_max = y_max or self.y_max

    # Check for available columns in the file's metadata (without loading the data)
    parquet_metadata = dd.read_parquet(file_path, meta_only=True)
    available_columns = parquet_metadata.columns

    # Define the list of columns to read
    columns_to_read = [
        self.keys.TRANSCRIPTS_ID.value,
        self.keys.TRANSCRIPTS_X.value,
        self.keys.TRANSCRIPTS_Y.value,
        self.keys.FEATURE_NAME.value,
        self.keys.CELL_ID.value,
    ]

    # Check if the QUALITY_VALUE key exists in the dataset, and add it to the columns list if present
    if self.keys.QUALITY_VALUE.value in available_columns:
        columns_to_read.append(self.keys.QUALITY_VALUE.value)

    if self.keys.OVERLAPS_BOUNDARY.value in available_columns:
        columns_to_read.append(self.keys.OVERLAPS_BOUNDARY.value)

    # Use filters to only load data within the specified bounding box (x_min, x_max, y_min, y_max)
    filters = [
        (self.keys.TRANSCRIPTS_X.value, ">=", x_min),
        (self.keys.TRANSCRIPTS_X.value, "<=", x_max),
        (self.keys.TRANSCRIPTS_Y.value, ">=", y_min),
        (self.keys.TRANSCRIPTS_Y.value, "<=", y_max),
    ]

    # Load the dataset lazily with filters applied for the bounding box
    columns = set(dd.read_parquet(file_path).columns)
    transcripts_df = dd.read_parquet(file_path, columns=columns_to_read, filters=filters).compute()

    # Convert transcript and cell IDs to strings lazily
    transcripts_df[self.keys.TRANSCRIPTS_ID.value] = transcripts_df[self.keys.TRANSCRIPTS_ID.value].apply(
        lambda x: str(x) if pd.notnull(x) else None,
    )
    transcripts_df[self.keys.CELL_ID.value] = transcripts_df[self.keys.CELL_ID.value].apply(
        lambda x: str(x) if pd.notnull(x) else None,
    )

    # Convert feature names from bytes to strings if necessary
    if pd.api.types.is_object_dtype(transcripts_df[self.keys.FEATURE_NAME.value]):
        transcripts_df[self.keys.FEATURE_NAME.value] = transcripts_df[self.keys.FEATURE_NAME.value].astype(str)

    # Apply dataset-specific filtering (e.g., quality filtering for Xenium)
    transcripts_df = self.filter_transcripts(transcripts_df)

    # Handle additional embeddings if provided
    if self.embedding_df is not None and not self.embedding_df.empty:
        valid_genes = self.embedding_df.index
        # Lazily count the number of rows in the DataFrame before filtering
        initial_count = delayed(lambda df: df.shape[0])(transcripts_df)
        # Filter the DataFrame lazily based on valid genes from embeddings
        transcripts_df = transcripts_df[transcripts_df[self.keys.FEATURE_NAME.value].isin(valid_genes)]
        final_count = delayed(lambda df: df.shape[0])(transcripts_df)
        if self.verbose:
            print(f"Dropped {initial_count - final_count} transcripts not found in embedding.")

    # Ensure that the 'OVERLAPS_BOUNDARY' column is boolean if it exists
    if self.keys.OVERLAPS_BOUNDARY.value in transcripts_df.columns:
        transcripts_df[self.keys.OVERLAPS_BOUNDARY.value] = transcripts_df[
            self.keys.OVERLAPS_BOUNDARY.value
        ].astype(bool)

    return transcripts_df

save_dataset_for_segger

save_dataset_for_segger(processed_dir, x_size=1000, y_size=1000, d_x=900, d_y=900, margin_x=None, margin_y=None, compute_labels=True, r_tx=5, k_tx=3, val_prob=0.1, test_prob=0.2, neg_sampling_ratio_approx=5, sampling_rate=1, num_workers=1, scale_boundaries=1.0, method='kd_tree', gpu=False, workers=1)

Saves the dataset for Segger in a processed format using Dask for parallel and lazy processing.

Parameters:

Name Type Description Default
processed_dir Path

Directory to save the processed dataset.

required
x_size float

Width of each tile.

1000
y_size float

Height of each tile.

1000
d_x float

Step size in the x direction for tiles.

900
d_y float

Step size in the y direction for tiles.

900
margin_x float

Margin in the x direction to include transcripts.

None
margin_y float

Margin in the y direction to include transcripts.

None
compute_labels bool

Whether to compute edge labels for tx_belongs_bd edges.

True
r_tx float

Radius for building the transcript-to-transcript graph.

5
k_tx int

Number of nearest neighbors for the tx-tx graph.

3
val_prob float

Probability of assigning a tile to the validation set.

0.1
test_prob float

Probability of assigning a tile to the test set.

0.2
neg_sampling_ratio_approx float

Approximate ratio of negative samples.

5
sampling_rate float

Rate of sampling tiles.

1
num_workers int

Number of workers to use for parallel processing.

1
scale_boundaries float

The factor by which to scale the boundary polygons. Default is 1.0.

1.0
method str

Method for computing edge indices (e.g., 'kd_tree', 'faiss').

'kd_tree'
gpu bool

Whether to use GPU acceleration for edge index computation.

False
workers int

Number of workers to use to compute the neighborhood graph (per tile).

1
Source code in src/segger/data/io.py
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
def save_dataset_for_segger(
    self,
    processed_dir: Path,
    x_size: float = 1000,
    y_size: float = 1000,
    d_x: float = 900,
    d_y: float = 900,
    margin_x: float = None,
    margin_y: float = None,
    compute_labels: bool = True,
    r_tx: float = 5,
    k_tx: int = 3,
    val_prob: float = 0.1,
    test_prob: float = 0.2,
    neg_sampling_ratio_approx: float = 5,
    sampling_rate: float = 1,
    num_workers: int = 1,
    scale_boundaries: float = 1.0,
    method: str = "kd_tree",
    gpu: bool = False,
    workers: int = 1,
) -> None:
    """
    Saves the dataset for Segger in a processed format using Dask for parallel and lazy processing.

    Parameters:
        processed_dir (Path): Directory to save the processed dataset.
        x_size (float, optional): Width of each tile.
        y_size (float, optional): Height of each tile.
        d_x (float, optional): Step size in the x direction for tiles.
        d_y (float, optional): Step size in the y direction for tiles.
        margin_x (float, optional): Margin in the x direction to include transcripts.
        margin_y (float, optional): Margin in the y direction to include transcripts.
        compute_labels (bool, optional): Whether to compute edge labels for tx_belongs_bd edges.
        r_tx (float, optional): Radius for building the transcript-to-transcript graph.
        k_tx (int, optional): Number of nearest neighbors for the tx-tx graph.
        val_prob (float, optional): Probability of assigning a tile to the validation set.
        test_prob (float, optional): Probability of assigning a tile to the test set.
        neg_sampling_ratio_approx (float, optional): Approximate ratio of negative samples.
        sampling_rate (float, optional): Rate of sampling tiles.
        num_workers (int, optional): Number of workers to use for parallel processing.
        scale_boundaries (float, optional): The factor by which to scale the boundary polygons. Default is 1.0.
        method (str, optional): Method for computing edge indices (e.g., 'kd_tree', 'faiss').
        gpu (bool, optional): Whether to use GPU acceleration for edge index computation.
        workers (int, optional): Number of workers to use to compute the neighborhood graph (per tile).

    """
    # Prepare directories for storing processed tiles
    self._prepare_directories(processed_dir)

    # Get x and y coordinate ranges for tiling
    x_range, y_range = self._get_ranges(d_x, d_y)

    # Generate parameters for each tile
    tile_params = self._generate_tile_params(
        x_range,
        y_range,
        x_size,
        y_size,
        margin_x,
        margin_y,
        compute_labels,
        r_tx,
        k_tx,
        val_prob,
        test_prob,
        neg_sampling_ratio_approx,
        sampling_rate,
        processed_dir,
        scale_boundaries,
        method,
        gpu,
        workers,
    )

    # Process each tile using Dask to parallelize the task
    if self.verbose:
        print("Starting tile processing...")
    tasks = [delayed(self._process_tile)(params) for params in tile_params]

    with ProgressBar():
        # Use Dask to process all tiles in parallel
        dask.compute(*tasks, num_workers=num_workers)
    if self.verbose:
        print("Tile processing completed.")

set_embedding

set_embedding(embedding_name)

Set the current embedding type for the transcripts.

Parameters:

Name Type Description Default
embedding_name

str The name of the embedding to use.

required
Source code in src/segger/data/io.py
402
403
404
405
406
407
408
409
410
411
412
413
414
def set_embedding(self, embedding_name: str) -> None:
    """
    Set the current embedding type for the transcripts.

    Parameters:
        embedding_name : str
            The name of the embedding to use.

    """
    if embedding_name in self.embeddings_dict:
        self.current_embedding = embedding_name
    else:
        raise ValueError(f"Embedding {embedding_name} not found in embeddings_dict.")

set_file_paths

set_file_paths(transcripts_path, boundaries_path)

Set the paths for the transcript and boundary files.

Parameters:

Name Type Description Default
transcripts_path Path

Path to the Parquet file containing transcripts data.

required
boundaries_path Path

Path to the Parquet file containing boundaries data.

required
Source code in src/segger/data/io.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def set_file_paths(self, transcripts_path: Path, boundaries_path: Path) -> None:
    """
    Set the paths for the transcript and boundary files.

    Parameters:
        transcripts_path (Path): Path to the Parquet file containing transcripts data.
        boundaries_path (Path): Path to the Parquet file containing boundaries data.
    """
    self.transcripts_path = transcripts_path
    self.boundaries_path = boundaries_path

    if self.verbose:
        print(f"Set transcripts file path to {transcripts_path}")
    if self.verbose:
        print(f"Set boundaries file path to {boundaries_path}")

set_metadata

set_metadata()

Set metadata for the transcript dataset, including bounding box limits and unique gene names, without reading the entire Parquet file. Additionally, return integer tokens for unique gene names instead of one-hot encodings and store the lookup table for later mapping.

Source code in src/segger/data/io.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
def set_metadata(self) -> None:
    """
    Set metadata for the transcript dataset, including bounding box limits and unique gene names,
    without reading the entire Parquet file. Additionally, return integer tokens for unique gene names
    instead of one-hot encodings and store the lookup table for later mapping.
    """
    # Load the Parquet file metadata
    parquet_file = pq.read_table(self.transcripts_path)

    # Get the column names for X, Y, and feature names from the class's keys
    x_col = self.keys.TRANSCRIPTS_X.value
    y_col = self.keys.TRANSCRIPTS_Y.value
    feature_col = self.keys.FEATURE_NAME.value

    # Initialize variables to track min/max values for X and Y
    x_min, x_max, y_min, y_max = float("inf"), float("-inf"), float("inf"), float("-inf")

    # Extract unique gene names and ensure they're strings
    gene_set = set()

    # Define the filter for unwanted codewords
    filter_codewords = (
        "NegControlProbe_",
        "antisense_",
        "NegControlCodeword_",
        "BLANK_",
        "DeprecatedCodeword_",
        "UnassignedCodeword_",
    )

    row_group_size = 4_000_000
    start = 0
    n = len(parquet_file)
    while start < n:
        chunk = parquet_file.slice(start, start + row_group_size)
        start += row_group_size

        # Update the bounding box values (min/max)
        x_values = chunk[x_col].to_pandas()
        y_values = chunk[y_col].to_pandas()

        x_min = min(x_min, x_values.min())
        x_max = max(x_max, x_values.max())
        y_min = min(y_min, y_values.min())
        y_max = max(y_max, y_values.max())

        # Convert feature values (gene names) to strings and filter out unwanted codewords
        feature_values = (
            chunk[feature_col]
            .to_pandas()
            .apply(
                lambda x: x.decode("utf-8") if isinstance(x, bytes) else str(x),
            )
        )

        # Filter out unwanted codewords
        filtered_genes = feature_values[~feature_values.str.startswith(filter_codewords)]

        # Update the unique gene set
        gene_set.update(filtered_genes.unique())

    # Set bounding box limits
    self.x_min = x_min
    self.x_max = x_max
    self.y_min = y_min
    self.y_max = y_max

    if self.verbose:
        print(
            f"Bounding box limits set: x_min={self.x_min}, x_max={self.x_max}, y_min={self.y_min}, y_max={self.y_max}"
        )

    # Convert the set of unique genes into a sorted list for consistent ordering
    self.unique_genes = sorted(gene_set)
    if self.verbose:
        print(f"Extracted {len(self.unique_genes)} unique gene names for integer tokenization.")

    # Initialize a LabelEncoder to convert unique genes into integer tokens
    self.tx_encoder = LabelEncoder()

    # Fit the LabelEncoder on the unique genes
    self.tx_encoder.fit(self.unique_genes)

    # Store the integer tokens mapping to gene names
    self.gene_to_token_map = dict(
        zip(self.tx_encoder.classes_, self.tx_encoder.transform(self.tx_encoder.classes_))
    )

    if self.verbose:
        print("Integer tokens have been computed and stored based on unique gene names.")

    # Optional: Create a reverse mapping for lookup purposes (token to gene)
    self.token_to_gene_map = {v: k for k, v in self.gene_to_token_map.items()}

    if self.verbose:
        print("Lookup tables (gene_to_token_map and token_to_gene_map) have been created.")

XeniumKeys

Bases: Enum

Keys for 10X Genomics Xenium formatted dataset.

XeniumSample

XeniumSample(transcripts_df=None, transcripts_radius=10, boundaries_graph=False, embedding_df=None, verbose=True)

Bases: SpatialTranscriptomicsSample

Source code in src/segger/data/io.py
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
def __init__(
    self,
    transcripts_df: dd.DataFrame = None,
    transcripts_radius: int = 10,
    boundaries_graph: bool = False,
    embedding_df: pd.DataFrame = None,
    verbose: bool = True,
):
    super().__init__(
        transcripts_df, transcripts_radius, boundaries_graph, embedding_df, XeniumKeys, verbose=verbose
    )

filter_transcripts

filter_transcripts(transcripts_df, min_qv=20.0)

Filters transcripts based on quality value and removes unwanted transcripts for Xenium using Dask.

Parameters:

Name Type Description Default
transcripts_df DataFrame

The Dask DataFrame containing transcript data.

required
min_qv float

The minimum quality value threshold for filtering transcripts.

20.0

Returns:

Type Description
DataFrame

dd.DataFrame: The filtered Dask DataFrame.

Source code in src/segger/data/io.py
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
def filter_transcripts(self, transcripts_df: dd.DataFrame, min_qv: float = 20.0) -> dd.DataFrame:
    """
    Filters transcripts based on quality value and removes unwanted transcripts for Xenium using Dask.

    Parameters:
        transcripts_df (dd.DataFrame): The Dask DataFrame containing transcript data.
        min_qv (float, optional): The minimum quality value threshold for filtering transcripts.

    Returns:
        dd.DataFrame: The filtered Dask DataFrame.
    """
    filter_codewords = (
        "NegControlProbe_",
        "antisense_",
        "NegControlCodeword_",
        "BLANK_",
        "DeprecatedCodeword_",
        "UnassignedCodeword_",
    )

    # Ensure FEATURE_NAME is a string type for proper filtering (compatible with Dask)
    # Handle potential bytes to string conversion for Dask DataFrame
    if pd.api.types.is_object_dtype(transcripts_df[self.keys.FEATURE_NAME.value]):
        transcripts_df[self.keys.FEATURE_NAME.value] = transcripts_df[self.keys.FEATURE_NAME.value].apply(
            lambda x: x.decode("utf-8") if isinstance(x, bytes) else x
        )

    # Apply the quality value filter using Dask
    mask_quality = transcripts_df[self.keys.QUALITY_VALUE.value] >= min_qv

    # Apply the filter for unwanted codewords using Dask string functions
    mask_codewords = ~transcripts_df[self.keys.FEATURE_NAME.value].str.startswith(filter_codewords)

    # Combine the filters and return the filtered Dask DataFrame
    mask = mask_quality & mask_codewords

    # Return the filtered DataFrame lazily
    return transcripts_df[mask]

calculate_gene_celltype_abundance_embedding

calculate_gene_celltype_abundance_embedding(adata, celltype_column)

Calculate the cell type abundance embedding for each gene based on the fraction of cells in each cell type that express the gene (non-zero expression).

Parameters:

Name Type Description Default
adata AnnData

An AnnData object containing gene expression data and cell type information.

required
celltype_column str

The column name in adata.obs that contains the cell type information.

required

Returns:

Type Description
DataFrame

pd.DataFrame: A DataFrame where rows are genes and columns are cell types, with each value representing the fraction of cells in that cell type expressing the gene.

Example

adata = AnnData(...) # Load your scRNA-seq AnnData object celltype_column = 'celltype_major' abundance_df = calculate_gene_celltype_abundance_embedding(adata, celltype_column) abundance_df.head()

Source code in src/segger/data/utils.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def calculate_gene_celltype_abundance_embedding(adata: ad.AnnData, celltype_column: str) -> pd.DataFrame:
    """Calculate the cell type abundance embedding for each gene based on the fraction of cells in each cell type
    that express the gene (non-zero expression).

    Parameters:
        adata (ad.AnnData): An AnnData object containing gene expression data and cell type information.
        celltype_column (str): The column name in `adata.obs` that contains the cell type information.

    Returns:
        pd.DataFrame: A DataFrame where rows are genes and columns are cell types, with each value representing
            the fraction of cells in that cell type expressing the gene.

    Example:
        >>> adata = AnnData(...)  # Load your scRNA-seq AnnData object
        >>> celltype_column = 'celltype_major'
        >>> abundance_df = calculate_gene_celltype_abundance_embedding(adata, celltype_column)
        >>> abundance_df.head()
    """
    # Extract expression data (cells x genes) and cell type information (cells)
    expression_data = adata.X.toarray() if hasattr(adata.X, "toarray") else adata.X
    cell_types = adata.obs[celltype_column].values
    # Create a binary matrix for gene expression (1 if non-zero, 0 otherwise)
    gene_expression_binary = (expression_data > 0).astype(int)
    # Convert the binary matrix to a DataFrame
    gene_expression_df = pd.DataFrame(gene_expression_binary, index=adata.obs_names, columns=adata.var_names)
    # Perform one-hot encoding on the cell types
    encoder = OneHotEncoder(sparse_output=False)
    cell_type_encoded = encoder.fit_transform(cell_types.reshape(-1, 1))
    # Calculate the fraction of cells expressing each gene per cell type
    cell_type_abundance_list = []
    for i in range(cell_type_encoded.shape[1]):
        # Extract cells of the current cell type
        cell_type_mask = cell_type_encoded[:, i] == 1
        # Calculate the abundance: sum of non-zero expressions in this cell type / total cells in this cell type
        abundance = gene_expression_df[cell_type_mask].mean(axis=0)
        cell_type_abundance_list.append(abundance)
    # Create a DataFrame for the cell type abundance with gene names as rows and cell types as columns
    cell_type_abundance_df = pd.DataFrame(
        cell_type_abundance_list, columns=adata.var_names, index=encoder.categories_[0]
    ).T
    return cell_type_abundance_df

compute_transcript_metrics

compute_transcript_metrics(df, qv_threshold=30, cell_id_col='cell_id')

Computes various metrics for a given dataframe of transcript data filtered by quality value threshold.

Parameters:

Name Type Description Default
df DataFrame

The dataframe containing transcript data.

required
qv_threshold float

The quality value threshold for filtering transcripts.

30
cell_id_col str

The name of the column representing the cell ID.

'cell_id'

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing various transcript metrics: - 'percent_assigned' (float): The percentage of assigned transcripts. - 'percent_cytoplasmic' (float): The percentage of cytoplasmic transcripts among assigned transcripts. - 'percent_nucleus' (float): The percentage of nucleus transcripts among assigned transcripts. - 'percent_non_assigned_cytoplasmic' (float): The percentage of non-assigned cytoplasmic transcripts. - 'gene_metrics' (pd.DataFrame): A dataframe containing gene-level metrics.

Source code in src/segger/data/utils.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def compute_transcript_metrics(
    df: pd.DataFrame, qv_threshold: float = 30, cell_id_col: str = "cell_id"
) -> Dict[str, Any]:
    """
    Computes various metrics for a given dataframe of transcript data filtered by quality value threshold.

    Parameters:
        df (pd.DataFrame): The dataframe containing transcript data.
        qv_threshold (float): The quality value threshold for filtering transcripts.
        cell_id_col (str): The name of the column representing the cell ID.

    Returns:
        Dict[str, Any]: A dictionary containing various transcript metrics:
            - 'percent_assigned' (float): The percentage of assigned transcripts.
            - 'percent_cytoplasmic' (float): The percentage of cytoplasmic transcripts among assigned transcripts.
            - 'percent_nucleus' (float): The percentage of nucleus transcripts among assigned transcripts.
            - 'percent_non_assigned_cytoplasmic' (float): The percentage of non-assigned cytoplasmic transcripts.
            - 'gene_metrics' (pd.DataFrame): A dataframe containing gene-level metrics.
    """
    df_filtered = df[df["qv"] > qv_threshold]
    total_transcripts = len(df_filtered)
    assigned_transcripts = df_filtered[df_filtered[cell_id_col] != -1]
    percent_assigned = len(assigned_transcripts) / (total_transcripts + 1) * 100
    cytoplasmic_transcripts = assigned_transcripts[assigned_transcripts["overlaps_nucleus"] != 1]
    percent_cytoplasmic = len(cytoplasmic_transcripts) / (len(assigned_transcripts) + 1) * 100
    percent_nucleus = 100 - percent_cytoplasmic
    non_assigned_transcripts = df_filtered[df_filtered[cell_id_col] == -1]
    non_assigned_cytoplasmic = non_assigned_transcripts[non_assigned_transcripts["overlaps_nucleus"] != 1]
    percent_non_assigned_cytoplasmic = len(non_assigned_cytoplasmic) / (len(non_assigned_transcripts) + 1) * 100
    gene_group_assigned = assigned_transcripts.groupby("feature_name")
    gene_group_all = df_filtered.groupby("feature_name")
    gene_percent_assigned = (gene_group_assigned.size() / (gene_group_all.size() + 1) * 100).reset_index(
        names="percent_assigned"
    )
    cytoplasmic_gene_group = cytoplasmic_transcripts.groupby("feature_name")
    gene_percent_cytoplasmic = (cytoplasmic_gene_group.size() / (len(cytoplasmic_transcripts) + 1) * 100).reset_index(
        name="percent_cytoplasmic"
    )
    gene_metrics = pd.merge(gene_percent_assigned, gene_percent_cytoplasmic, on="feature_name", how="outer").fillna(0)
    results = {
        "percent_assigned": percent_assigned,
        "percent_cytoplasmic": percent_cytoplasmic,
        "percent_nucleus": percent_nucleus,
        "percent_non_assigned_cytoplasmic": percent_non_assigned_cytoplasmic,
        "gene_metrics": gene_metrics,
    }
    return results

create_anndata

create_anndata(df, panel_df=None, min_transcripts=5, cell_id_col='cell_id', qv_threshold=30, min_cell_area=10.0, max_cell_area=1000.0)

Generates an AnnData object from a dataframe of segmented transcriptomics data.

Parameters:

Name Type Description Default
df DataFrame

The dataframe containing segmented transcriptomics data.

required
panel_df Optional[DataFrame]

The dataframe containing panel information.

None
min_transcripts int

The minimum number of transcripts required for a cell to be included.

5
cell_id_col str

The column name representing the cell ID in the input dataframe.

'cell_id'
qv_threshold float

The quality value threshold for filtering transcripts.

30
min_cell_area float

The minimum cell area to include a cell.

10.0
max_cell_area float

The maximum cell area to include a cell.

1000.0

Returns:

Type Description
AnnData

ad.AnnData: The generated AnnData object containing the transcriptomics data and metadata.

Source code in src/segger/data/utils.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def create_anndata(
    df: pd.DataFrame,
    panel_df: Optional[pd.DataFrame] = None,
    min_transcripts: int = 5,
    cell_id_col: str = "cell_id",
    qv_threshold: float = 30,
    min_cell_area: float = 10.0,
    max_cell_area: float = 1000.0,
) -> ad.AnnData:
    """
    Generates an AnnData object from a dataframe of segmented transcriptomics data.

    Parameters:
        df (pd.DataFrame): The dataframe containing segmented transcriptomics data.
        panel_df (Optional[pd.DataFrame]): The dataframe containing panel information.
        min_transcripts (int): The minimum number of transcripts required for a cell to be included.
        cell_id_col (str): The column name representing the cell ID in the input dataframe.
        qv_threshold (float): The quality value threshold for filtering transcripts.
        min_cell_area (float): The minimum cell area to include a cell.
        max_cell_area (float): The maximum cell area to include a cell.

    Returns:
        ad.AnnData: The generated AnnData object containing the transcriptomics data and metadata.
    """
    # df_filtered = filter_transcripts(df, min_qv=qv_threshold)
    df_filtered = df
    # metrics = compute_transcript_metrics(df_filtered, qv_threshold, cell_id_col)
    df_filtered = df_filtered[df_filtered[cell_id_col].astype(str) != "-1"]
    pivot_df = df_filtered.rename(columns={cell_id_col: "cell", "feature_name": "gene"})[["cell", "gene"]].pivot_table(
        index="cell", columns="gene", aggfunc="size", fill_value=0
    )
    pivot_df = pivot_df[pivot_df.sum(axis=1) >= min_transcripts]
    cell_summary = []
    for cell_id, cell_data in df_filtered.groupby(cell_id_col):
        if len(cell_data) < min_transcripts:
            continue
        cell_convex_hull = ConvexHull(cell_data[["x_location", "y_location"]], qhull_options="QJ")
        cell_area = cell_convex_hull.area
        if cell_area < min_cell_area or cell_area > max_cell_area:
            continue
        # if 'nucleus_distance' in cell_data:
        #     nucleus_data = cell_data[cell_data['nucleus_distance'] == 0]
        # else:
        #     nucleus_data = cell_data[cell_data['overlaps_nucleus'] == 1]
        # if len(nucleus_data) >= 3:
        #     nucleus_convex_hull = ConvexHull(nucleus_data[['x_location', 'y_location']])
        # else:
        #     nucleus_convex_hull = None
        cell_summary.append(
            {
                "cell": cell_id,
                "cell_centroid_x": cell_data["x_location"].mean(),
                "cell_centroid_y": cell_data["y_location"].mean(),
                "cell_area": cell_area,
                # "nucleus_centroid_x": nucleus_data['x_location'].mean() if len(nucleus_data) > 0 else cell_data['x_location'].mean(),
                # "nucleus_centroid_y": nucleus_data['x_location'].mean() if len(nucleus_data) > 0 else cell_data['x_location'].mean(),
                # "nucleus_area": nucleus_convex_hull.area if nucleus_convex_hull else 0,
                # "percent_cytoplasmic": len(cell_data[cell_data['overlaps_nucleus'] != 1]) / len(cell_data) * 100,
                # "has_nucleus": len(nucleus_data) > 0
            }
        )
    cell_summary = pd.DataFrame(cell_summary).set_index("cell")
    if panel_df is not None:
        panel_df = panel_df.sort_values("gene")
        genes = panel_df["gene"].values
        for gene in genes:
            if gene not in pivot_df:
                pivot_df[gene] = 0
        pivot_df = pivot_df[genes.tolist()]
    if panel_df is None:
        var_df = pd.DataFrame(
            [
                {"gene": i, "feature_types": "Gene Expression", "genome": "Unknown"}
                for i in np.unique(pivot_df.columns.values)
            ]
        ).set_index("gene")
    else:
        var_df = panel_df[["gene", "ensembl"]].rename(columns={"ensembl": "gene_ids"})
        var_df["feature_types"] = "Gene Expression"
        var_df["genome"] = "Unknown"
        var_df = var_df.set_index("gene")
    # gene_metrics = metrics['gene_metrics'].set_index('feature_name')
    # var_df = var_df.join(gene_metrics, how='left').fillna(0)
    cells = list(set(pivot_df.index) & set(cell_summary.index))
    pivot_df = pivot_df.loc[cells, :]
    cell_summary = cell_summary.loc[cells, :]
    adata = ad.AnnData(pivot_df.values)
    adata.var = var_df
    adata.obs["transcripts"] = pivot_df.sum(axis=1).values
    adata.obs["unique_transcripts"] = (pivot_df > 0).sum(axis=1).values
    adata.obs_names = pivot_df.index.values.tolist()
    adata.obs = pd.merge(adata.obs, cell_summary.loc[adata.obs_names, :], left_index=True, right_index=True)
    # adata.uns['metrics'] = {
    #     'percent_assigned': metrics['percent_assigned'],
    #     'percent_cytoplasmic': metrics['percent_cytoplasmic'],
    #     'percent_nucleus': metrics['percent_nucleus'],
    #     'percent_non_assigned_cytoplasmic': metrics['percent_non_assigned_cytoplasmic']
    # }
    return adata

filter_transcripts

filter_transcripts(transcripts_df, min_qv=20.0)

Filters transcripts based on quality value and removes unwanted transcripts.

Parameters:

Name Type Description Default
transcripts_df DataFrame

The dataframe containing transcript data.

required
min_qv float

The minimum quality value threshold for filtering transcripts.

20.0

Returns:

Type Description
DataFrame

pd.DataFrame: The filtered dataframe.

Source code in src/segger/data/utils.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def filter_transcripts(
    transcripts_df: pd.DataFrame,
    min_qv: float = 20.0,
) -> pd.DataFrame:
    """
    Filters transcripts based on quality value and removes unwanted transcripts.

    Parameters:
        transcripts_df (pd.DataFrame): The dataframe containing transcript data.
        min_qv (float): The minimum quality value threshold for filtering transcripts.

    Returns:
        pd.DataFrame: The filtered dataframe.
    """
    filter_codewords = (
        "NegControlProbe_",
        "antisense_",
        "NegControlCodeword_",
        "BLANK_",
        "DeprecatedCodeword_",
        "UnassignedCodeword_",
    )
    mask = transcripts_df["qv"].ge(min_qv)
    mask &= ~transcripts_df["feature_name"].str.startswith(filter_codewords)
    return transcripts_df[mask]

format_time

format_time(elapsed)

Format elapsed time to hⓂs.

Parameters:

elapsed : float Elapsed time in seconds.

Returns:

str Formatted time in hⓂs.

Source code in src/segger/data/utils.py
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
def format_time(elapsed: float) -> str:
    """
    Format elapsed time to h:m:s.

    Parameters:
    ----------
    elapsed : float
        Elapsed time in seconds.

    Returns:
    -------
    str
        Formatted time in h:m:s.
    """
    return str(timedelta(seconds=int(elapsed)))

get_edge_index

get_edge_index(coords_1, coords_2, k=5, dist=10, method='kd_tree', workers=1)

Computes edge indices using KD-Tree.

Parameters:

Name Type Description Default
coords_1 ndarray

First set of coordinates.

required
coords_2 ndarray

Second set of coordinates.

required
k int

Number of nearest neighbors.

5
dist int

Distance threshold.

10
method str

The method to use. Only 'kd_tree' is supported now.

'kd_tree'

Returns:

Type Description
Tensor

torch.Tensor: Edge indices.

Source code in src/segger/data/utils.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def get_edge_index(
    coords_1: np.ndarray,
    coords_2: np.ndarray,
    k: int = 5,
    dist: int = 10,
    method: str = "kd_tree",
    workers: int = 1,
) -> torch.Tensor:
    """
    Computes edge indices using KD-Tree.

    Parameters:
        coords_1 (np.ndarray): First set of coordinates.
        coords_2 (np.ndarray): Second set of coordinates.
        k (int, optional): Number of nearest neighbors.
        dist (int, optional): Distance threshold.
        method (str, optional): The method to use. Only 'kd_tree' is supported now.

    Returns:
        torch.Tensor: Edge indices.
    """
    if method == "kd_tree":
        return get_edge_index_kdtree(coords_1, coords_2, k=k, dist=dist, workers=workers)
    # elif method == "cuda":
    #     return get_edge_index_cuda(coords_1, coords_2, k=k, dist=dist)
    else:
        msg = f"Unknown method {method}. The only supported method is 'kd_tree' now."
        raise ValueError(msg)

get_edge_index_kdtree

get_edge_index_kdtree(coords_1, coords_2, k=5, dist=10, workers=1)

Computes edge indices using KDTree.

Parameters:

Name Type Description Default
coords_1 ndarray

First set of coordinates.

required
coords_2 ndarray

Second set of coordinates.

required
k int

Number of nearest neighbors.

5
dist int

Distance threshold.

10

Returns:

Type Description
Tensor

torch.Tensor: Edge indices.

Source code in src/segger/data/utils.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def get_edge_index_kdtree(
    coords_1: np.ndarray, coords_2: np.ndarray, k: int = 5, dist: int = 10, workers: int = 1
) -> torch.Tensor:
    """
    Computes edge indices using KDTree.

    Parameters:
        coords_1 (np.ndarray): First set of coordinates.
        coords_2 (np.ndarray): Second set of coordinates.
        k (int, optional): Number of nearest neighbors.
        dist (int, optional): Distance threshold.

    Returns:
        torch.Tensor: Edge indices.
    """
    if isinstance(coords_1, torch.Tensor):
        coords_1 = coords_1.cpu().numpy()
    if isinstance(coords_2, torch.Tensor):
        coords_2 = coords_2.cpu().numpy()
    tree = cKDTree(coords_1)
    d_kdtree, idx_out = tree.query(coords_2, k=k, distance_upper_bound=dist, workers=workers)
    valid_mask = d_kdtree < dist
    edges = []

    for idx, valid in enumerate(valid_mask):
        valid_indices = idx_out[idx][valid]
        if valid_indices.size > 0:
            edges.append(np.vstack((np.full(valid_indices.shape, idx), valid_indices)).T)

    edge_index = torch.tensor(np.vstack(edges), dtype=torch.long).contiguous()
    return edge_index

get_xy_extents

get_xy_extents(filepath, x, y)

Get the bounding box of the x and y coordinates from a Parquet file.

Parameters

filepath : str The path to the Parquet file. x : str The name of the column representing the x-coordinate. y : str The name of the column representing the y-coordinate.

Returns

shapely.Polygon A polygon representing the bounding box of the x and y coordinates.

Source code in src/segger/data/utils.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
def get_xy_extents(
    filepath,
    x: str,
    y: str,
) -> Tuple[int]:
    """
    Get the bounding box of the x and y coordinates from a Parquet file.

    Parameters
    ----------
    filepath : str
        The path to the Parquet file.
    x : str
        The name of the column representing the x-coordinate.
    y : str
        The name of the column representing the y-coordinate.

    Returns
    -------
    shapely.Polygon
        A polygon representing the bounding box of the x and y coordinates.
    """
    # Get index of columns of parquet file
    metadata = pq.read_metadata(filepath)
    schema_idx = dict(map(reversed, enumerate(metadata.schema.names)))

    # Find min and max values across all row groups
    x_max = -1
    x_min = sys.maxsize
    y_max = -1
    y_min = sys.maxsize
    for i in range(metadata.num_row_groups):
        group = metadata.row_group(i)
        x_min = min(x_min, group.column(schema_idx[x]).statistics.min)
        x_max = max(x_max, group.column(schema_idx[x]).statistics.max)
        y_min = min(y_min, group.column(schema_idx[y]).statistics.min)
        y_max = max(y_max, group.column(schema_idx[y]).statistics.max)
    return x_min, y_min, x_max, y_max