segger.data.transcript_embedding¶
The transcript_embedding
module provides utilities for encoding transcript features into numerical representations suitable for machine learning models. This module handles the conversion of gene names and transcript labels into embeddings that can be used in graph neural networks.
TranscriptEmbedding ¶
Bases: Module
Utility class to handle transcript embeddings in PyTorch so that they are optionally learnable in the future.
Default behavior is to use the index of gene names.
Source code in src/segger/data/transcript_embedding.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
|
__init__ ¶
__init__(classes, weights=None)
Initialize the TranscriptEmbedding module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
classes |
ArrayLike
|
A 1D array of unique class names. |
required |
weights |
Optional[DataFrame]
|
Optional DataFrame containing weights for each class. Defaults to None. |
None
|
Source code in src/segger/data/transcript_embedding.py
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
|
_check_inputs
staticmethod
¶
_check_inputs(classes, weights)
Check input arguments for validity.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
classes |
ArrayLike
|
A 1D array of unique class names. |
required |
weights |
Union[DataFrame, None]
|
Optional DataFrame containing weights for each class. Defaults to None. |
required |
Raises:
Type | Description |
---|---|
ValueError
|
If classes is not 1D, contains duplicates, or if weights DataFrame is missing entries for some classes. |
Source code in src/segger/data/transcript_embedding.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
|
embed ¶
embed(classes)
Embed transcript classes into numerical representations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
classes |
ArrayLike
|
Array of class names to embed. |
required |
Returns:
Type | Description |
---|---|
Union[LongTensor, torch.Tensor]: If no weights provided, returns indices. If weights provided, returns embedded representations. |
Source code in src/segger/data/transcript_embedding.py
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
|
Overview¶
The TranscriptEmbedding
class is designed to handle transcript feature encoding in a flexible and extensible way. It supports both simple index-based encoding and weighted embeddings, making it suitable for various machine learning applications.
Usage Examples¶
Basic Index-based Encoding¶
from segger.data.transcript_embedding import TranscriptEmbedding
import pandas as pd
# Create a list of gene names
gene_names = ["GENE1", "GENE2", "GENE3", "GENE4"]
# Initialize embedding without weights (index-based)
embedding = TranscriptEmbedding(classes=gene_names)
# Encode transcript labels
transcript_labels = ["GENE1", "GENE3", "GENE2"]
encoded = embedding.embed(transcript_labels)
# Returns: tensor([0, 2, 1])
Weighted Embeddings¶
import pandas as pd
# Create weights DataFrame
weights_df = pd.DataFrame({
'weight1': [0.1, 0.2, 0.3, 0.4],
'weight2': [0.5, 0.6, 0.7, 0.8]
}, index=["GENE1", "GENE2", "GENE3", "GENE4"])
# Initialize embedding with weights
embedding = TranscriptEmbedding(
classes=gene_names,
weights=weights_df
)
# Encode transcript labels
transcript_labels = ["GENE1", "GENE3"]
encoded = embedding.embed(transcript_labels)
# Returns: tensor([[0.1, 0.5], [0.3, 0.7]])
Integration with PyTorch¶
import torch
from segger.data.transcript_embedding import TranscriptEmbedding
# Create embedding module
embedding = TranscriptEmbedding(classes=gene_names)
# Use in a neural network
class TranscriptEncoder(torch.nn.Module):
def __init__(self, gene_names):
super().__init__()
self.embedding = TranscriptEmbedding(gene_names)
self.projection = torch.nn.Linear(len(gene_names), 128)
def forward(self, transcript_labels):
encoded = self.embedding.embed(transcript_labels)
projected = self.projection(encoded)
return projected
# Initialize and use
encoder = TranscriptEncoder(gene_names)
output = encoder(["GENE1", "GENE2"])