Skip to content

segger.models.segger_model

The segger_model module contains the core Graph Neural Network architecture for spatial transcriptomics analysis. This module implements the Segger class, a sophisticated attention-based GNN designed specifically for processing heterogeneous graphs with transcript and boundary nodes.

Core Classes

Segger

Bases: Module

Source code in src/segger/models/segger_model.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class Segger(torch.nn.Module):
    def __init__(
        self,
        num_tx_tokens: int,
        init_emb: int = 16,
        hidden_channels: int = 32,
        num_mid_layers: int = 3,
        out_channels: int = 32,
        heads: int = 3,
    ):
        """
        Initializes the Segger model.

        Args:
            num_tx_tokens (int)  : Number of unique 'tx' tokens for embedding.
            init_emb (int)       : Initial embedding size for both 'tx' and boundary (non-token) nodes.
            hidden_channels (int): Number of hidden channels.
            num_mid_layers (int) : Number of hidden layers (excluding first and last layers).
            out_channels (int)   : Number of output channels.
            heads (int)          : Number of attention heads.
        """
        super().__init__()

        # Embedding for 'tx' (transcript) nodes
        self.tx_embedding = Embedding(num_tx_tokens, init_emb)

        # Linear layer for boundary (non-token) nodes
        self.lin0 = Linear(-1, init_emb, bias=False)

        # First GATv2Conv layer
        self.conv_first = GATv2Conv(
            (-1, -1), hidden_channels, heads=heads, add_self_loops=False
        )
        # self.lin_first = Linear(-1, hidden_channels * heads)

        # Middle GATv2Conv layers
        self.num_mid_layers = num_mid_layers
        if num_mid_layers > 0:
            self.conv_mid_layers = torch.nn.ModuleList()
            # self.lin_mid_layers = torch.nn.ModuleList()
            for _ in range(num_mid_layers):
                self.conv_mid_layers.append(
                    GATv2Conv(
                        (-1, -1), hidden_channels, heads=heads, add_self_loops=False
                    )
                )
                # self.lin_mid_layers.append(Linear(-1, hidden_channels * heads))

        # Last GATv2Conv layer
        self.conv_last = GATv2Conv(
            (-1, -1), out_channels, heads=heads, add_self_loops=False
        )
        # self.lin_last = Linear(-1, out_channels * heads)

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        """
        Forward pass for the Segger model.

        Args:
            x (Tensor): Node features.
            edge_index (Tensor): Edge indices.

        Returns:
            Tensor: Output node embeddings.
        """
        x = torch.nan_to_num(x, nan=0)
        is_one_dim = (x.ndim == 1) * 1
        x = x[:, None]
        x = self.tx_embedding(
            ((x.sum(-1) * is_one_dim).int())
        ) * is_one_dim + self.lin0(x.float()) * (1 - is_one_dim)
        x = x.squeeze()
        # First layer
        x = x.relu()
        x = self.conv_first(x, edge_index)  # + self.lin_first(x)
        x = x.relu()

        # Middle layers
        if self.num_mid_layers > 0:
            for conv_mid in self.conv_mid_layers:
                x = conv_mid(x, edge_index)  # + lin_mid(x)
                x = x.relu()

        # Last layer
        x = self.conv_last(x, edge_index)  # + self.lin_last(x)

        # x = F.normalize(x)

        return x

    def decode(self, z: Tensor, edge_index: Union[Tensor]) -> Tensor:
        """
        Decode the node embeddings to predict edge values.

        Args:
            z (Tensor): Node embeddings.
            edge_index (EdgeIndex): Edge label indices.

        Returns:
            Tensor: Predicted edge values.
        """
        return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)

conv_first instance-attribute

conv_first = GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False)

conv_last instance-attribute

conv_last = GATv2Conv((-1, -1), out_channels, heads=heads, add_self_loops=False)

conv_mid_layers instance-attribute

conv_mid_layers = ModuleList()

lin0 instance-attribute

lin0 = Linear(-1, init_emb, bias=False)

num_mid_layers instance-attribute

num_mid_layers = num_mid_layers

tx_embedding instance-attribute

tx_embedding = Embedding(num_tx_tokens, init_emb)

__init__

__init__(num_tx_tokens, init_emb=16, hidden_channels=32, num_mid_layers=3, out_channels=32, heads=3)

Initializes the Segger model.

Parameters:

Name Type Description Default
num_tx_tokens int)

Number of unique 'tx' tokens for embedding.

required
init_emb int)

Initial embedding size for both 'tx' and boundary (non-token) nodes.

16
hidden_channels int

Number of hidden channels.

32
num_mid_layers int)

Number of hidden layers (excluding first and last layers).

3
out_channels int)

Number of output channels.

32
heads int)

Number of attention heads.

3
Source code in src/segger/models/segger_model.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def __init__(
    self,
    num_tx_tokens: int,
    init_emb: int = 16,
    hidden_channels: int = 32,
    num_mid_layers: int = 3,
    out_channels: int = 32,
    heads: int = 3,
):
    """
    Initializes the Segger model.

    Args:
        num_tx_tokens (int)  : Number of unique 'tx' tokens for embedding.
        init_emb (int)       : Initial embedding size for both 'tx' and boundary (non-token) nodes.
        hidden_channels (int): Number of hidden channels.
        num_mid_layers (int) : Number of hidden layers (excluding first and last layers).
        out_channels (int)   : Number of output channels.
        heads (int)          : Number of attention heads.
    """
    super().__init__()

    # Embedding for 'tx' (transcript) nodes
    self.tx_embedding = Embedding(num_tx_tokens, init_emb)

    # Linear layer for boundary (non-token) nodes
    self.lin0 = Linear(-1, init_emb, bias=False)

    # First GATv2Conv layer
    self.conv_first = GATv2Conv(
        (-1, -1), hidden_channels, heads=heads, add_self_loops=False
    )
    # self.lin_first = Linear(-1, hidden_channels * heads)

    # Middle GATv2Conv layers
    self.num_mid_layers = num_mid_layers
    if num_mid_layers > 0:
        self.conv_mid_layers = torch.nn.ModuleList()
        # self.lin_mid_layers = torch.nn.ModuleList()
        for _ in range(num_mid_layers):
            self.conv_mid_layers.append(
                GATv2Conv(
                    (-1, -1), hidden_channels, heads=heads, add_self_loops=False
                )
            )
            # self.lin_mid_layers.append(Linear(-1, hidden_channels * heads))

    # Last GATv2Conv layer
    self.conv_last = GATv2Conv(
        (-1, -1), out_channels, heads=heads, add_self_loops=False
    )

decode

decode(z, edge_index)

Decode the node embeddings to predict edge values.

Parameters:

Name Type Description Default
z Tensor

Node embeddings.

required
edge_index EdgeIndex

Edge label indices.

required

Returns:

Name Type Description
Tensor Tensor

Predicted edge values.

Source code in src/segger/models/segger_model.py
106
107
108
109
110
111
112
113
114
115
116
117
def decode(self, z: Tensor, edge_index: Union[Tensor]) -> Tensor:
    """
    Decode the node embeddings to predict edge values.

    Args:
        z (Tensor): Node embeddings.
        edge_index (EdgeIndex): Edge label indices.

    Returns:
        Tensor: Predicted edge values.
    """
    return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)

forward

forward(x, edge_index)

Forward pass for the Segger model.

Parameters:

Name Type Description Default
x Tensor

Node features.

required
edge_index Tensor

Edge indices.

required

Returns:

Name Type Description
Tensor Tensor

Output node embeddings.

Source code in src/segger/models/segger_model.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
    """
    Forward pass for the Segger model.

    Args:
        x (Tensor): Node features.
        edge_index (Tensor): Edge indices.

    Returns:
        Tensor: Output node embeddings.
    """
    x = torch.nan_to_num(x, nan=0)
    is_one_dim = (x.ndim == 1) * 1
    x = x[:, None]
    x = self.tx_embedding(
        ((x.sum(-1) * is_one_dim).int())
    ) * is_one_dim + self.lin0(x.float()) * (1 - is_one_dim)
    x = x.squeeze()
    # First layer
    x = x.relu()
    x = self.conv_first(x, edge_index)  # + self.lin_first(x)
    x = x.relu()

    # Middle layers
    if self.num_mid_layers > 0:
        for conv_mid in self.conv_mid_layers:
            x = conv_mid(x, edge_index)  # + lin_mid(x)
            x = x.relu()

    # Last layer
    x = self.conv_last(x, edge_index)  # + self.lin_last(x)

    # x = F.normalize(x)

    return x

Overview

The Segger class implements a Graph Neural Network architecture specifically designed for spatial transcriptomics data. It uses Graph Attention Networks (GAT) with GATv2Conv layers to learn complex spatial relationships between transcripts and cellular boundaries.

Key Features

  • Heterogeneous Graph Processing: Automatically handles different node types (transcripts vs. boundaries)
  • Attention Mechanisms: GATv2Conv layers for learning spatial relationships
  • Configurable Architecture: Adjustable depth, width, and attention heads
  • PyTorch Integration: Native PyTorch module with full compatibility
  • Spatial Optimization: Designed specifically for spatial transcriptomics data

Architecture Details

Node Type Processing

The model automatically differentiates between node types based on input feature dimensions:

  1. Transcript Nodes (1D features): Processed through embedding layers
  2. Boundary Nodes (Multi-dimensional features): Processed through linear transformations

Layer Structure

Input Features → Node Type Detection → Feature Processing → GATv2Conv Layers → Output Embeddings
     ↓              ↓                    ↓                ↓                ↓
Transcripts    Auto-routing         Embedding/Linear   Attention Mech.   Learned Features
Boundaries     (1D vs Multi-D)     Transformations    Spatial Learning   Biological Insights

Attention Mechanism

The model uses Graph Attention Networks (GAT) with the following attention computation:

α_ij = softmax(LeakyReLU(a^T [Wh_i || Wh_j]))

Where: - α_ij is the attention coefficient between nodes i and j - a is a learnable attention vector - W is a learnable weight matrix - h_i, h_j are node features

Usage Examples

Basic Model Initialization

from segger.models.segger_model import Segger

# Initialize with default parameters
model = Segger(
    num_tx_tokens=5000,      # Number of unique transcript types
    init_emb=16,             # Initial embedding dimension
    hidden_channels=32,       # Hidden layer size
    num_mid_layers=3,        # Number of hidden layers
    out_channels=32,          # Output dimension
    heads=3                   # Number of attention heads
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Forward Pass

import torch

# Create sample data
batch_size = 100
num_nodes = 1000
x = torch.randn(num_nodes, 64)  # Node features
edge_index = torch.randint(0, num_nodes, (2, 2000))  # Edge indices

# Forward pass
with torch.no_grad():
    output = model(x, edge_index)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Edge index shape: {edge_index.shape}")

Training Configuration

import torch.nn as nn
import torch.optim as optim

# Model configuration for large dataset
model = Segger(
    num_tx_tokens=10000,     # Large vocabulary
    init_emb=64,             # Larger embeddings
    hidden_channels=128,      # Wider layers
    num_mid_layers=5,        # Deeper architecture
    out_channels=256,         # Rich output features
    heads=8                   # More attention heads
)

# Optimizer and loss
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()

# Training loop
model.train()
for epoch in range(100):
    optimizer.zero_grad()

    # Forward pass
    out = model(x, edge_index)

    # Compute loss (example: node classification)
    loss = criterion(out, labels)

    # Backward pass
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Model Parameters

Required Parameters

  • num_tx_tokens (int): Number of unique transcript types in your dataset
  • This determines the size of the transcript embedding layer
  • Should match the number of unique genes/transcripts in your data

Optional Parameters

  • init_emb (int, default=16): Initial embedding dimension
  • Used for both transcript embeddings and boundary feature transformation
  • Larger values provide more expressive features but increase memory usage

  • hidden_channels (int, default=32): Number of hidden channels

  • Size of intermediate layer representations
  • Affects model capacity and computational cost

  • num_mid_layers (int, default=3): Number of hidden GAT layers

  • More layers enable learning of more complex patterns
  • Balance between expressiveness and overfitting

  • out_channels (int, default=32): Output embedding dimension

  • Size of final node representations
  • Should match your downstream task requirements

  • heads (int, default=3): Number of attention heads

  • Multiple heads learn different types of relationships
  • More heads generally improve performance but increase computation

Architecture Components

1. Input Processing Layer

# Automatic node type detection and processing
if x.ndim == 1:  # Transcript nodes
    x = self.tx_embedding(x.int())
else:  # Boundary nodes
    x = self.lin0(x.float())

2. Graph Attention Layers

# First attention layer
x = F.relu(x)
x = self.conv_first(x, edge_index)
x = F.relu(x)

# Middle attention layers
for conv_mid in self.conv_mid_layers:
    x = conv_mid(x, edge_index)
    x = F.relu(x)

# Final attention layer
x = self.conv_last(x, edge_index)

3. Output Processing

# Final embeddings can be used for various tasks
# - Node classification
# - Link prediction
# - Graph-level tasks
# - Downstream analysis

Performance Characteristics

Computational Complexity

  • Time Complexity: O(|E| × F × H) per layer
  • |E|: Number of edges
  • F: Feature dimension
  • H: Number of attention heads

  • Memory Usage: Scales with graph size and model parameters

  • Node features: O(|V| × F)
  • Edge attention: O(|E| × H)
  • Model parameters: O(F² × L × H)

Optimization Features

  • Efficient Attention: GATv2Conv optimized for sparse graphs
  • Memory Management: Automatic handling of different node types
  • PyTorch Optimization: Leverages PyTorch's optimized operations
  • GPU Acceleration: Full CUDA support for training and inference

Integration with PyTorch Geometric

The model is designed to work seamlessly with PyTorch Geometric:

from torch_geometric.data import Data
from torch_geometric.transforms import ToUndirected

# Create PyG data object
data = Data(x=x, edge_index=edge_index)

# Apply transformations
data = ToUndirected()(data)

# Process with model
output = model(data.x, data.edge_index)

Training Strategies

1. Learning Rate Scheduling

from torch.optim.lr_scheduler import CosineAnnealingLR

scheduler = CosineAnnealingLR(optimizer, T_max=100)
# Use in training loop
scheduler.step()

2. Regularization

# Weight decay in optimizer
optimizer = optim.AdamW(model.parameters(), weight_decay=1e-5)

# Dropout (can be added to model if needed)
dropout = nn.Dropout(0.1)
x = dropout(x)

3. Early Stopping

# Monitor validation loss
best_val_loss = float('inf')
patience = 10
patience_counter = 0

for epoch in range(max_epochs):
    # Training...
    val_loss = validate(model, val_loader)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print("Early stopping triggered")
        break

Best Practices

Model Architecture Selection

  • Small Datasets (< 10k nodes): Use fewer layers and smaller dimensions
  • Medium Datasets (10k-100k nodes): Balanced architecture with moderate complexity
  • Large Datasets (> 100k nodes): Deeper models with more attention heads

Training Configuration

  • Learning Rate: Start with 0.001 and adjust based on convergence
  • Batch Size: Use largest size that fits in memory
  • Regularization: Apply weight decay and consider dropout
  • Monitoring: Track both training and validation metrics

Data Preparation

  • Feature Normalization: Normalize input features for stable training
  • Graph Construction: Ensure proper edge construction for spatial relationships
  • Validation Strategy: Use spatial-aware validation splits
  • Data Augmentation: Consider spatial augmentations for robustness

Common Use Cases

1. Cell Type Classification

# Train model for cell type prediction
model = Segger(num_tx_tokens=5000, out_channels=num_cell_types)
# ... training ...
predictions = model(x, edge_index)
cell_types = torch.argmax(predictions, dim=1)

2. Spatial Relationship Learning

# Learn spatial relationships between transcripts and boundaries
embeddings = model(x, edge_index)
# Use embeddings for downstream analysis
similarity = torch.mm(embeddings, embeddings.t())

3. Tissue Architecture Analysis

# Analyze tissue-level patterns
model = Segger(num_tx_tokens=5000, out_channels=128)
embeddings = model(x, edge_index)
# Apply clustering or other analysis to embeddings

Troubleshooting

Common Issues

  1. Memory Errors: Reduce model size or batch size
  2. Training Instability: Lower learning rate or add regularization
  3. Poor Performance: Check data quality and feature engineering
  4. Slow Convergence: Adjust learning rate or model architecture

Performance Tips

  1. Use appropriate model size for your dataset
  2. Monitor training metrics to detect issues early
  3. Validate on held-out data to prevent overfitting
  4. Use mixed precision training for faster training on modern GPUs

Future Enhancements

Planned improvements include:

  • Additional Attention Types: Support for different attention mechanisms
  • Multi-modal Integration: Support for additional data types
  • Distributed Training: Multi-GPU and multi-node support
  • Model Compression: Efficient deployment of trained models
  • Interpretability Tools: Understanding learned spatial relationships

Dependencies

  • PyTorch: Core neural network functionality
  • PyTorch Geometric: Graph neural network operations
  • NumPy: Numerical operations (optional, for data preprocessing)

Contributing

Contributions to improve the Segger model are welcome:

  • Architecture Improvements: Better attention mechanisms and layer designs
  • Performance Optimization: Faster training and inference
  • Feature Extensions: Support for additional node and edge types
  • Testing: Comprehensive test coverage and validation