Skip to content

segger.data.transcript_embedding

The transcript_embedding module provides utilities for encoding transcript features into numerical representations suitable for machine learning models. This module handles the conversion of gene names and transcript labels into embeddings that can be used in graph neural networks.

TranscriptEmbedding

Bases: Module

Utility class to handle transcript embeddings in PyTorch so that they are optionally learnable in the future.

Default behavior is to use the index of gene names.

Source code in src/segger/data/transcript_embedding.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class TranscriptEmbedding(torch.nn.Module):
    """Utility class to handle transcript embeddings in PyTorch so that they are
    optionally learnable in the future.

    Default behavior is to use the index of gene names.
    """

    # TODO: Add documentation
    @staticmethod
    def _check_inputs(
        classes: ArrayLike,
        weights: Union[pd.DataFrame, None],
    ):
        """Check input arguments for validity.

        Args:
            classes: A 1D array of unique class names.
            weights: Optional DataFrame containing weights for each class. Defaults to None.

        Raises:
            ValueError: If classes is not 1D, contains duplicates, or if weights DataFrame
                is missing entries for some classes.
        """
        # Classes is a 1D array
        if len(classes.shape) > 1:
            msg = (
                "'classes' should be a 1D array, got an array of shape "
                f"{classes.shape} instead."
            )
            raise ValueError(msg)
        # Items appear exactly once
        if len(classes) != len(set(classes)):
            msg = (
                "All embedding classes must be unique. One or more items in "
                "'classes' appears twice."
            )
            raise ValueError(msg)
        # All classes have an entry in weights
        elif weights is not None:
            missing = set(classes).difference(weights.index)
            if len(missing) > 0:
                msg = (
                    f"Index of 'weights' DataFrame is missing {len(missing)} "
                    "entries compared to classes."
                )
                raise ValueError(msg)

    # TODO: Add documentation
    def __init__(
        self,
        classes: ArrayLike,
        weights: Optional[pd.DataFrame] = None,
    ):
        """Initialize the TranscriptEmbedding module.

        Args:
            classes: A 1D array of unique class names.
            weights: Optional DataFrame containing weights for each class. Defaults to None.
        """
        # check input arguments
        TranscriptEmbedding._check_inputs(classes, weights)
        # Setup as PyTorch module
        super(TranscriptEmbedding, self).__init__()
        self._encoder = LabelEncoder().fit(classes)
        if weights is None:
            self._weights = None
        else:
            self._weights = Tensor(weights.loc[classes].values)

    # TODO: Add documentation
    def embed(self, classes: ArrayLike):
        """Embed transcript classes into numerical representations.

        Args:
            classes: Array of class names to embed.

        Returns:
            Union[LongTensor, torch.Tensor]: If no weights provided, returns indices.
                If weights provided, returns embedded representations.
        """
        indices = LongTensor(self._encoder.transform(classes))
        # Default, one-hot encoding
        if self._weights is None:
            return indices  # F.one_hot(indices, len(self._encoder.classes_))
        else:
            return F.embedding(indices, self._weights)

_encoder instance-attribute

_encoder = fit(classes)

_weights instance-attribute

_weights = None

__init__

__init__(classes, weights=None)

Initialize the TranscriptEmbedding module.

Parameters:

Name Type Description Default
classes ArrayLike

A 1D array of unique class names.

required
weights Optional[DataFrame]

Optional DataFrame containing weights for each class. Defaults to None.

None
Source code in src/segger/data/transcript_embedding.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def __init__(
    self,
    classes: ArrayLike,
    weights: Optional[pd.DataFrame] = None,
):
    """Initialize the TranscriptEmbedding module.

    Args:
        classes: A 1D array of unique class names.
        weights: Optional DataFrame containing weights for each class. Defaults to None.
    """
    # check input arguments
    TranscriptEmbedding._check_inputs(classes, weights)
    # Setup as PyTorch module
    super(TranscriptEmbedding, self).__init__()
    self._encoder = LabelEncoder().fit(classes)
    if weights is None:
        self._weights = None
    else:
        self._weights = Tensor(weights.loc[classes].values)

_check_inputs staticmethod

_check_inputs(classes, weights)

Check input arguments for validity.

Parameters:

Name Type Description Default
classes ArrayLike

A 1D array of unique class names.

required
weights Union[DataFrame, None]

Optional DataFrame containing weights for each class. Defaults to None.

required

Raises:

Type Description
ValueError

If classes is not 1D, contains duplicates, or if weights DataFrame is missing entries for some classes.

Source code in src/segger/data/transcript_embedding.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@staticmethod
def _check_inputs(
    classes: ArrayLike,
    weights: Union[pd.DataFrame, None],
):
    """Check input arguments for validity.

    Args:
        classes: A 1D array of unique class names.
        weights: Optional DataFrame containing weights for each class. Defaults to None.

    Raises:
        ValueError: If classes is not 1D, contains duplicates, or if weights DataFrame
            is missing entries for some classes.
    """
    # Classes is a 1D array
    if len(classes.shape) > 1:
        msg = (
            "'classes' should be a 1D array, got an array of shape "
            f"{classes.shape} instead."
        )
        raise ValueError(msg)
    # Items appear exactly once
    if len(classes) != len(set(classes)):
        msg = (
            "All embedding classes must be unique. One or more items in "
            "'classes' appears twice."
        )
        raise ValueError(msg)
    # All classes have an entry in weights
    elif weights is not None:
        missing = set(classes).difference(weights.index)
        if len(missing) > 0:
            msg = (
                f"Index of 'weights' DataFrame is missing {len(missing)} "
                "entries compared to classes."
            )
            raise ValueError(msg)

embed

embed(classes)

Embed transcript classes into numerical representations.

Parameters:

Name Type Description Default
classes ArrayLike

Array of class names to embed.

required

Returns:

Type Description

Union[LongTensor, torch.Tensor]: If no weights provided, returns indices. If weights provided, returns embedded representations.

Source code in src/segger/data/transcript_embedding.py
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def embed(self, classes: ArrayLike):
    """Embed transcript classes into numerical representations.

    Args:
        classes: Array of class names to embed.

    Returns:
        Union[LongTensor, torch.Tensor]: If no weights provided, returns indices.
            If weights provided, returns embedded representations.
    """
    indices = LongTensor(self._encoder.transform(classes))
    # Default, one-hot encoding
    if self._weights is None:
        return indices  # F.one_hot(indices, len(self._encoder.classes_))
    else:
        return F.embedding(indices, self._weights)

Overview

The TranscriptEmbedding class is designed to handle transcript feature encoding in a flexible and extensible way. It supports both simple index-based encoding and weighted embeddings, making it suitable for various machine learning applications.

Usage Examples

Basic Index-based Encoding

from segger.data.transcript_embedding import TranscriptEmbedding
import pandas as pd

# Create a list of gene names
gene_names = ["GENE1", "GENE2", "GENE3", "GENE4"]

# Initialize embedding without weights (index-based)
embedding = TranscriptEmbedding(classes=gene_names)

# Encode transcript labels
transcript_labels = ["GENE1", "GENE3", "GENE2"]
encoded = embedding.embed(transcript_labels)
# Returns: tensor([0, 2, 1])

Weighted Embeddings

import pandas as pd

# Create weights DataFrame
weights_df = pd.DataFrame({
    'weight1': [0.1, 0.2, 0.3, 0.4],
    'weight2': [0.5, 0.6, 0.7, 0.8]
}, index=["GENE1", "GENE2", "GENE3", "GENE4"])

# Initialize embedding with weights
embedding = TranscriptEmbedding(
    classes=gene_names,
    weights=weights_df
)

# Encode transcript labels
transcript_labels = ["GENE1", "GENE3"]
encoded = embedding.embed(transcript_labels)
# Returns: tensor([[0.1, 0.5], [0.3, 0.7]])

Integration with PyTorch

import torch
from segger.data.transcript_embedding import TranscriptEmbedding

# Create embedding module
embedding = TranscriptEmbedding(classes=gene_names)

# Use in a neural network
class TranscriptEncoder(torch.nn.Module):
    def __init__(self, gene_names):
        super().__init__()
        self.embedding = TranscriptEmbedding(gene_names)
        self.projection = torch.nn.Linear(len(gene_names), 128)

    def forward(self, transcript_labels):
        encoded = self.embedding.embed(transcript_labels)
        projected = self.projection(encoded)
        return projected

# Initialize and use
encoder = TranscriptEncoder(gene_names)
output = encoder(["GENE1", "GENE2"])