Training the segger
Model¶
The Model¶
The segger
model is a graph neural network designed to handle heterogeneous graphs with two primary node types: transcripts and nuclei or cell boundaries. It leverages attention-based graph convolutional layers to compute node embeddings and relationships in spatial transcriptomics data. The architecture includes an initial embedding layer for node feature transformation, multiple graph attention layers (GATv2Conv), and residual linear connections.
Model Architecture¶
-
Input Node Features:
For input node features \(\mathbf{x}\), the model distinguishes between transcript nodes and boundary (or nucleus) nodes. -
Transcript Nodes: If \(\mathbf{x}\) is 1-dimensional (e.g., for tokenized transcript data), the model applies an embedding layer:
where \(i\) is the transcript token index.
- Nuclei or Cell Boundary Nodes: If \(\mathbf{x}\) has multiple dimensions, the model applies a linear transformation:
where \(\mathbf{W}\) is a learnable weight matrix.
- Graph Attention Layers (GATv2Conv):
The node embeddings are updated through multiple attention-based layers. The update for a node \(i\) at layer \(l+1\) is given by:
where: - \(\alpha_{ij}\) is the attention coefficient between node \(i\) and node \(j\), computed as:
-
\(\mathbf{a}\) is a learnable attention vector.
-
Residual Linear Connections:
After each attention layer, a residual connection is added via a linear transformation to stabilize learning:
where \(\mathbf{W}_{res}\) is a residual weight matrix.
- L2 Normalization:
Finally, the embeddings are normalized using L2 normalization:
ensuring the final node embeddings have unit norm.
Heterogeneous Graph Transformation¶
In the next step, the segger
model is transformed into a heterogeneous graph neural network using PyTorch Geometric's to_hetero
function. This transformation enables the model to handle distinct node and edge types (transcripts and nuclei or cell boundaries) with separate mechanisms for modeling their relationships.
Usage¶
To instantiate and run the segger
model:
model = segger(
num_tx_tokens=5000, # Number of unique 'tx' tokens
init_emb=32, # Initial embedding dimension
hidden_channels=64, # Number of hidden channels
num_mid_layers=2, # Number of middle layers
out_channels=128, # Number of output channels
heads=4, # Number of attention heads
)
output = model(x, edge_index)
Once transformed to a heterogeneous model and trained using PyTorch Lightning, the model can efficiently learn relationships between transcripts and nuclei or cell boundaries.
Training the heterogeneous GNN with pytorch-lightning
¶
The training module makes use of PyTorch Lightning for efficient and scalable training, alongside PyTorch Geometric for processing the graph-based data. The module is built to handle multi-GPU setups and allows the flexibility to adjust hyperparameters, aggregation methods, and embedding sizes.
The SpatialTranscriptomicsDataset
class is used to load and manage spatial transcriptomics data stored in the format of PyTorch Geometric Data
objects. It inherits from InMemoryDataset
to load preprocessed datasets, ensuring efficient in-memory data handling for training and validation phases.
Example Training Command¶
python scripts/train_model.py \
--train_dir path/to/train/tiles \
--val_dir path/to/val/tiles \
--batch_size_train 4 \
--batch_size_val 4 \
--num_tx_tokens 500 \
--init_emb 8 \
--hidden_channels 64 \
--out_channels 16 \
--heads 4 \
--mid_layers 1 \
--aggr sum \
--accelerator cuda \
--strategy auto \
--precision 16-mixed \
--devices 4 \
--epochs 100 \
--default_root_dir ./models/clean2
The scripts/train_model.py
file can be found on the github repo. This example submits a job to train the segger
model on 4 GPUs with a batch size of 4 for both training and validation, utilizing 16-bit mixed precision.