Skip to content

segger.validation

This module handles validation utilities for the Segger tool.

Submodules

API Documentation

annotate_query_with_reference

annotate_query_with_reference(reference_adata, query_adata, transfer_column)

Annotate query AnnData object using a scRNA-seq reference atlas.

  • reference_adata: ad.AnnData Reference AnnData object containing the scRNA-seq atlas data.
  • query_adata: ad.AnnData Query AnnData object containing the data to be annotated.
  • transfer_column: str The name of the column in the reference atlas's obs to transfer to the query dataset.
  • query_adata: ad.AnnData Annotated query AnnData object with transferred labels and UMAP coordinates from the reference.
Source code in src/segger/validation/utils.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def annotate_query_with_reference(
    reference_adata: ad.AnnData, query_adata: ad.AnnData, transfer_column: str
) -> ad.AnnData:
    """Annotate query AnnData object using a scRNA-seq reference atlas.

    Args:
    - reference_adata: ad.AnnData
        Reference AnnData object containing the scRNA-seq atlas data.
    - query_adata: ad.AnnData
        Query AnnData object containing the data to be annotated.
    - transfer_column: str
        The name of the column in the reference atlas's `obs` to transfer to the query dataset.

    Returns:
    - query_adata: ad.AnnData
        Annotated query AnnData object with transferred labels and UMAP coordinates from the reference.
    """
    common_genes = list(set(reference_adata.var_names) & set(query_adata.var_names))
    reference_adata = reference_adata[:, common_genes]
    query_adata = query_adata[:, common_genes]
    query_adata.layers["raw"] = query_adata.raw.X if query_adata.raw else query_adata.X
    query_adata.var["raw_counts"] = query_adata.layers["raw"].sum(axis=0)
    sc.pp.normalize_total(query_adata, target_sum=1e4)
    sc.pp.log1p(query_adata)
    sc.pp.pca(reference_adata)
    sc.pp.neighbors(reference_adata)
    sc.tl.umap(reference_adata)
    sc.tl.ingest(query_adata, reference_adata, obs=transfer_column)
    query_adata.obsm["X_umap"] = query_adata.obsm["X_umap"]
    return query_adata

calculate_contamination

calculate_contamination(adata, markers, radius=15, n_neighs=10, celltype_column='celltype_major', num_cells=10000)

Calculate normalized contamination from neighboring cells of different cell types based on positive markers.

  • adata: ad.AnnData Annotated data object with raw counts and cell type information.
  • markers: dict Dictionary where keys are cell types and values are dictionaries containing: 'positive': list of top x% highly expressed genes 'negative': list of top x% lowly expressed genes.
  • radius: float, default=15 Radius for spatial neighbor calculation.
  • n_neighs: int, default=10 Maximum number of neighbors to consider.
  • celltype_column: str, default='celltype_major' Column name in the AnnData object representing cell types.
  • num_cells: int, default=10000 Number of cells to randomly select for the calculation.
  • contamination_df: pd.DataFrame DataFrame containing the normalized level of contamination from each cell type to each other cell type.
Source code in src/segger/validation/utils.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
def calculate_contamination(
    adata: ad.AnnData,
    markers: Dict[str, Dict[str, List[str]]],
    radius: float = 15,
    n_neighs: int = 10,
    celltype_column: str = "celltype_major",
    num_cells: int = 10000,
) -> pd.DataFrame:
    """Calculate normalized contamination from neighboring cells of different cell types based on positive markers.

    Args:
    - adata: ad.AnnData
        Annotated data object with raw counts and cell type information.
    - markers: dict
        Dictionary where keys are cell types and values are dictionaries containing:
            'positive': list of top x% highly expressed genes
            'negative': list of top x% lowly expressed genes.
    - radius: float, default=15
        Radius for spatial neighbor calculation.
    - n_neighs: int, default=10
        Maximum number of neighbors to consider.
    - celltype_column: str, default='celltype_major'
        Column name in the AnnData object representing cell types.
    - num_cells: int, default=10000
        Number of cells to randomly select for the calculation.

    Returns:
    - contamination_df: pd.DataFrame
        DataFrame containing the normalized level of contamination from each cell type to each other cell type.
    """
    if celltype_column not in adata.obs:
        raise ValueError("Column celltype_column must be present in adata.obs.")
    positive_markers = {ct: markers[ct]["positive"] for ct in markers}
    adata.obsm["spatial"] = adata.obs[["cell_centroid_x", "cell_centroid_y"]].copy().to_numpy()
    sq.gr.spatial_neighbors(adata, radius=radius, n_neighs=n_neighs, coord_type="generic")
    neighbors = adata.obsp["spatial_connectivities"].tolil()
    raw_counts = adata[:, adata.var_names].layers["raw"].toarray()
    cell_types = adata.obs[celltype_column]
    selected_cells = np.random.choice(adata.n_obs, size=min(num_cells, adata.n_obs), replace=False)
    contamination = {ct: {ct2: 0 for ct2 in positive_markers.keys()} for ct in positive_markers.keys()}
    negighborings = {ct: {ct2: 0 for ct2 in positive_markers.keys()} for ct in positive_markers.keys()}
    for cell_idx in selected_cells:
        cell_type = cell_types[cell_idx]
        own_markers = set(positive_markers[cell_type])
        for marker in own_markers:
            if marker in adata.var_names:
                total_counts_in_neighborhood = raw_counts[cell_idx, adata.var_names.get_loc(marker)]
                for neighbor_idx in neighbors.rows[cell_idx]:
                    total_counts_in_neighborhood += raw_counts[neighbor_idx, adata.var_names.get_loc(marker)]
                for neighbor_idx in neighbors.rows[cell_idx]:
                    neighbor_type = cell_types[neighbor_idx]
                    if cell_type == neighbor_type:
                        continue
                    neighbor_markers = set(positive_markers.get(neighbor_type, []))
                    contamination_markers = own_markers - neighbor_markers
                    for marker in contamination_markers:
                        if marker in adata.var_names:
                            marker_counts_in_neighbor = raw_counts[neighbor_idx, adata.var_names.get_loc(marker)]
                            if total_counts_in_neighborhood > 0:
                                contamination[cell_type][neighbor_type] += (
                                    marker_counts_in_neighbor / total_counts_in_neighborhood
                                )
                                negighborings[cell_type][neighbor_type] += 1
    contamination_df = pd.DataFrame(contamination).T
    negighborings_df = pd.DataFrame(negighborings).T
    contamination_df.index.name = "Source Cell Type"
    contamination_df.columns.name = "Target Cell Type"
    return contamination_df / (negighborings_df + 1)

calculate_sensitivity

calculate_sensitivity(adata, purified_markers, max_cells_per_type=1000)

Calculate the sensitivity of the purified markers for each cell type.

  • adata: AnnData Annotated data object containing gene expression data.
  • purified_markers: dict Dictionary where keys are cell types and values are lists of purified marker genes.
  • max_cells_per_type: int, default=1000 Maximum number of cells to consider per cell type.
  • sensitivity_results: dict Dictionary with cell types as keys and lists of sensitivity values for each cell.
Source code in src/segger/validation/utils.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
def calculate_sensitivity(
    adata: ad.AnnData, purified_markers: Dict[str, List[str]], max_cells_per_type: int = 1000
) -> Dict[str, List[float]]:
    """Calculate the sensitivity of the purified markers for each cell type.

    Args:
    - adata: AnnData
        Annotated data object containing gene expression data.
    - purified_markers: dict
        Dictionary where keys are cell types and values are lists of purified marker genes.
    - max_cells_per_type: int, default=1000
        Maximum number of cells to consider per cell type.

    Returns:
    - sensitivity_results: dict
        Dictionary with cell types as keys and lists of sensitivity values for each cell.
    """
    sensitivity_results = {cell_type: [] for cell_type in purified_markers.keys()}
    for cell_type, markers in purified_markers.items():
        markers = markers["positive"]
        subset = adata[adata.obs["celltype_major"] == cell_type]
        if subset.n_obs > max_cells_per_type:
            cell_indices = np.random.choice(subset.n_obs, max_cells_per_type, replace=False)
            subset = subset[cell_indices]
        for cell_counts in subset.X:
            expressed_markers = np.asarray((cell_counts[subset.var_names.get_indexer(markers)] > 0).sum())
            sensitivity = expressed_markers / len(markers) if markers else 0
            sensitivity_results[cell_type].append(sensitivity)
    return sensitivity_results

compute_MECR

compute_MECR(adata, gene_pairs)

Compute the Mutually Exclusive Co-expression Rate (MECR) for each gene pair in an AnnData object.

  • adata: AnnData Annotated data object containing gene expression data.
  • gene_pairs: List[Tuple[str, str]] List of tuples representing gene pairs to evaluate.
  • mecr_dict: Dict[Tuple[str, str], float] Dictionary where keys are gene pairs (tuples) and values are MECR values.
Source code in src/segger/validation/utils.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def compute_MECR(adata: ad.AnnData, gene_pairs: List[Tuple[str, str]]) -> Dict[Tuple[str, str], float]:
    """Compute the Mutually Exclusive Co-expression Rate (MECR) for each gene pair in an AnnData object.

    Args:
    - adata: AnnData
        Annotated data object containing gene expression data.
    - gene_pairs: List[Tuple[str, str]]
        List of tuples representing gene pairs to evaluate.

    Returns:
    - mecr_dict: Dict[Tuple[str, str], float]
        Dictionary where keys are gene pairs (tuples) and values are MECR values.
    """
    mecr_dict = {}
    gene_expression = adata.to_df()
    for gene1, gene2 in gene_pairs:
        expr_gene1 = gene_expression[gene1] > 0
        expr_gene2 = gene_expression[gene2] > 0
        both_expressed = (expr_gene1 & expr_gene2).mean()
        at_least_one_expressed = (expr_gene1 | expr_gene2).mean()
        mecr = both_expressed / at_least_one_expressed if at_least_one_expressed > 0 else 0
        mecr_dict[(gene1, gene2)] = mecr
    return mecr_dict

compute_clustering_scores

compute_clustering_scores(adata, cell_type_column='celltype_major', use_pca=True)

Compute the Calinski-Harabasz and Silhouette scores for an AnnData object based on the assigned cell types.

  • adata: AnnData Annotated data object containing gene expression data and cell type assignments.
  • cell_type_column: str, default='celltype_major' Column name in adata.obs that specifies cell types.
  • use_pca: bool, default=True Whether to use PCA components as features. If False, use the raw data.
  • ch_score: float The Calinski-Harabasz score.
  • sh_score: float The Silhouette score.
Source code in src/segger/validation/utils.py
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
def compute_clustering_scores(
    adata: ad.AnnData, cell_type_column: str = "celltype_major", use_pca: bool = True
) -> Tuple[float, float]:
    """Compute the Calinski-Harabasz and Silhouette scores for an AnnData object based on the assigned cell types.

    Args:
    - adata: AnnData
        Annotated data object containing gene expression data and cell type assignments.
    - cell_type_column: str, default='celltype_major'
        Column name in `adata.obs` that specifies cell types.
    - use_pca: bool, default=True
        Whether to use PCA components as features. If False, use the raw data.

    Returns:
    - ch_score: float
        The Calinski-Harabasz score.
    - sh_score: float
        The Silhouette score.
    """
    if cell_type_column not in adata.obs:
        raise ValueError(f"Column '{cell_type_column}' must be present in adata.obs.")
    features = adata.X
    cell_indices = np.random.choice(adata.n_obs, 10000, replace=False)
    features = features[cell_indices, :]
    labels = adata[cell_indices, :].obs[cell_type_column]
    ch_score = calinski_harabasz_score(features, labels)
    sh_score = silhouette_score(features, labels)
    return ch_score, sh_score

compute_neighborhood_metrics

compute_neighborhood_metrics(adata, radius=10, celltype_column='celltype_major', n_neighs=20, subset_size=10000)

Compute neighborhood entropy and number of neighbors for each cell in the AnnData object.

  • adata: AnnData Annotated data object containing spatial information and cell type assignments.
  • radius: int, default=10 Radius for spatial neighbor calculation.
  • celltype_column: str, default='celltype_major' Column name in adata.obs that specifies cell types.
Source code in src/segger/validation/utils.py
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
def compute_neighborhood_metrics(
    adata: ad.AnnData,
    radius: float = 10,
    celltype_column: str = "celltype_major",
    n_neighs: int = 20,
    subset_size: int = 10000,
) -> None:
    """Compute neighborhood entropy and number of neighbors for each cell in the AnnData object.

    Args:
    - adata: AnnData
        Annotated data object containing spatial information and cell type assignments.
    - radius: int, default=10
        Radius for spatial neighbor calculation.
    - celltype_column: str, default='celltype_major'
        Column name in `adata.obs` that specifies cell types.
    """
    """
    Compute neighborhood entropy and number of neighbors for a random subset of cells in the AnnData object.

    Args:
    - adata: AnnData
        Annotated data object containing spatial information and cell type assignments.
    - radius: int, default=10
        Radius for spatial neighbor calculation.
    - celltype_column: str, default='celltype_major'
        Column name in `adata.obs` that specifies cell types.
    - subset_size: int, default=10000
        Number of cells to randomly select for the calculation.
    """
    # Ensure the subset size does not exceed the number of cells
    subset_size = min(subset_size, adata.n_obs)
    # Randomly select a subset of cells
    subset_indices = np.random.choice(adata.n_obs, subset_size, replace=False)
    # Compute spatial neighbors for the entire dataset
    sq.gr.spatial_neighbors(adata, radius=radius, coord_type="generic", n_neighs=n_neighs)
    neighbors = adata.obsp["spatial_distances"].tolil().rows
    entropies = []
    num_neighbors = []
    # Calculate entropy and number of neighbors only for the selected subset
    for cell_index in subset_indices:
        neighbor_indices = neighbors[cell_index]
        neighbor_cell_types = adata.obs[celltype_column].iloc[neighbor_indices]
        cell_type_counts = neighbor_cell_types.value_counts()
        total_neighbors = len(neighbor_cell_types)
        num_neighbors.append(total_neighbors)
        if total_neighbors > 0:
            cell_type_probs = cell_type_counts / total_neighbors
            cell_type_entropy = entropy(cell_type_probs)
            entropies.append(cell_type_entropy)
        else:
            entropies.append(0)
    # Store the results back into the original AnnData object
    # We fill with NaN for cells not in the subset
    entropy_full = np.full(adata.n_obs, np.nan)
    neighbors_full = np.full(adata.n_obs, np.nan)
    entropy_full[subset_indices] = entropies
    neighbors_full[subset_indices] = num_neighbors
    adata.obs["neighborhood_entropy"] = entropy_full
    adata.obs["number_of_neighbors"] = neighbors_full

compute_quantized_mecr_area

compute_quantized_mecr_area(adata, gene_pairs, quantiles=10)

Compute the average MECR, variance of MECR, and average cell area for quantiles of cell areas.

  • adata: AnnData Annotated data object containing gene expression data.
  • gene_pairs: List[Tuple[str, str]] List of tuples representing gene pairs to evaluate.
  • quantiles: int, default=10 Number of quantiles to divide the data into.
  • quantized_data: pd.DataFrame DataFrame containing quantile information, average MECR, variance of MECR, average area, and number of cells.
Source code in src/segger/validation/utils.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def compute_quantized_mecr_area(
    adata: sc.AnnData, gene_pairs: List[Tuple[str, str]], quantiles: int = 10
) -> pd.DataFrame:
    """Compute the average MECR, variance of MECR, and average cell area for quantiles of cell areas.

    Args:
    - adata: AnnData
        Annotated data object containing gene expression data.
    - gene_pairs: List[Tuple[str, str]]
        List of tuples representing gene pairs to evaluate.
    - quantiles: int, default=10
        Number of quantiles to divide the data into.

    Returns:
    - quantized_data: pd.DataFrame
        DataFrame containing quantile information, average MECR, variance of MECR, average area, and number of cells.
    """
    adata.obs["quantile"] = pd.qcut(adata.obs["cell_area"], quantiles, labels=False)
    quantized_data = []
    for quantile in range(quantiles):
        cells_in_quantile = adata.obs["quantile"] == quantile
        mecr = compute_MECR(adata[cells_in_quantile, :], gene_pairs)
        average_mecr = np.mean([i for i in mecr.values()])
        variance_mecr = np.var([i for i in mecr.values()])
        average_area = adata.obs.loc[cells_in_quantile, "cell_area"].mean()
        quantized_data.append(
            {
                "quantile": quantile / quantiles,
                "average_mecr": average_mecr,
                "variance_mecr": variance_mecr,
                "average_area": average_area,
                "num_cells": cells_in_quantile.sum(),
            }
        )
    return pd.DataFrame(quantized_data)

compute_quantized_mecr_counts

compute_quantized_mecr_counts(adata, gene_pairs, quantiles=10)

Compute the average MECR, variance of MECR, and average transcript counts for quantiles of transcript counts.

  • adata: AnnData Annotated data object containing gene expression data.
  • gene_pairs: List[Tuple[str, str]] List of tuples representing gene pairs to evaluate.
  • quantiles: int, default=10 Number of quantiles to divide the data into.
  • quantized_data: pd.DataFrame DataFrame containing quantile information, average MECR, variance of MECR, average counts, and number of cells.
Source code in src/segger/validation/utils.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def compute_quantized_mecr_counts(
    adata: sc.AnnData, gene_pairs: List[Tuple[str, str]], quantiles: int = 10
) -> pd.DataFrame:
    """Compute the average MECR, variance of MECR, and average transcript counts for quantiles of transcript counts.

    Args:
    - adata: AnnData
        Annotated data object containing gene expression data.
    - gene_pairs: List[Tuple[str, str]]
        List of tuples representing gene pairs to evaluate.
    - quantiles: int, default=10
        Number of quantiles to divide the data into.

    Returns:
    - quantized_data: pd.DataFrame
        DataFrame containing quantile information, average MECR, variance of MECR, average counts, and number of cells.
    """
    adata.obs["quantile"] = pd.qcut(adata.obs["transcripts"], quantiles, labels=False)
    quantized_data = []
    for quantile in range(quantiles):
        cells_in_quantile = adata.obs["quantile"] == quantile
        mecr = compute_MECR(adata[cells_in_quantile, :], gene_pairs)
        average_mecr = np.mean([i for i in mecr.values()])
        variance_mecr = np.var([i for i in mecr.values()])
        average_counts = adata.obs.loc[cells_in_quantile, "transcripts"].mean()
        quantized_data.append(
            {
                "quantile": quantile / quantiles,
                "average_mecr": average_mecr,
                "variance_mecr": variance_mecr,
                "average_counts": average_counts,
                "num_cells": cells_in_quantile.sum(),
            }
        )
    return pd.DataFrame(quantized_data)

compute_transcript_density

compute_transcript_density(adata)

Compute the transcript density for each cell in the AnnData object.

  • adata: AnnData Annotated data object containing transcript and cell area information.
Source code in src/segger/validation/utils.py
436
437
438
439
440
441
442
443
444
445
446
447
448
def compute_transcript_density(adata: ad.AnnData) -> None:
    """Compute the transcript density for each cell in the AnnData object.

    Args:
    - adata: AnnData
        Annotated data object containing transcript and cell area information.
    """
    try:
        transcript_counts = adata.obs["transcript_counts"]
    except:
        transcript_counts = adata.obs["transcripts"]
    cell_areas = adata.obs["cell_area"]
    adata.obs["transcript_density"] = transcript_counts / cell_areas

draw_umap

draw_umap(adata, column='leiden')

Draw UMAP plots for the given AnnData object.

Parameters:

Name Type Description Default
adata AnnData

The AnnData object containing the data.

required
column str

The column to color the UMAP plot by.

'leiden'
Source code in src/segger/validation/xenium_explorer.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def draw_umap(adata, column: str = "leiden") -> None:
    """Draw UMAP plots for the given AnnData object.

    Args:
        adata (AnnData): The AnnData object containing the data.
        column (str): The column to color the UMAP plot by.
    """
    sc.pl.umap(adata, color=[column])
    plt.show()

    sc.pl.umap(adata, color=["KRT5", "KRT7"], vmax="p95")
    plt.show()

    sc.pl.umap(adata, color=["ACTA2", "PTPRC"], vmax="p95")
    plt.show()

find_markers

find_markers(adata, cell_type_column, pos_percentile=5, neg_percentile=10, percentage=50)

Identify positive and negative markers for each cell type based on gene expression and filter by expression percentage.

  • adata: AnnData Annotated data object containing gene expression data.
  • cell_type_column: str Column name in adata.obs that specifies cell types.
  • pos_percentile: float, default=5 Percentile threshold to determine top x% expressed genes.
  • neg_percentile: float, default=10 Percentile threshold to determine top x% lowly expressed genes.
  • percentage: float, default=50 Minimum percentage of cells expressing the marker within a cell type for it to be considered.
  • markers: dict Dictionary where keys are cell types and values are dictionaries containing: 'positive': list of top x% highly expressed genes 'negative': list of top x% lowly expressed genes.
Source code in src/segger/validation/utils.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def find_markers(
    adata: ad.AnnData,
    cell_type_column: str,
    pos_percentile: float = 5,
    neg_percentile: float = 10,
    percentage: float = 50,
) -> Dict[str, Dict[str, List[str]]]:
    """Identify positive and negative markers for each cell type based on gene expression and filter by expression percentage.

    Args:
    - adata: AnnData
        Annotated data object containing gene expression data.
    - cell_type_column: str
        Column name in `adata.obs` that specifies cell types.
    - pos_percentile: float, default=5
        Percentile threshold to determine top x% expressed genes.
    - neg_percentile: float, default=10
        Percentile threshold to determine top x% lowly expressed genes.
    - percentage: float, default=50
        Minimum percentage of cells expressing the marker within a cell type for it to be considered.

    Returns:
    - markers: dict
        Dictionary where keys are cell types and values are dictionaries containing:
            'positive': list of top x% highly expressed genes
            'negative': list of top x% lowly expressed genes.
    """
    markers = {}
    sc.tl.rank_genes_groups(adata, groupby=cell_type_column)
    genes = adata.var_names
    for cell_type in adata.obs[cell_type_column].unique():
        subset = adata[adata.obs[cell_type_column] == cell_type]
        mean_expression = np.asarray(subset.X.mean(axis=0)).flatten()
        cutoff_high = np.percentile(mean_expression, 100 - pos_percentile)
        cutoff_low = np.percentile(mean_expression, neg_percentile)
        pos_indices = np.where(mean_expression >= cutoff_high)[0]
        neg_indices = np.where(mean_expression <= cutoff_low)[0]
        expr_frac = np.asarray((subset.X[:, pos_indices] > 0).mean(axis=0)).flatten()
        valid_pos_indices = pos_indices[expr_frac >= (percentage / 100)]
        positive_markers = genes[valid_pos_indices]
        negative_markers = genes[neg_indices]
        markers[cell_type] = {"positive": list(positive_markers), "negative": list(negative_markers)}
    return markers

find_mutually_exclusive_genes

find_mutually_exclusive_genes(adata, markers, cell_type_column)

Identify mutually exclusive genes based on expression criteria.

  • adata: AnnData Annotated data object containing gene expression data.
  • markers: dict Dictionary where keys are cell types and values are dictionaries containing: 'positive': list of top x% highly expressed genes 'negative': list of top x% lowly expressed genes.
  • cell_type_column: str Column name in adata.obs that specifies cell types.
  • exclusive_pairs: list List of mutually exclusive gene pairs.
Source code in src/segger/validation/utils.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def find_mutually_exclusive_genes(
    adata: ad.AnnData, markers: Dict[str, Dict[str, List[str]]], cell_type_column: str
) -> List[Tuple[str, str]]:
    """Identify mutually exclusive genes based on expression criteria.

    Args:
    - adata: AnnData
        Annotated data object containing gene expression data.
    - markers: dict
        Dictionary where keys are cell types and values are dictionaries containing:
            'positive': list of top x% highly expressed genes
            'negative': list of top x% lowly expressed genes.
    - cell_type_column: str
        Column name in `adata.obs` that specifies cell types.

    Returns:
    - exclusive_pairs: list
        List of mutually exclusive gene pairs.
    """
    exclusive_genes = {}
    all_exclusive = []
    gene_expression = adata.to_df()
    for cell_type, marker_sets in markers.items():
        positive_markers = marker_sets["positive"]
        exclusive_genes[cell_type] = []
        for gene in positive_markers:
            gene_expr = adata[:, gene].X
            cell_type_mask = adata.obs[cell_type_column] == cell_type
            non_cell_type_mask = ~cell_type_mask
            if (gene_expr[cell_type_mask] > 0).mean() > 0.2 and (gene_expr[non_cell_type_mask] > 0).mean() < 0.05:
                exclusive_genes[cell_type].append(gene)
                all_exclusive.append(gene)
    unique_genes = list({gene for i in exclusive_genes.keys() for gene in exclusive_genes[i] if gene in all_exclusive})
    filtered_exclusive_genes = {
        i: [gene for gene in exclusive_genes[i] if gene in unique_genes] for i in exclusive_genes.keys()
    }
    mutually_exclusive_gene_pairs = [
        (gene1, gene2)
        for key1, key2 in combinations(filtered_exclusive_genes.keys(), 2)
        for gene1 in filtered_exclusive_genes[key1]
        for gene2 in filtered_exclusive_genes[key2]
    ]
    return mutually_exclusive_gene_pairs

generate_experiment_file

generate_experiment_file(template_path, output_path, cells_name='seg_cells', analysis_name='seg_analysis')

Generate the experiment file for Xenium.

Parameters:

Name Type Description Default
template_path str

The path to the template file.

required
output_path str

The path to the output file.

required
cells_name str

The name of the cells file.

'seg_cells'
analysis_name str

The name of the analysis file.

'seg_analysis'
Source code in src/segger/validation/xenium_explorer.py
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
def generate_experiment_file(
    template_path: str, output_path: str, cells_name: str = "seg_cells", analysis_name: str = "seg_analysis"
) -> None:
    """Generate the experiment file for Xenium.

    Args:
        template_path (str): The path to the template file.
        output_path (str): The path to the output file.
        cells_name (str): The name of the cells file.
        analysis_name (str): The name of the analysis file.
    """
    import json

    with open(template_path) as f:
        experiment = json.load(f)

    experiment["images"].pop("morphology_filepath")
    experiment["images"].pop("morphology_focus_filepath")

    experiment["xenium_explorer_files"]["cells_zarr_filepath"] = f"{cells_name}.zarr.zip"
    experiment["xenium_explorer_files"].pop("cell_features_zarr_filepath")
    experiment["xenium_explorer_files"]["analysis_zarr_filepath"] = f"{analysis_name}.zarr.zip"

    with open(output_path, "w") as f:
        json.dump(experiment, f, indent=2)

get_flatten_version

get_flatten_version(polygons, max_value=21)

Get the flattened version of polygon vertices.

Parameters:

Name Type Description Default
polygons List[ndarray]

List of polygon vertices.

required
max_value int

The maximum number of vertices to keep.

21

Returns:

Type Description
ndarray

np.ndarray: The flattened array of polygon vertices.

Source code in src/segger/validation/xenium_explorer.py
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
def get_flatten_version(polygons: List[np.ndarray], max_value: int = 21) -> np.ndarray:
    """Get the flattened version of polygon vertices.

    Args:
        polygons (List[np.ndarray]): List of polygon vertices.
        max_value (int): The maximum number of vertices to keep.

    Returns:
        np.ndarray: The flattened array of polygon vertices.
    """
    n = max_value + 1
    result = np.zeros((len(polygons), n * 2))
    for i, polygon in tqdm(enumerate(polygons), total=len(polygons)):
        num_points = len(polygon)
        if num_points == 0:
            result[i] = np.zeros(n * 2)
            continue
        elif num_points < max_value:
            repeated_points = np.tile(polygon[0], (n - num_points, 1))
            padded_polygon = np.concatenate((polygon, repeated_points), axis=0)
        else:
            padded_polygon = np.zeros((n, 2))
            padded_polygon[: min(num_points, n)] = polygon[: min(num_points, n)]
            padded_polygon[-1] = polygon[0]
        result[i] = padded_polygon.flatten()
    return result

get_indices_indptr

get_indices_indptr(input_array)

Get the indices and indptr arrays for sparse matrix representation.

Parameters:

Name Type Description Default
input_array ndarray

The input array containing cluster labels.

required

Returns:

Type Description
Tuple[ndarray, ndarray]

Tuple[np.ndarray, np.ndarray]: The indices and indptr arrays.

Source code in src/segger/validation/xenium_explorer.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def get_indices_indptr(input_array: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Get the indices and indptr arrays for sparse matrix representation.

    Args:
        input_array (np.ndarray): The input array containing cluster labels.

    Returns:
        Tuple[np.ndarray, np.ndarray]: The indices and indptr arrays.
    """
    clusters = sorted(np.unique(input_array[input_array != 0]))
    indptr = np.zeros(len(clusters), dtype=np.uint32)
    indices = []

    for cluster in clusters:
        cluster_indices = np.where(input_array == cluster)[0]
        indptr[cluster - 1] = len(indices)
        indices.extend(cluster_indices)

    indices.extend(-np.zeros(len(input_array[input_array == 0])))
    indices = np.array(indices, dtype=np.int32).astype(np.uint32)
    return indices, indptr

get_leiden_umap

get_leiden_umap(adata, draw=False)

Perform Leiden clustering and UMAP visualization on the given AnnData object.

Parameters:

Name Type Description Default
adata AnnData

The AnnData object containing the data.

required
draw bool

Whether to draw the UMAP plots.

False

Returns:

Name Type Description
AnnData

The AnnData object with Leiden clustering and UMAP results.

Source code in src/segger/validation/xenium_explorer.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def get_leiden_umap(adata, draw: bool = False):
    """Perform Leiden clustering and UMAP visualization on the given AnnData object.

    Args:
        adata (AnnData): The AnnData object containing the data.
        draw (bool): Whether to draw the UMAP plots.

    Returns:
        AnnData: The AnnData object with Leiden clustering and UMAP results.
    """
    sc.pp.filter_cells(adata, min_genes=5)
    sc.pp.filter_genes(adata, min_cells=5)

    gene_names = adata.var_names
    mean_expression_values = adata.X.mean(axis=0)
    gene_mean_expression_df = pd.DataFrame({"gene_name": gene_names, "mean_expression": mean_expression_values})
    top_genes = gene_mean_expression_df.sort_values(by="mean_expression", ascending=False).head(30)
    top_gene_names = top_genes["gene_name"].tolist()

    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)
    sc.pp.neighbors(adata, n_neighbors=10, n_pcs=30)
    sc.tl.umap(adata)
    sc.tl.leiden(adata)

    if draw:
        draw_umap(adata, "leiden")

    return adata

get_median_expression_table

get_median_expression_table(adata, column='leiden')

Get the median expression table for the given AnnData object.

Parameters:

Name Type Description Default
adata AnnData

The AnnData object containing the data.

required
column str

The column to group by.

'leiden'

Returns:

Type Description
DataFrame

pd.DataFrame: The median expression table.

Source code in src/segger/validation/xenium_explorer.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def get_median_expression_table(adata, column: str = "leiden") -> pd.DataFrame:
    """Get the median expression table for the given AnnData object.

    Args:
        adata (AnnData): The AnnData object containing the data.
        column (str): The column to group by.

    Returns:
        pd.DataFrame: The median expression table.
    """
    top_genes = [
        "GATA3",
        "ACTA2",
        "KRT7",
        "KRT8",
        "KRT5",
        "AQP1",
        "SERPINA3",
        "PTGDS",
        "CXCR4",
        "SFRP1",
        "ENAH",
        "MYH11",
        "SVIL",
        "KRT14",
        "CD4",
    ]
    top_gene_indices = [adata.var_names.get_loc(gene) for gene in top_genes]

    clusters = adata.obs[column]
    cluster_data = {}

    for cluster in clusters.unique():
        cluster_cells = adata[clusters == cluster].X
        cluster_expression = cluster_cells[:, top_gene_indices]
        gene_medians = [
            pd.Series(cluster_expression[:, gene_idx]).median() for gene_idx in range(len(top_gene_indices))
        ]
        cluster_data[f"Cluster_{cluster}"] = gene_medians

    cluster_expression_df = pd.DataFrame(cluster_data, index=top_genes)
    sorted_columns = sorted(cluster_expression_df.columns.values, key=lambda x: int(x.split("_")[-1]))
    cluster_expression_df = cluster_expression_df[sorted_columns]
    return cluster_expression_df.T.style.background_gradient(cmap="Greens")

load_segmentations

load_segmentations(segmentation_paths)

Load segmentation data from provided paths and handle special cases like separating 'segger' into 'segger_n0' and 'segger_n1'.

Args: segmentation_paths (Dict[str, Path]): Dictionary mapping segmentation method names to their file paths.

Returns: Dict[str, sc.AnnData]: Dictionary mapping segmentation method names to loaded AnnData objects.

Source code in src/segger/validation/utils.py
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
def load_segmentations(segmentation_paths: Dict[str, Path]) -> Dict[str, sc.AnnData]:
    """Load segmentation data from provided paths and handle special cases like separating 'segger' into 'segger_n0' and 'segger_n1'.

    Args:
    segmentation_paths (Dict[str, Path]): Dictionary mapping segmentation method names to their file paths.

    Returns:
    Dict[str, sc.AnnData]: Dictionary mapping segmentation method names to loaded AnnData objects.
    """
    segmentations_dict = {}
    for method, path in segmentation_paths.items():
        adata = sc.read(path)
        # Special handling for 'segger' to separate into 'segger_n0' and 'segger_n1'
        if method == "segger":
            cells_n1 = [i for i in adata.obs_names if not i.endswith("-nx")]
            cells_n0 = [i for i in adata.obs_names if i.endswith("-nx")]
            segmentations_dict["segger_n1"] = adata[cells_n1, :]
            segmentations_dict["segger_n0"] = adata[cells_n0, :]
        segmentations_dict[method] = adata
    return segmentations_dict

plot_cell_area

plot_cell_area(segmentations_dict, output_path, palette)

Plot the cell area (log2) for each segmentation method.

Args: segmentations_dict (Dict[str, sc.AnnData]): Dictionary mapping segmentation method names to loaded AnnData objects. output_path (Path): Path to the directory where the plot will be saved.

Source code in src/segger/validation/utils.py
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
def plot_cell_area(segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str]) -> None:
    """Plot the cell area (log2) for each segmentation method.

    Args:
    segmentations_dict (Dict[str, sc.AnnData]): Dictionary mapping segmentation method names to loaded AnnData objects.
    output_path (Path): Path to the directory where the plot will be saved.
    """
    # Prepare the data for the violin plot
    violin_data = pd.DataFrame({"Segmentation Method": [], "Cell Area (log2)": []})
    for method in segmentations_dict.keys():
        if "cell_area" in segmentations_dict[method].obs.columns:
            method_area = segmentations_dict[method].obs["cell_area"] + 1
            method_df = pd.DataFrame(
                {"Segmentation Method": [method] * len(method_area), "Cell Area (log2)": method_area.values}
            )
            violin_data = pd.concat([violin_data, method_df], axis=0)
    violin_data.to_csv(output_path / "cell_area_log2_data.csv", index=True)
    # Plot the violin plots
    plt.figure(figsize=(4, 6))
    ax = sns.violinplot(x="Segmentation Method", y="Cell Area (log2)", data=violin_data, palette=palette)
    ax.set(ylim=(5, 100))
    # Add a dashed line for the 10X-nucleus median
    if "10X-nucleus" in segmentations_dict:
        median_10X_nucleus_area = np.median(segmentations_dict["10X-nucleus"].obs["cell_area"] + 1)
        ax.axhline(y=median_10X_nucleus_area, color="gray", linestyle="--", linewidth=1.5, label="10X-nucleus Median")
    # Set plot titles and labels
    plt.title("")
    plt.xlabel("Segmentation Method")
    plt.ylabel("Cell Area (log2)")
    plt.xticks(rotation=0)
    # Save the figure as a PDF
    plt.savefig(output_path / "cell_area_log2_violin_plot.pdf", bbox_inches="tight")
    plt.show()

plot_cell_counts

plot_cell_counts(segmentations_dict, output_path, palette)

Plot the number of cells per segmentation method and save the cell count data as a CSV.

Args: segmentations_dict (Dict[str, sc.AnnData]): Dictionary mapping segmentation method names to loaded AnnData objects. output_path (Path): Path to the directory where the plot will be saved.

Source code in src/segger/validation/utils.py
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
def plot_cell_counts(segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str]) -> None:
    """Plot the number of cells per segmentation method and save the cell count data as a CSV.

    Args:
    segmentations_dict (Dict[str, sc.AnnData]): Dictionary mapping segmentation method names to loaded AnnData objects.
    output_path (Path): Path to the directory where the plot will be saved.
    """
    # Calculate the number of cells in each segmentation method
    cell_counts = {method: seg.n_obs for method, seg in segmentations_dict.items()}

    # Create a DataFrame for the bar plot
    df = pd.DataFrame(cell_counts, index=["Number of Cells"]).T

    # Save the DataFrame to CSV
    df.to_csv(output_path / "cell_counts_data.csv", index=True)

    # Generate the bar plot
    ax = df.plot(
        kind="bar", stacked=False, color=[palette.get(key, "#333333") for key in df.index], figsize=(3, 6), width=0.9
    )

    # Add a dashed line for the 10X baseline
    if "10X" in cell_counts:
        baseline_height = cell_counts["10X"]
        ax.axhline(y=baseline_height, color="gray", linestyle="--", linewidth=1.5, label="10X Baseline")

    # Set plot titles and labels
    plt.title("Number of Cells per Segmentation Method")
    plt.xlabel("Segmentation Method")
    plt.ylabel("Number of Cells")
    plt.legend(title="", bbox_to_anchor=(1.05, 1), loc="upper left")

    # Save the figure as a PDF
    plt.savefig(output_path / "cell_counts_bar_plot.pdf", bbox_inches="tight")
    plt.show()

plot_contamination_boxplots

plot_contamination_boxplots(boxplot_data, output_path, palette)

Plot boxplots for contamination values across different segmentation methods.

Args: boxplot_data (pd.DataFrame): DataFrame containing contamination data for all segmentation methods. output_path (Path): Path to the directory where the plot will be saved. palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes.

Source code in src/segger/validation/utils.py
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
def plot_contamination_boxplots(boxplot_data: pd.DataFrame, output_path: Path, palette: Dict[str, str]) -> None:
    """Plot boxplots for contamination values across different segmentation methods.

    Args:
    boxplot_data (pd.DataFrame): DataFrame containing contamination data for all segmentation methods.
    output_path (Path): Path to the directory where the plot will be saved.
    palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes.
    """
    boxplot_data.to_csv(output_path / "contamination_box_results.csv", index=True)
    plt.figure(figsize=(14, 8))
    sns.boxplot(x="Source Cell Type", y="Contamination", hue="Segmentation Method", data=boxplot_data, palette=palette)
    plt.title("Neighborhood Contamination")
    plt.xlabel("Source Cell Type")
    plt.ylabel("Contamination")
    plt.legend(title="", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.xticks(rotation=45, ha="right")

    plt.tight_layout()
    plt.savefig(output_path / "contamination_boxplots.pdf", bbox_inches="tight")
    plt.show()

plot_contamination_results

plot_contamination_results(contamination_results, output_path, palette)

Plot contamination results for each segmentation method.

Parameters:

Name Type Description Default
contamination_results Dict[str, DataFrame]

Dictionary of contamination data for each segmentation method.

required
output_path Path

Path to the directory where the plot will be saved.

required
palette Dict[str, str]

Dictionary mapping segmentation method names to color codes.

required
Source code in src/segger/validation/utils.py
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
def plot_contamination_results(
    contamination_results: Dict[str, pd.DataFrame], output_path: Path, palette: Dict[str, str]
) -> None:
    """Plot contamination results for each segmentation method.

    Args:
        contamination_results (Dict[str, pd.DataFrame]): Dictionary of contamination data for each segmentation method.
        output_path (Path): Path to the directory where the plot will be saved.
        palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes.
    """
    contamination_results.to_csv(output_path / "contamination_results.csv", index=True)
    for method, df in contamination_results.items():
        plt.figure(figsize=(10, 6))
        sns.heatmap(df, annot=True, cmap="coolwarm", linewidths=0.5)
        plt.title(f"Contamination Matrix for {method}")
        plt.xlabel("Target Cell Type")
        plt.ylabel("Source Cell Type")
        plt.tight_layout()
        plt.savefig(output_path / f"{method}_contamination_matrix.pdf", bbox_inches="tight")
        plt.show()

plot_counts_per_cell

plot_counts_per_cell(segmentations_dict, output_path, palette)

Plot the counts per cell (log2) for each segmentation method.

Args: segmentations_dict (Dict[str, sc.AnnData]): Dictionary mapping segmentation method names to loaded AnnData objects. output_path (Path): Path to the directory where the plot will be saved.

Source code in src/segger/validation/utils.py
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
def plot_counts_per_cell(segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str]) -> None:
    """Plot the counts per cell (log2) for each segmentation method.

    Args:
    segmentations_dict (Dict[str, sc.AnnData]): Dictionary mapping segmentation method names to loaded AnnData objects.
    output_path (Path): Path to the directory where the plot will be saved.
    """
    # Prepare the data for the violin plot
    violin_data = pd.DataFrame({"Segmentation Method": [], "Counts per Cell (log2)": []})
    for method, adata in segmentations_dict.items():
        method_counts = adata.obs["transcripts"] + 1
        method_df = pd.DataFrame(
            {"Segmentation Method": [method] * len(method_counts), "Counts per Cell (log2)": method_counts.values}
        )
        violin_data = pd.concat([violin_data, method_df], axis=0)

    violin_data.to_csv(output_path / "counts_per_cell_data.csv", index=True)
    # Plot the violin plots
    plt.figure(figsize=(4, 6))
    ax = sns.violinplot(x="Segmentation Method", y="Counts per Cell (log2)", data=violin_data, palette=palette)
    ax.set(ylim=(5, 300))
    # Add a dashed line for the 10X-nucleus median
    if "10X-nucleus" in segmentations_dict:
        median_10X_nucleus = np.median(segmentations_dict["10X-nucleus"].obs["transcripts"] + 1)
        ax.axhline(y=median_10X_nucleus, color="gray", linestyle="--", linewidth=1.5, label="10X-nucleus Median")
    # Set plot titles and labels
    plt.title("")
    plt.xlabel("Segmentation Method")
    plt.ylabel("Counts per Cell (log2)")
    plt.xticks(rotation=0)
    # Save the figure as a PDF
    plt.savefig(output_path / "counts_per_cell_violin_plot.pdf", bbox_inches="tight")
    plt.show()

plot_entropy_boxplots

plot_entropy_boxplots(entropy_boxplot_data, output_path, palette)

Plot boxplots for neighborhood entropy across different segmentation methods by cell type.

Args: entropy_boxplot_data (pd.DataFrame): DataFrame containing neighborhood entropy data for all segmentation methods. output_path (Path): Path to the directory where the plot will be saved. palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes.

Source code in src/segger/validation/utils.py
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
def plot_entropy_boxplots(entropy_boxplot_data: pd.DataFrame, output_path: Path, palette: Dict[str, str]) -> None:
    """Plot boxplots for neighborhood entropy across different segmentation methods by cell type.

    Args:
    entropy_boxplot_data (pd.DataFrame): DataFrame containing neighborhood entropy data for all segmentation methods.
    output_path (Path): Path to the directory where the plot will be saved.
    palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes.
    """
    plt.figure(figsize=(14, 8))
    sns.boxplot(
        x="Cell Type", y="Neighborhood Entropy", hue="Segmentation Method", data=entropy_boxplot_data, palette=palette
    )
    plt.title("Neighborhood Entropy")
    plt.xlabel("Cell Type")
    plt.ylabel("Neighborhood Entropy")
    plt.legend(title="", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(output_path / "neighborhood_entropy_boxplots.pdf", bbox_inches="tight")
    plt.show()

plot_gene_counts

plot_gene_counts(segmentations_dict, output_path, palette)

Plot the normalized gene counts for each segmentation method.

Args: segmentations_dict (Dict[str, sc.AnnData]): Dictionary mapping segmentation method names to loaded AnnData objects. output_path (Path): Path to the directory where the plot will be saved.

Source code in src/segger/validation/utils.py
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
def plot_gene_counts(segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str]) -> None:
    """Plot the normalized gene counts for each segmentation method.

    Args:
    segmentations_dict (Dict[str, sc.AnnData]): Dictionary mapping segmentation method names to loaded AnnData objects.
    output_path (Path): Path to the directory where the plot will be saved.
    """
    # Calculate total counts per gene for each segmentation method
    total_counts_per_gene = pd.DataFrame()

    for method, adata in segmentations_dict.items():
        gene_counts = adata.X.sum(axis=0).flatten()
        gene_counts = pd.Series(gene_counts, index=adata.var_names, name=method)
        total_counts_per_gene = pd.concat([total_counts_per_gene, gene_counts], axis=1)

    # Normalize by the maximum count per gene across all segmentations
    max_counts_per_gene = total_counts_per_gene.max(axis=1)
    normalized_counts_per_gene = total_counts_per_gene.divide(max_counts_per_gene, axis=0)

    # Prepare the data for the box plot
    boxplot_data = pd.DataFrame({"Segmentation Method": [], "Normalized Counts": []})

    for method in segmentations_dict.keys():
        method_counts = normalized_counts_per_gene[method]
        method_df = pd.DataFrame(
            {"Segmentation Method": [method] * len(method_counts), "Normalized Counts": method_counts.values}
        )
        boxplot_data = pd.concat([boxplot_data, method_df], axis=0)

    boxplot_data.to_csv(output_path / "gene_counts_normalized_data.csv", index=True)

    # Plot the box plots
    plt.figure(figsize=(3, 6))
    ax = sns.boxplot(x="Segmentation Method", y="Normalized Counts", data=boxplot_data, palette=palette, width=0.9)

    # Add a dashed line for the 10X baseline
    if "10X" in normalized_counts_per_gene:
        baseline_height = normalized_counts_per_gene["10X"].mean()
        plt.axhline(y=baseline_height, color="gray", linestyle="--", linewidth=1.5, label="10X Baseline")

    # Set plot titles and labels
    plt.title("")
    plt.xlabel("Segmentation Method")
    plt.ylabel("Normalized Counts")
    plt.xticks(rotation=0)

    # Save the figure as a PDF
    plt.savefig(output_path / "gene_counts_normalized_boxplot_by_method.pdf", bbox_inches="tight")
    plt.show()

plot_general_statistics_plots

plot_general_statistics_plots(segmentations_dict, output_path, palette)

Create a summary plot with all the general statistics subplots.

Args: segmentations_dict (Dict[str, sc.AnnData]): Dictionary mapping segmentation method names to loaded AnnData objects. output_path (Path): Path to the directory where the summary plot will be saved.

Source code in src/segger/validation/utils.py
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
def plot_general_statistics_plots(
    segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str]
) -> None:
    """Create a summary plot with all the general statistics subplots.

    Args:
    segmentations_dict (Dict[str, sc.AnnData]): Dictionary mapping segmentation method names to loaded AnnData objects.
    output_path (Path): Path to the directory where the summary plot will be saved.
    """
    plt.figure(figsize=(15, 20))

    plt.subplot(3, 2, 1)
    plot_cell_counts(segmentations_dict, output_path, palette=palette)

    plt.subplot(3, 2, 2)
    plot_percent_assigned(segmentations_dict, output_path, palette=palette)

    plt.subplot(3, 2, 3)
    plot_gene_counts(segmentations_dict, output_path, palette=palette)

    plt.subplot(3, 2, 4)
    plot_counts_per_cell(segmentations_dict, output_path, palette=palette)
    plt.subplot(3, 2, 5)
    plot_cell_area(segmentations_dict, output_path, palette=palette)

    plt.subplot(3, 2, 6)
    plot_transcript_density(segmentations_dict, output_path, palette=palette)

    plt.tight_layout()
    plt.savefig(output_path / "general_statistics_plots.pdf", bbox_inches="tight")
    plt.show()

plot_mecr_results

plot_mecr_results(mecr_results, output_path, palette)

Plot the MECR (Mutually Exclusive Co-expression Rate) results for each segmentation method.

Args: mecr_results (Dict[str, Dict[Tuple[str, str], float]]): Dictionary of MECR results for each segmentation method. output_path (Path): Path to the directory where the plot will be saved. palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes.

Source code in src/segger/validation/utils.py
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
def plot_mecr_results(
    mecr_results: Dict[str, Dict[Tuple[str, str], float]], output_path: Path, palette: Dict[str, str]
) -> None:
    """Plot the MECR (Mutually Exclusive Co-expression Rate) results for each segmentation method.

    Args:
    mecr_results (Dict[str, Dict[Tuple[str, str], float]]): Dictionary of MECR results for each segmentation method.
    output_path (Path): Path to the directory where the plot will be saved.
    palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes.
    """
    # Prepare the data for plotting
    plot_data = []
    for method, mecr_dict in mecr_results.items():
        for gene_pair, mecr_value in mecr_dict.items():
            plot_data.append(
                {"Segmentation Method": method, "Gene Pair": f"{gene_pair[0]} - {gene_pair[1]}", "MECR": mecr_value}
            )
    df = pd.DataFrame(plot_data)
    df.to_csv(output_path / "mcer_box.csv", index=True)
    plt.figure(figsize=(3, 6))
    sns.boxplot(x="Segmentation Method", y="MECR", data=df, palette=palette)
    plt.title("Mutually Exclusive Co-expression Rate (MECR)")
    plt.xlabel("Segmentation Method")
    plt.ylabel("MECR")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(output_path / "mecr_results_boxplot.pdf", bbox_inches="tight")
    plt.show()

plot_metric_comparison

plot_metric_comparison(ax, data, metric, label, method1, method2, output_path)

Plot a comparison of a specific metric between two methods and save the comparison data.

  • ax: plt.Axes Matplotlib axis to plot on.
  • data: pd.DataFrame DataFrame containing the data for plotting.
  • metric: str The metric to compare.
  • label: str Label for the metric.
  • method1: str The first method to compare.
  • method2: str The second method to compare.
  • output_path: Path Path to save the merged DataFrame as a CSV.
Source code in src/segger/validation/utils.py
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
def plot_metric_comparison(
    ax: plt.Axes, data: pd.DataFrame, metric: str, label: str, method1: str, method2: str, output_path: Path
) -> None:
    """Plot a comparison of a specific metric between two methods and save the comparison data.

    Args:
    - ax: plt.Axes
        Matplotlib axis to plot on.
    - data: pd.DataFrame
        DataFrame containing the data for plotting.
    - metric: str
        The metric to compare.
    - label: str
        Label for the metric.
    - method1: str
        The first method to compare.
    - method2: str
        The second method to compare.
    - output_path: Path
        Path to save the merged DataFrame as a CSV.
    """
    subset1 = data[data["method"] == method1]
    subset2 = data[data["method"] == method2]
    merged_data = pd.merge(subset1, subset2, on="celltype_major", suffixes=(f"_{method1}", f"_{method2}"))

    # Save the merged data used in the plot to CSV
    merged_data.to_csv(output_path / f"metric_comparison_{metric}_{method1}_vs_{method2}.csv", index=False)

    for cell_type in merged_data["celltype_major"].unique():
        cell_data = merged_data[merged_data["celltype_major"] == cell_type]
        ax.scatter(cell_data[f"{metric}_{method1}"], cell_data[f"{metric}_{method2}"], label=cell_type)

    max_value = max(merged_data[f"{metric}_{method1}"].max(), merged_data[f"{metric}_{method2}"].max())
    ax.plot([0, max_value], [0, max_value], "k--", alpha=0.5)
    ax.set_xlabel(f"{label} ({method1})")
    ax.set_ylabel(f"{label} ({method2})")
    ax.set_title(f"{label}: {method1} vs {method2}")

plot_percent_assigned

plot_percent_assigned(segmentations_dict, output_path, palette)

Plot the percentage of assigned transcripts (normalized) for each segmentation method.

Args: segmentations_dict (Dict[str, sc.AnnData]): Dictionary mapping segmentation method names to loaded AnnData objects. output_path (Path): Path to the directory where the plot will be saved.

Source code in src/segger/validation/utils.py
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
def plot_percent_assigned(
    segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str]
) -> None:
    """Plot the percentage of assigned transcripts (normalized) for each segmentation method.

    Args:
    segmentations_dict (Dict[str, sc.AnnData]): Dictionary mapping segmentation method names to loaded AnnData objects.
    output_path (Path): Path to the directory where the plot will be saved.
    """
    # Calculate total counts per gene for each segmentation method
    total_counts_per_gene = pd.DataFrame()

    for method, adata in segmentations_dict.items():
        gene_counts = adata.X.sum(axis=0).flatten()  # Sum across cells for each gene and flatten to 1D
        gene_counts = pd.Series(gene_counts, index=adata.var_names, name=method)
        total_counts_per_gene = pd.concat([total_counts_per_gene, gene_counts], axis=1)

    # Normalize by the maximum count per gene across all segmentations
    max_counts_per_gene = total_counts_per_gene.max(axis=1)
    percent_assigned_normalized = total_counts_per_gene.divide(max_counts_per_gene, axis=0) * 100

    # Prepare the data for the violin plot
    violin_data = pd.DataFrame({"Segmentation Method": [], "Percent Assigned (Normalized)": []})

    # Add normalized percent_assigned data for each method
    for method in segmentations_dict.keys():
        method_data = percent_assigned_normalized[method].dropna()
        method_df = pd.DataFrame(
            {"Segmentation Method": [method] * len(method_data), "Percent Assigned (Normalized)": method_data.values}
        )
        violin_data = pd.concat([violin_data, method_df], axis=0)

    violin_data.to_csv(output_path / "percent_assigned_normalized.csv", index=True)

    # Plot the violin plots
    plt.figure(figsize=(12, 8))
    ax = sns.violinplot(x="Segmentation Method", y="Percent Assigned (Normalized)", data=violin_data, palette=palette)

    # Add a dashed line for the 10X baseline
    if "10X" in segmentations_dict:
        baseline_height = percent_assigned_normalized["10X"].mean()
        ax.axhline(y=baseline_height, color="gray", linestyle="--", linewidth=1.5, label="10X Baseline")

    # Set plot titles and labels
    plt.title("")
    plt.xlabel("Segmentation Method")
    plt.ylabel("Percent Assigned (Normalized)")
    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

    # Save the figure as a PDF
    plt.savefig(output_path / "percent_assigned_normalized_violin_plot.pdf", bbox_inches="tight")
    plt.show()

plot_quantized_mecr_area

plot_quantized_mecr_area(quantized_mecr_area, output_path, palette)

Plot the quantized MECR values against cell areas for each segmentation method, with point size proportional to the variance of MECR.

Args: quantized_mecr_area (Dict[str, pd.DataFrame]): Dictionary of quantized MECR area data for each segmentation method. output_path (Path): Path to the directory where the plot will be saved. palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes.

Source code in src/segger/validation/utils.py
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
def plot_quantized_mecr_area(
    quantized_mecr_area: Dict[str, pd.DataFrame], output_path: Path, palette: Dict[str, str]
) -> None:
    """Plot the quantized MECR values against cell areas for each segmentation method, with point size proportional to the variance of MECR.

    Args:
    quantized_mecr_area (Dict[str, pd.DataFrame]): Dictionary of quantized MECR area data for each segmentation method.
    output_path (Path): Path to the directory where the plot will be saved.
    palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes.
    """
    # quantized_mecr_area.to_csv(output_path / 'quantized_mecr_area.csv', index=True)
    plt.figure(figsize=(6, 4))
    for method, df in quantized_mecr_area.items():
        plt.plot(
            df["average_area"],
            df["average_mecr"],
            marker="o",
            # s=df['variance_mecr']  * 1e5,
            linestyle="-",
            color=palette.get(method, "#333333"),
            label=method,
            markersize=0,
        )
        plt.scatter(
            df["average_area"],
            df["average_mecr"],
            s=df["variance_mecr"] * 1e5,  # Size of points based on the variance of MECR
            color=palette.get(method, "#333333"),
            alpha=0.7,  # Slight transparency for overlapping points
            edgecolor="w",  # White edge color for better visibility
            linewidth=0.5,  # Thin edge line
        )
    plt.title("Quantized MECR by Cell Area")
    plt.xlabel("Average Cell Area")
    plt.ylabel("Average MECR")
    # Place the legend outside the plot on the top right
    plt.legend(title="", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.savefig(output_path / "quantized_mecr_area_plot.pdf", bbox_inches="tight")
    plt.show()

plot_quantized_mecr_counts

plot_quantized_mecr_counts(quantized_mecr_counts, output_path, palette)

Plot the quantized MECR values against transcript counts for each segmentation method, with point size proportional to the variance of MECR.

Parameters:

Name Type Description Default
quantized_mecr_counts Dict[str, DataFrame]

Dictionary of quantized MECR count data for each segmentation method.

required
output_path Path

Path to the directory where the plot will be saved.

required
palette Dict[str, str]

Dictionary mapping segmentation method names to color codes.

required
Source code in src/segger/validation/utils.py
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
def plot_quantized_mecr_counts(
    quantized_mecr_counts: Dict[str, pd.DataFrame], output_path: Path, palette: Dict[str, str]
) -> None:
    """Plot the quantized MECR values against transcript counts for each segmentation method, with point size proportional to the variance of MECR.

    Args:
        quantized_mecr_counts (Dict[str, pd.DataFrame]): Dictionary of quantized MECR count data for each segmentation method.
        output_path (Path): Path to the directory where the plot will be saved.
        palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes.
    """
    # quantized_mecr_counts.to_csv(output_path / 'quantized_mecr_counts.csv', index=True)
    plt.figure(figsize=(9, 6))
    for method, df in quantized_mecr_counts.items():
        plt.plot(
            df["average_counts"],
            df["average_mecr"],
            marker="o",
            linestyle="-",
            color=palette.get(method, "#333333"),
            label=method,
            markersize=0,  # No markers, only lines
        )
        plt.scatter(
            df["average_counts"],
            df["average_mecr"],
            s=df["variance_mecr"] * 1e5,  # Size of points based on the variance of MECR
            color=palette.get(method, "#333333"),
            alpha=0.7,  # Slight transparency for overlapping points
            edgecolor="w",  # White edge color for better visibility
            linewidth=0.5,  # Thin edge line
        )
    plt.title("Quantized MECR by Transcript Counts")
    plt.xlabel("Average Transcript Counts")
    plt.ylabel("Average MECR")
    # Place the legend outside the plot on the top right
    plt.legend(title="", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.tight_layout()
    plt.savefig(output_path / "quantized_mecr_counts_plot.pdf", bbox_inches="tight")
    plt.show()

plot_sensitivity_boxplots

plot_sensitivity_boxplots(sensitivity_boxplot_data, output_path, palette)

Plot boxplots for sensitivity across different segmentation methods by cell type. Args: sensitivity_boxplot_data (pd.DataFrame): DataFrame containing sensitivity data for all segmentation methods. output_path (Path): Path to the directory where the plot will be saved. palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes.

Source code in src/segger/validation/utils.py
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
def plot_sensitivity_boxplots(
    sensitivity_boxplot_data: pd.DataFrame, output_path: Path, palette: Dict[str, str]
) -> None:
    """Plot boxplots for sensitivity across different segmentation methods by cell type.
    Args:
        sensitivity_boxplot_data (pd.DataFrame): DataFrame containing sensitivity data for all segmentation methods.
        output_path (Path): Path to the directory where the plot will be saved.
        palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes.
    """
    sensitivity_boxplot_data.to_csv(output_path / "sensitivity_results.csv", index=True)
    plt.figure(figsize=(14, 8))
    sns.boxplot(
        x="Cell Type", y="Sensitivity", hue="Segmentation Method", data=sensitivity_boxplot_data, palette=palette
    )
    plt.title("Sensitivity Score")
    plt.xlabel("Cell Type")
    plt.ylabel("Sensitivity")
    plt.legend(title="", bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(output_path / "sensitivity_boxplots.pdf", bbox_inches="tight")
    plt.show()

plot_transcript_density

plot_transcript_density(segmentations_dict, output_path, palette)

Plot the transcript density (log2) for each segmentation method.

Args: segmentations_dict (Dict[str, sc.AnnData]): Dictionary mapping segmentation method names to loaded AnnData objects. output_path (Path): Path to the directory where the plot will be saved.

Source code in src/segger/validation/utils.py
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
def plot_transcript_density(
    segmentations_dict: Dict[str, sc.AnnData], output_path: Path, palette: Dict[str, str]
) -> None:
    """Plot the transcript density (log2) for each segmentation method.

    Args:
    segmentations_dict (Dict[str, sc.AnnData]): Dictionary mapping segmentation method names to loaded AnnData objects.
    output_path (Path): Path to the directory where the plot will be saved.
    """
    # Prepare the data for the violin plot
    violin_data = pd.DataFrame({"Segmentation Method": [], "Transcript Density (log2)": []})

    for method in segmentations_dict.keys():
        if "cell_area" in segmentations_dict[method].obs.columns:
            method_density = segmentations_dict[method].obs["transcripts"] / segmentations_dict[method].obs["cell_area"]
            method_density_log2 = np.log2(method_density + 1)
            method_df = pd.DataFrame(
                {
                    "Segmentation Method": [method] * len(method_density_log2),
                    "Transcript Density (log2)": method_density_log2.values,
                }
            )
            violin_data = pd.concat([violin_data, method_df], axis=0)

    violin_data.to_csv(output_path / "transcript_density_log2_data.csv", index=True)

    # Plot the violin plots
    plt.figure(figsize=(4, 6))
    ax = sns.violinplot(x="Segmentation Method", y="Transcript Density (log2)", data=violin_data, palette=palette)

    # Add a dashed line for the 10X-nucleus median
    if "10X-nucleus" in segmentations_dict:
        median_10X_nucleus_density_log2 = np.median(
            np.log2(
                segmentations_dict["10X-nucleus"].obs["transcripts"]
                / segmentations_dict["10X-nucleus"].obs["cell_area"]
                + 1
            )
        )
        ax.axhline(
            y=median_10X_nucleus_density_log2, color="gray", linestyle="--", linewidth=1.5, label="10X-nucleus Median"
        )

    # Set plot titles and labels
    plt.title("")
    plt.xlabel("Segmentation Method")
    plt.ylabel("Transcript Density (log2)")
    plt.xticks(rotation=0)

    # Save the figure as a PDF
    plt.savefig(output_path / "transcript_density_log2_violin_plot.pdf", bbox_inches="tight")
    plt.show()

plot_umaps_with_scores

plot_umaps_with_scores(segmentations_dict, clustering_scores, output_path, palette)

Plot UMAPs colored by cell type for each segmentation method and display clustering scores in the title. Args: segmentations_dict (Dict[str, AnnData]): Dictionary of AnnData objects for each segmentation method. clustering_scores (Dict[str, Tuple[float, float]]): Dictionary of clustering scores for each method. output_path (Path): Path to the directory where the plots will be saved. palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes.

Source code in src/segger/validation/utils.py
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
def plot_umaps_with_scores(
    segmentations_dict: Dict[str, sc.AnnData],
    clustering_scores: Dict[str, Tuple[float, float]],
    output_path: Path,
    palette: Dict[str, str],
) -> None:
    """Plot UMAPs colored by cell type for each segmentation method and display clustering scores in the title.
    Args:
    segmentations_dict (Dict[str, AnnData]): Dictionary of AnnData objects for each segmentation method.
    clustering_scores (Dict[str, Tuple[float, float]]): Dictionary of clustering scores for each method.
    output_path (Path): Path to the directory where the plots will be saved.
    palette (Dict[str, str]): Dictionary mapping segmentation method names to color codes.
    """
    for method, adata in segmentations_dict.items():
        print(method)
        adata_copy = adata.copy()
        sc.pp.subsample(adata_copy, n_obs=10000)
        sc.pp.normalize_total(adata_copy)
        # Plot UMAP colored by cell type
        plt.figure(figsize=(8, 6))
        sc.pp.neighbors(adata_copy, n_neighbors=5)
        sc.tl.umap(adata_copy, spread=5)
        sc.pl.umap(adata_copy, color="celltype_major", palette=palette, show=False)
        # Add clustering scores to the title
        ch_score, sh_score = compute_clustering_scores(adata_copy, cell_type_column="celltype_major")
        plt.title(f"{method} - UMAP\nCalinski-Harabasz: {ch_score:.2f}, Silhouette: {sh_score:.2f}")
        # Save the figure
        plt.savefig(output_path / f"{method}_umap_with_scores.pdf", bbox_inches="tight")
        plt.show()

save_cell_clustering

save_cell_clustering(merged, zarr_path, columns)

Save cell clustering information to a Zarr file.

Parameters:

Name Type Description Default
merged DataFrame

The merged dataframe containing cell clustering information.

required
zarr_path str

The path to the Zarr file.

required
columns List[str]

The list of columns to save.

required
Source code in src/segger/validation/xenium_explorer.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def save_cell_clustering(merged: pd.DataFrame, zarr_path: str, columns: List[str]) -> None:
    """Save cell clustering information to a Zarr file.

    Args:
        merged (pd.DataFrame): The merged dataframe containing cell clustering information.
        zarr_path (str): The path to the Zarr file.
        columns (List[str]): The list of columns to save.
    """
    import zarr

    new_zarr = zarr.open(zarr_path, mode="w")
    new_zarr.create_group("/cell_groups")

    mappings = []
    for index, column in enumerate(columns):
        new_zarr["cell_groups"].create_group(index)
        classes = list(np.unique(merged[column].astype(str)))
        mapping_dict = {key: i for i, key in zip(range(1, len(classes)), [k for k in classes if k != "nan"])}
        mapping_dict["nan"] = 0

        clusters = merged[column].astype(str).replace(mapping_dict).values.astype(int)
        indices, indptr = get_indices_indptr(clusters)

        new_zarr["cell_groups"][index].create_dataset("indices", data=indices)
        new_zarr["cell_groups"][index].create_dataset("indptr", data=indptr)
        mappings.append(mapping_dict)

    new_zarr["cell_groups"].attrs.update(
        {
            "major_version": 1,
            "minor_version": 0,
            "number_groupings": len(columns),
            "grouping_names": columns,
            "group_names": [
                [k for k, v in sorted(mapping_dict.items(), key=lambda item: item[1])][1:] for mapping_dict in mappings
            ],
        }
    )
    new_zarr.store.close()

seg2explorer

seg2explorer(seg_df, source_path, output_dir, cells_filename='seg_cells', analysis_filename='seg_analysis', xenium_filename='seg_experiment.xenium', analysis_df=None, draw=False, cell_id_columns='seg_cell_id', area_low=10, area_high=100)

Convert seg output to a format compatible with Xenium explorer.

Parameters:

Name Type Description Default
seg_df DataFrame

The seg DataFrame.

required
source_path str

The source path.

required
output_dir str

The output directory.

required
cells_filename str

The filename for cells.

'seg_cells'
analysis_filename str

The filename for analysis.

'seg_analysis'
xenium_filename str

The filename for Xenium.

'seg_experiment.xenium'
analysis_df Optional[DataFrame]

The analysis DataFrame.

None
draw bool

Whether to draw the plots.

False
cell_id_columns str

The cell ID columns.

'seg_cell_id'
area_low float

The lower area threshold.

10
area_high float

The upper area threshold.

100
Source code in src/segger/validation/xenium_explorer.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def seg2explorer(
    seg_df: pd.DataFrame,
    source_path: str,
    output_dir: str,
    cells_filename: str = "seg_cells",
    analysis_filename: str = "seg_analysis",
    xenium_filename: str = "seg_experiment.xenium",
    analysis_df: Optional[pd.DataFrame] = None,
    draw: bool = False,
    cell_id_columns: str = "seg_cell_id",
    area_low: float = 10,
    area_high: float = 100,
) -> None:
    """Convert seg output to a format compatible with Xenium explorer.

    Args:
        seg_df (pd.DataFrame): The seg DataFrame.
        source_path (str): The source path.
        output_dir (str): The output directory.
        cells_filename (str): The filename for cells.
        analysis_filename (str): The filename for analysis.
        xenium_filename (str): The filename for Xenium.
        analysis_df (Optional[pd.DataFrame]): The analysis DataFrame.
        draw (bool): Whether to draw the plots.
        cell_id_columns (str): The cell ID columns.
        area_low (float): The lower area threshold.
        area_high (float): The upper area threshold.
    """
    import zarr
    import json

    source_path = Path(source_path)
    storage = Path(output_dir)

    cell_id2old_id = {}
    cell_id = []
    cell_summary = []
    polygon_num_vertices = [[], []]
    polygon_vertices = [[], []]
    seg_mask_value = []
    tma_id = []

    grouped_by = seg_df.groupby(cell_id_columns)
    for cell_incremental_id, (seg_cell_id, seg_cell) in tqdm(enumerate(grouped_by), total=len(grouped_by)):
        if len(seg_cell) < 5:
            continue

        cell_convex_hull = ConvexHull(seg_cell[["x_location", "y_location"]])
        if cell_convex_hull.area > area_high:
            continue
        if cell_convex_hull.area < area_low:
            continue

        uint_cell_id = cell_incremental_id + 1
        cell_id2old_id[uint_cell_id] = seg_cell_id

        seg_nucleous = seg_cell[seg_cell["overlaps_nucleus"] == 1]
        if len(seg_nucleous) >= 3:
            nucleus_convex_hull = ConvexHull(seg_nucleous[["x_location", "y_location"]])

        cell_id.append(uint_cell_id)
        cell_summary.append(
            {
                "cell_centroid_x": seg_cell["x_location"].mean(),
                "cell_centroid_y": seg_cell["y_location"].mean(),
                "cell_area": cell_convex_hull.area,
                "nucleus_centroid_x": seg_cell["x_location"].mean(),
                "nucleus_centroid_y": seg_cell["y_location"].mean(),
                "nucleus_area": cell_convex_hull.area,
                "z_level": (seg_cell.z_location.mean() // 3).round(0) * 3,
            }
        )

        polygon_num_vertices[0].append(len(cell_convex_hull.vertices))
        polygon_num_vertices[1].append(len(nucleus_convex_hull.vertices) if len(seg_nucleous) >= 3 else 0)
        polygon_vertices[0].append(seg_cell[["x_location", "y_location"]].values[cell_convex_hull.vertices])
        polygon_vertices[1].append(
            seg_nucleous[["x_location", "y_location"]].values[nucleus_convex_hull.vertices]
            if len(seg_nucleous) >= 3
            else np.array([[], []]).T
        )
        seg_mask_value.append(cell_incremental_id + 1)

    cell_polygon_vertices = get_flatten_version(polygon_vertices[0], max_value=21)
    nucl_polygon_vertices = get_flatten_version(polygon_vertices[1], max_value=21)

    cells = {
        "cell_id": np.array([np.array(cell_id), np.ones(len(cell_id))], dtype=np.uint32).T,
        "cell_summary": pd.DataFrame(cell_summary).values.astype(np.float64),
        "polygon_num_vertices": np.array(
            [
                [min(x + 1, x + 1) for x in polygon_num_vertices[1]],
                [min(x + 1, x + 1) for x in polygon_num_vertices[0]],
            ],
            dtype=np.int32,
        ),
        "polygon_vertices": np.array([nucl_polygon_vertices, cell_polygon_vertices]).astype(np.float32),
        "seg_mask_value": np.array(seg_mask_value, dtype=np.int32),
    }

    existing_store = zarr.open(source_path / "cells.zarr.zip", mode="r")
    new_store = zarr.open(storage / f"{cells_filename}.zarr.zip", mode="w")

    new_store["cell_id"] = cells["cell_id"]
    new_store["polygon_num_vertices"] = cells["polygon_num_vertices"]
    new_store["polygon_vertices"] = cells["polygon_vertices"]
    new_store["seg_mask_value"] = cells["seg_mask_value"]

    new_store.attrs.update(existing_store.attrs)
    new_store.attrs["number_cells"] = len(cells["cell_id"])
    new_store.store.close()

    if analysis_df is None:
        analysis_df = pd.DataFrame([cell_id2old_id[i] for i in cell_id], columns=[cell_id_columns])
        analysis_df["default"] = "seg"

    zarr_df = pd.DataFrame([cell_id2old_id[i] for i in cell_id], columns=[cell_id_columns])
    clustering_df = pd.merge(zarr_df, analysis_df, how="left", on=cell_id_columns)

    clusters_names = [i for i in analysis_df.columns if i != cell_id_columns]
    clusters_dict = {
        cluster: {
            j: i
            for i, j in zip(
                range(1, len(sorted(np.unique(clustering_df[cluster].dropna()))) + 1),
                sorted(np.unique(clustering_df[cluster].dropna())),
            )
        }
        for cluster in clusters_names
    }

    new_zarr = zarr.open(storage / (analysis_filename + ".zarr.zip"), mode="w")
    new_zarr.create_group("/cell_groups")

    clusters = [[clusters_dict[cluster].get(x, 0) for x in list(clustering_df[cluster])] for cluster in clusters_names]

    for i in range(len(clusters)):
        new_zarr["cell_groups"].create_group(i)
        indices, indptr = get_indices_indptr(np.array(clusters[i]))
        new_zarr["cell_groups"][i].create_dataset("indices", data=indices)
        new_zarr["cell_groups"][i].create_dataset("indptr", data=indptr)

    new_zarr["cell_groups"].attrs.update(
        {
            "major_version": 1,
            "minor_version": 0,
            "number_groupings": len(clusters_names),
            "grouping_names": clusters_names,
            "group_names": [
                [x[0] for x in sorted(clusters_dict[cluster].items(), key=lambda x: x[1])] for cluster in clusters_names
            ],
        }
    )

    new_zarr.store.close()
    generate_experiment_file(
        template_path=source_path / "experiment.xenium",
        output_path=storage / xenium_filename,
        cells_name=cells_filename,
        analysis_name=analysis_filename,
    )

str_to_uint32

str_to_uint32(cell_id_str)

Convert a string cell ID back to uint32 format.

Parameters:

Name Type Description Default
cell_id_str str

The cell ID in string format.

required

Returns:

Type Description
Tuple[int, int]

Tuple[int, int]: The cell ID in uint32 format and the dataset suffix.

Source code in src/segger/validation/xenium_explorer.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def str_to_uint32(cell_id_str: str) -> Tuple[int, int]:
    """Convert a string cell ID back to uint32 format.

    Args:
        cell_id_str (str): The cell ID in string format.

    Returns:
        Tuple[int, int]: The cell ID in uint32 format and the dataset suffix.
    """
    prefix, suffix = cell_id_str.split("-")
    str_to_hex_mapping = {
        "a": "0",
        "b": "1",
        "c": "2",
        "d": "3",
        "e": "4",
        "f": "5",
        "g": "6",
        "h": "7",
        "i": "8",
        "j": "9",
        "k": "a",
        "l": "b",
        "m": "c",
        "n": "d",
        "o": "e",
        "p": "f",
    }
    hex_prefix = "".join([str_to_hex_mapping[char] for char in prefix])
    cell_id_uint32 = int(hex_prefix, 16)
    dataset_suffix = int(suffix)
    return cell_id_uint32, dataset_suffix