Segger Model Training¶
Overview¶
Training the Segger model involves optimizing a Graph Neural Network for transcript-to-cell link prediction in spatial transcriptomics data. The training process leverages PyTorch Lightning for scalable multi-GPU training, with specialized data handling and validation strategies designed for heterogeneous graphs.
Training Framework¶
PyTorch Lightning Integration¶
Segger uses PyTorch Lightning for training orchestration, providing:
- Multi-GPU Training: Automatic data parallel training across multiple devices
- Mixed Precision: Support for 16-bit mixed precision training
- Distributed Training: Multi-node training capabilities
- Automatic Logging: Built-in metrics tracking and visualization
- Checkpoint Management: Automatic model saving and restoration
Training Architecture¶
Training Tiles → Data Loaders → Segger Model → Link Prediction → Loss Computation → Optimization
↓ ↓ ↓ ↓ ↓ ↓
Spatial Graphs Mini-batches GNN Forward Similarity Binary CE Adam Optimizer
Validation Set GPU Transfer Attention Scores Loss + AUROC Weight Updates
Data Preparation¶
Training Data Structure¶
Training data consists of spatial tiles represented as PyTorch Geometric graphs:
# Each tile contains:
data = {
"tx": { # Transcript nodes
"id": transcript_ids,
"pos": spatial_coordinates,
"x": feature_vectors
},
"bd": { # Boundary nodes
"id": boundary_ids,
"pos": centroid_coordinates,
"x": geometric_features
},
"tx,neighbors,tx": { # Transcript proximity edges
"edge_index": neighbor_connections
},
"tx,belongs,bd": { # Transcript-boundary edges
"edge_index": containment_relationships,
"edge_label": positive/negative labels
}
}
Data Splitting Strategy¶
Tiles are randomly assigned to training, validation, and test sets:
# Recommended split ratios
train_ratio = 0.7 # 70% for training
val_ratio = 0.2 # 20% for validation
test_ratio = 0.1 # 10% for testing
# Spatial-aware splitting ensures:
# - No information leakage between splits
# - Representative spatial coverage in each split
# - Balanced distribution of cell types
Negative Edge Sampling¶
To handle class imbalance, negative edges are sampled during training:
# Sample negative edges at 1:5 ratio (positive:negative)
neg_sampling_ratio = 5
# Negative edges represent:
# - Transcripts assigned to wrong cells
# - Random transcript-cell pairs
# - Spatially distant but transcriptionally similar pairs
Training Configuration¶
Model Parameters¶
Key training parameters based on the Segger paper:
# Architecture configuration
model_config = {
'num_tx_tokens': 5000, # Vocabulary size (adjust for dataset)
'init_emb': 16, # Initial embedding dimension
'hidden_channels': 64, # Hidden layer size
'num_mid_layers': 3, # Number of GAT layers
'out_channels': 32, # Output dimension
'heads': 4 # Number of attention heads
}
# Training configuration
training_config = {
'learning_rate': 0.001, # Initial learning rate
'batch_size': 2, # Batch size per GPU
'max_epochs': 200, # Maximum training epochs
'weight_decay': 1e-5, # L2 regularization
'patience': 10, # Early stopping patience
}
Hardware Configuration¶
# GPU configuration
gpu_config = {
'accelerator': 'cuda', # Use CUDA acceleration
'devices': 4, # Number of GPUs
'strategy': 'ddp', # Distributed data parallel
'precision': '16-mixed' # Mixed precision training
}
# Memory optimization
memory_config = {
'gradient_clip_val': 1.0, # Gradient clipping
'accumulate_grad_batches': 1, # Gradient accumulation
'num_workers': 4 # Data loading workers
}
Training Process¶
Training Loop¶
The training process follows this sequence:
# Training loop (PyTorch Lightning handles this automatically)
for epoch in range(max_epochs):
# Training phase
for batch in train_loader:
# Forward pass
embeddings = model(batch.x, batch.edge_index)
# Link prediction
scores = model.decode(embeddings, batch.edge_label_index)
# Loss computation
loss = criterion(scores, batch.edge_label)
# Backward pass
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Validation phase
for batch in val_loader:
with torch.no_grad():
embeddings = model(batch.x, batch.edge_index)
scores = model.decode(embeddings, batch.edge_label_index)
val_loss = criterion(scores, batch.edge_label)
# Compute metrics
auroc = compute_auroc(scores, batch.edge_label)
f1_score = compute_f1(scores, batch.edge_label)
Loss Function¶
The model uses binary cross-entropy loss for link prediction:
# Binary cross-entropy loss
criterion = nn.BCEWithLogitsLoss()
# Loss computation
loss = -Σ_(t_i,c_j) [y_ij log(σ(s_ij)) + (1-y_ij) log(1-σ(s_ij))]
# Where:
# y_ij: Ground truth label (1 for positive, 0 for negative)
# s_ij: Raw similarity score from model
# σ(s_ij): Sigmoid activation for probability
Optimization¶
Training uses the Adam optimizer with learning rate scheduling:
# Optimizer configuration
optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.999)
)
# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=max_epochs,
eta_min=1e-6
)
Validation and Monitoring¶
Validation Metrics¶
The model is evaluated using:
AUROC (Area Under ROC Curve)¶
def compute_auroc(scores, labels):
"""Compute Area Under ROC Curve for link prediction."""
fpr, tpr, _ = roc_curve(labels, scores)
return auc(fpr, tpr)
F1 Score¶
def compute_f1(scores, labels):
"""Compute F1 score for link prediction."""
predictions = (scores > 0.5).float()
return f1_score(labels, predictions)
Training Monitoring¶
PyTorch Lightning provides automatic logging:
# Metrics logged automatically
self.log('train_loss', train_loss, on_step=True, on_epoch=True)
self.log('val_loss', val_loss, on_epoch=True)
self.log('val_auroc', val_auroc, on_epoch=True)
self.log('val_f1', val_f1, on_epoch=True)
# Learning rate logging
self.log('lr', self.optimizer.param_groups[0]['lr'], on_epoch=True)
Early Stopping¶
Training stops automatically when validation performance plateaus:
# Early stopping callback
early_stopping = EarlyStopping(
monitor='val_auroc',
mode='max',
patience=patience,
verbose=True
)
# Model checkpoint callback
checkpoint_callback = ModelCheckpoint(
monitor='val_auroc',
mode='max',
save_top_k=3,
filename='segger-{epoch:02d}-{val_auroc:.3f}'
)
Multi-GPU Training¶
Data Parallel Strategy¶
Segger supports distributed training across multiple GPUs:
# Distributed training configuration
trainer = pl.Trainer(
accelerator='cuda',
devices=4, # Use 4 GPUs
strategy='ddp', # Distributed data parallel
precision='16-mixed', # Mixed precision
max_epochs=max_epochs,
callbacks=[early_stopping, checkpoint_callback]
)
Batch Size Scaling¶
# Effective batch size = batch_size × num_gpus
effective_batch_size = batch_size * num_gpus
# Example: batch_size=2, num_gpus=4
# Effective batch size = 8
# Adjust learning rate for larger effective batch size
scaled_lr = base_lr * (effective_batch_size / 32) # Linear scaling rule
Memory Management¶
# Memory optimization techniques
memory_config = {
'gradient_checkpointing': True, # Trade compute for memory
'find_unused_parameters': False, # Optimize for DDP
'sync_batchnorm': False, # Not needed for GNNs
'deterministic': False # Allow non-deterministic operations
}
Training Strategies¶
Learning Rate Scheduling¶
Cosine Annealing¶
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=max_epochs,
eta_min=1e-6
)
Warmup + Cosine¶
# Warmup for first 10% of training
warmup_epochs = int(0.1 * max_epochs)
def get_lr_multiplier(epoch):
if epoch < warmup_epochs:
return epoch / warmup_epochs
else:
# Cosine decay
progress = (epoch - warmup_epochs) / (max_epochs - warmup_epochs)
return 0.5 * (1 + math.cos(math.pi * progress))
Regularization Techniques¶
Weight Decay¶
# L2 regularization in optimizer
optimizer = torch.optim.AdamW(
model.parameters(),
weight_decay=1e-5
)
Dropout (Optional)¶
# Add dropout to attention layers if needed
class SeggerWithDropout(Segger):
def __init__(self, *args, dropout=0.1, **kwargs):
super().__init__(*args, **kwargs)
self.dropout = nn.Dropout(dropout)
def forward(self, x, edge_index):
# Apply dropout after attention layers
x = super().forward(x, edge_index)
x = self.dropout(x)
return x
Performance Optimization¶
Mixed Precision Training¶
# Enable mixed precision for faster training
trainer = pl.Trainer(
precision='16-mixed', # 16-bit mixed precision
# Automatic mixed precision provides:
# - Faster training (1.5-2x speedup)
# - Lower memory usage
# - Maintained numerical stability
)
Gradient Accumulation¶
# Accumulate gradients over multiple batches
trainer = pl.Trainer(
accumulate_grad_batches=4, # Effective batch size = batch_size × 4
# Useful when:
# - GPU memory is limited
# - Large effective batch size is desired
# - Training stability is important
)
Data Loading Optimization¶
# Optimize data loading
dataloader_config = {
'num_workers': 4, # Parallel data loading
'pin_memory': True, # Faster GPU transfer
'persistent_workers': True, # Keep workers alive between epochs
'prefetch_factor': 2 # Prefetch batches
}
Troubleshooting¶
Common Training Issues¶
Training Instability¶
# Solutions:
# 1. Reduce learning rate
learning_rate = 0.0001 # Reduce from 0.001
# 2. Add gradient clipping
trainer = pl.Trainer(gradient_clip_val=1.0)
# 3. Check data quality and normalization
Memory Errors¶
# Solutions:
# 1. Reduce batch size
batch_size = 1 # Reduce from 2
# 2. Enable gradient checkpointing
trainer = pl.Trainer(enable_checkpointing=True)
# 3. Use mixed precision
trainer = pl.Trainer(precision='16-mixed')
Poor Convergence¶
# Solutions:
# 1. Check learning rate schedule
# 2. Verify data preprocessing
# 3. Adjust model architecture
# 4. Check for data leakage
Performance Monitoring¶
# Monitor training progress
class TrainingMonitor(pl.Callback):
def on_train_epoch_end(self, trainer, pl_module):
# Log training metrics
train_loss = trainer.callback_metrics['train_loss']
print(f"Epoch {trainer.current_epoch}: Train Loss = {train_loss:.4f}")
def on_validation_epoch_end(self, trainer, pl_module):
# Log validation metrics
val_auroc = trainer.callback_metrics['val_auroc']
val_f1 = trainer.callback_metrics['val_f1']
print(f"Validation: AUROC = {val_auroc:.4f}, F1 = {val_f1:.4f}")
Best Practices¶
Training Configuration¶
- Start with Default Parameters: Use recommended settings from the Segger paper
- Monitor Validation Metrics: Focus on AUROC and F1 score, not just loss
- Use Early Stopping: Prevent overfitting with patience-based stopping
- Enable Mixed Precision: Use 16-bit training for speed and memory efficiency
Data Preparation¶
- Quality Control: Filter low-quality transcripts and boundaries
- Spatial Validation: Ensure train/val/test splits are spatially representative
- Feature Normalization: Normalize transcript and boundary features
- Negative Sampling: Use appropriate negative sampling ratios
Hardware Utilization¶
- Multi-GPU Training: Scale training across multiple GPUs
- Memory Optimization: Use mixed precision and gradient checkpointing
- Data Loading: Optimize data loading with multiple workers
- Batch Size: Use largest batch size that fits in memory
Future Enhancements¶
Planned training improvements include:
- Advanced Scheduling: More sophisticated learning rate schedules
- Automated Hyperparameter Tuning: Integration with Optuna or similar tools
- Curriculum Learning: Progressive difficulty training strategies
- Multi-task Training: Joint training on multiple objectives
- Federated Learning: Distributed training across multiple institutions