Skip to content

segger.training

training module for Segger.

Contains the implementation of the Segger model using Graph Neural Networks.

LitSegger

LitSegger(learning_rate=0.001, **kwargs)

Bases: LightningModule

LitSegger is a PyTorch Lightning module for training and validating the Segger model.

Attributes

model : Segger The Segger model wrapped with PyTorch Geometric's to_hetero for heterogeneous graph support. validation_step_outputs : list A list to store outputs from the validation steps. criterion : torch.nn.Module The loss function used for training, specifically BCEWithLogitsLoss.

Initializes the LitSegger module with the given parameters.

Parameters

learning_rate : float The learning rate for the optimizer. **kwargs : dict Keyword arguments for initializing the module. Specific parameters depend on whether the module is initialized with new parameters or components.

Source code in src/segger/training/train.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def __init__(self, learning_rate: float = 1e-3, **kwargs):
    """
    Initializes the LitSegger module with the given parameters.

    Parameters
    ----------
    learning_rate : float
        The learning rate for the optimizer.
    **kwargs : dict
        Keyword arguments for initializing the module. Specific parameters
        depend on whether the module is initialized with new parameters or components.
    """
    super().__init__()
    new_args = inspect.getfullargspec(self.from_new)[0][1:]
    cmp_args = inspect.getfullargspec(self.from_components)[0][1:]

    # Initialize with new parameters (ensure num_tx_tokens is passed here)
    if set(kwargs.keys()) == set(new_args):
        self.from_new(**kwargs)

    # Initialize with existing components
    elif set(kwargs.keys()) == set(cmp_args):
        self.from_components(**kwargs)

    # Handle invalid arguments
    else:
        raise ValueError(
            f"Supplied kwargs do not match either constructor. Should be one of '{new_args}' or '{cmp_args}'."
        )

    self.validation_step_outputs = []
    self.criterion = torch.nn.BCEWithLogitsLoss()
    self.learning_rate = learning_rate

configure_optimizers

configure_optimizers()

Configures the optimizer for training.

Returns

torch.optim.Optimizer The optimizer for training.

Source code in src/segger/training/train.py
215
216
217
218
219
220
221
222
223
224
225
def configure_optimizers(self) -> torch.optim.Optimizer:
    """
    Configures the optimizer for training.

    Returns
    -------
    torch.optim.Optimizer
        The optimizer for training.
    """
    optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    return optimizer

forward

forward(batch)

Forward pass for the batch of data.

Parameters

batch : SpatialTranscriptomicsDataset The batch of data, including node features and edge indices.

Returns

torch.Tensor The output of the model.

Source code in src/segger/training/train.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def forward(self, batch: SpatialTranscriptomicsDataset) -> torch.Tensor:
    """
    Forward pass for the batch of data.

    Parameters
    ----------
    batch : SpatialTranscriptomicsDataset
        The batch of data, including node features and edge indices.

    Returns
    -------
    torch.Tensor
        The output of the model.
    """
    z = self.model(batch.x_dict, batch.edge_index_dict)
    edge_label_index = batch["tx", "belongs", "bd"].edge_label_index
    output = self.model.decode(z, edge_label_index)
    return output

from_components

from_components(model)

Initializes the LitSegger module with existing Segger components.

Parameters

model : Segger The Segger model to be used.

Source code in src/segger/training/train.py
110
111
112
113
114
115
116
117
118
119
def from_components(self, model: Segger):
    """
    Initializes the LitSegger module with existing Segger components.

    Parameters
    ----------
    model : Segger
        The Segger model to be used.
    """
    self.model = model

from_new

from_new(is_token_based, num_node_features, init_emb, hidden_channels, out_channels, heads, num_mid_layers, aggr)

Initializes the LitSegger module with new parameters.

Parameters

is_token_based : int Whether the model is using token-based embeddings or scRNAseq embeddings. num_node_features : dict[str, int] Number of node features for each node type. init_emb : int Initial embedding size. hidden_channels : int Number of hidden channels. out_channels : int Number of output channels. heads : int Number of attention heads. aggr : str Aggregation method for heterogeneous graph conversion. num_mid_layers: int Number of hidden layers (excluding first and last layers). metadata : Union[Tuple, Metadata] Metadata for heterogeneous graph structure.

Source code in src/segger/training/train.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def from_new(
    self,
    is_token_based: int,
    num_node_features: dict[str, int],
    init_emb: int,
    hidden_channels: int,
    out_channels: int,
    heads: int,
    num_mid_layers: int,
    aggr: str,
):
    """
    Initializes the LitSegger module with new parameters.

    Parameters
    ----------
    is_token_based : int
        Whether the model is using token-based embeddings or scRNAseq embeddings.
    num_node_features : dict[str, int]
        Number of node features for each node type.
    init_emb : int
        Initial embedding size.
    hidden_channels : int
        Number of hidden channels.
    out_channels : int
        Number of output channels.
    heads : int
        Number of attention heads.
    aggr : str
        Aggregation method for heterogeneous graph conversion.
    num_mid_layers: int
        Number of hidden layers (excluding first and last layers).
    metadata : Union[Tuple, Metadata]
        Metadata for heterogeneous graph structure.
    """
    # Create the Segger model (ensure num_tx_tokens is passed here)
    self.model = Segger(
        is_token_based=is_token_based,
        num_node_features=num_node_features,
        init_emb=init_emb,
        hidden_channels=hidden_channels,
        out_channels=out_channels,
        heads=heads,
        num_mid_layers=num_mid_layers,
    )
    # Save hyperparameters
    self.save_hyperparameters()

training_step

training_step(batch, batch_idx)

Defines the training step.

Parameters

batch : Any The batch of data. batch_idx : int The index of the batch.

Returns

torch.Tensor The loss value for the current training step.

Source code in src/segger/training/train.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
    """
    Defines the training step.

    Parameters
    ----------
    batch : Any
        The batch of data.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    torch.Tensor
        The loss value for the current training step.
    """
    # Get edge labels
    edge_label_index = batch["tx", "belongs", "bd"].edge_label_index
    edge_label = batch["tx", "belongs", "bd"].edge_label

    # Forward pass to get the logits
    z = self.model(batch.x_dict, batch.edge_index_dict)
    output = self.model.decode(z, edge_label_index)

    # Compute binary cross-entropy loss with logits (no sigmoid here)
    loss = self.criterion(output, edge_label)

    # Log the training loss
    self.log("train_loss", loss, prog_bar=True, batch_size=batch.num_graphs)
    return loss

validation_step

validation_step(batch, batch_idx)

Defines the validation step.

Parameters

batch : Any The batch of data. batch_idx : int The index of the batch.

Returns

torch.Tensor The loss value for the current validation step.

Source code in src/segger/training/train.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
    """
    Defines the validation step.

    Parameters
    ----------
    batch : Any
        The batch of data.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    torch.Tensor
        The loss value for the current validation step.
    """
    # Get edge labels
    edge_label_index = batch["tx", "belongs", "bd"].edge_label_index
    edge_label = batch["tx", "belongs", "bd"].edge_label

    # Forward pass to get the logits
    z = self.model(batch.x_dict, batch.edge_index_dict)
    output = self.model.decode(z, edge_label_index)

    # Compute binary cross-entropy loss with logits (no sigmoid here)
    loss = self.criterion(output, edge_label)

    # Apply sigmoid to logits for AUROC and F1 metrics
    out_values_prob = torch.sigmoid(output)

    # Compute metrics
    auroc = torchmetrics.AUROC(task="binary")
    auroc_res = auroc(out_values_prob, edge_label)

    f1 = F1Score(task="binary").to(self.device)
    f1_res = f1(out_values_prob, edge_label)

    # Log validation metrics
    self.log("validation_loss", loss, batch_size=batch.num_graphs)
    self.log("validation_auroc", auroc_res, prog_bar=True, batch_size=batch.num_graphs)
    self.log("validation_f1", f1_res, prog_bar=True, batch_size=batch.num_graphs)

    return loss

Segger

Segger(is_token_based, num_node_features, init_emb=16, hidden_channels=32, num_mid_layers=3, out_channels=32, heads=3)

Bases: Module

Initializes the Segger model.

Parameters:

Name Type Description Default
is_token_based int)

Whether the model is using token-based embeddings or scRNAseq embeddings.

required
num_node_features dict[str, int]

Number of node features for each node type.

required
init_emb int)

Initial embedding size for both 'tx' and boundary (non-token) nodes.

16
hidden_channels int)

Number of hidden channels.

32
num_mid_layers int)

Number of hidden layers (excluding first and last layers).

3
out_channels int)

Number of output channels.

32
heads int)

Number of attention heads.

3
Source code in src/segger/models/segger_model.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def __init__(
    self,
    is_token_based: int,
    num_node_features: dict[str, int],
    init_emb: int = 16,
    hidden_channels: int = 32,
    num_mid_layers: int = 3,
    out_channels: int = 32,
    heads: int = 3,
):
    """
    Initializes the Segger model.

    Args:
        is_token_based (int)   : Whether the model is using token-based embeddings or scRNAseq embeddings.
        num_node_features (dict[str, int]): Number of node features for each node type.
        init_emb (int)         : Initial embedding size for both 'tx' and boundary (non-token) nodes.
        hidden_channels (int)  : Number of hidden channels.
        num_mid_layers (int)   : Number of hidden layers (excluding first and last layers).
        out_channels (int)     : Number of output channels.
        heads (int)            : Number of attention heads.
    """
    super().__init__()

    # Initialize node embeddings
    if is_token_based:
        # Using token-based embeddings for transcript ('tx') nodes
        self.node_init = nn.ModuleDict(
            {
                "tx": nn.Embedding(num_node_features["tx"], init_emb),
                "bd": nn.Linear(num_node_features["bd"], init_emb),
            }
        )
    else:
        # Using scRNAseq embeddings (i.e. prior biological knowledge) for transcript ('tx') nodes
        self.node_init = nn.ModuleDict(
            {
                "tx": nn.Linear(num_node_features["tx"], init_emb),
                "bd": nn.Linear(num_node_features["bd"], init_emb),
            }
        )

    # First GATv2Conv layer
    self.conv1 = SkipGAT(init_emb, hidden_channels, heads)

    # Middle GATv2Conv layers
    self.num_mid_layers = num_mid_layers
    if num_mid_layers > 0:
        self.conv_mid_layers = nn.ModuleList()
        for _ in range(num_mid_layers):
            self.conv_mid_layers.append(SkipGAT(heads * hidden_channels, hidden_channels, heads))

    # Last GATv2Conv layer
    self.conv_last = SkipGAT(heads * hidden_channels, out_channels, heads)

    # Finalize node embeddings
    self.node_final = HeteroDictLinear(heads * out_channels, out_channels, types=("tx", "bd"))

decode

decode(z_dict, edge_index)

Decode the node embeddings to predict edge values.

Parameters:

Name Type Description Default
z_dict dict[str, Tensor]

Node embeddings for each node type.

required
edge_index EdgeIndex

Edge label indices.

required

Returns:

Name Type Description
Tensor Tensor

Predicted edge values.

Source code in src/segger/models/segger_model.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def decode(
    self,
    z_dict: dict[str, Tensor],
    edge_index: Union[Tensor],
) -> Tensor:
    """
    Decode the node embeddings to predict edge values.

    Args:
        z_dict (dict[str, Tensor]): Node embeddings for each node type.
        edge_index (EdgeIndex): Edge label indices.

    Returns:
        Tensor: Predicted edge values.
    """
    z_left = z_dict["tx"][edge_index[0]]
    z_right = z_dict["bd"][edge_index[1]]
    return (z_left * z_right).sum(dim=-1)

forward

forward(x_dict, edge_index_dict)

Forward pass for the Segger model.

Parameters:

Name Type Description Default
x_dict dict[str, Tensor]

Node features for each node type.

required
edge_index_dict dict[str, Tensor]

Edge indices for each edge type.

required
Source code in src/segger/models/segger_model.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def forward(
    self,
    x_dict: dict[str, Tensor],
    edge_index_dict: dict[str, Tensor],
) -> dict[str, Tensor]:
    """
    Forward pass for the Segger model.

    Args:
        x_dict (dict[str, Tensor]): Node features for each node type.
        edge_index_dict (dict[str, Tensor]): Edge indices for each edge type.
    """

    x_dict = {key: self.node_init[key](x) for key, x in x_dict.items()}

    x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}

    x_dict = self.conv1(x_dict, edge_index_dict)

    if self.num_mid_layers > 0:
        for i in range(self.num_mid_layers):
            x_dict = self.conv_mid_layers[i](x_dict, edge_index_dict)

    x_dict = self.conv_last(x_dict, edge_index_dict)

    x_dict = self.node_final(x_dict)

    return x_dict

SpatialTranscriptomicsDataset

SpatialTranscriptomicsDataset(root, transform=None, pre_transform=None, pre_filter=None)

Bases: InMemoryDataset

A dataset class for handling SpatialTranscriptomics spatial transcriptomics data.

Attributes:

Name Type Description
root str

The root directory where the dataset is stored.

transform callable

A function/transform that takes in a Data object and returns a transformed version.

pre_transform callable

A function/transform that takes in a Data object and returns a transformed version.

pre_filter callable

A function that takes in a Data object and returns a boolean indicating whether to keep it.

Initialize the SpatialTranscriptomicsDataset.

Parameters:

Name Type Description Default
root str

Root directory where the dataset is stored.

required
transform callable

A function/transform that takes in a Data object and returns a transformed version. Defaults to None.

None
pre_transform callable

A function/transform that takes in a Data object and returns a transformed version. Defaults to None.

None
pre_filter callable

A function that takes in a Data object and returns a boolean indicating whether to keep it. Defaults to None.

None
Source code in src/segger/data/utils.py
399
400
401
402
403
404
405
406
407
408
409
410
def __init__(
    self, root: str, transform: Callable = None, pre_transform: Callable = None, pre_filter: Callable = None
):
    """Initialize the SpatialTranscriptomicsDataset.

    Args:
        root (str): Root directory where the dataset is stored.
        transform (callable, optional): A function/transform that takes in a Data object and returns a transformed version. Defaults to None.
        pre_transform (callable, optional): A function/transform that takes in a Data object and returns a transformed version. Defaults to None.
        pre_filter (callable, optional): A function that takes in a Data object and returns a boolean indicating whether to keep it. Defaults to None.
    """
    super().__init__(root, transform, pre_transform, pre_filter)

processed_file_names property

processed_file_names

Return a list of processed file names in the processed directory.

Returns:

Type Description
List[str]

List[str]: List of processed file names.

raw_file_names property

raw_file_names

Return a list of raw file names in the raw directory.

Returns:

Type Description
List[str]

List[str]: List of raw file names.

download

download()

Download the raw data. This method should be overridden if you need to download the data.

Source code in src/segger/data/utils.py
430
431
432
def download(self) -> None:
    """Download the raw data. This method should be overridden if you need to download the data."""
    pass

get

get(idx)

Get a processed data object.

Parameters:

Name Type Description Default
idx int

Index of the data object to retrieve.

required

Returns:

Name Type Description
Data Data

The processed data object.

Source code in src/segger/data/utils.py
446
447
448
449
450
451
452
453
454
455
456
457
def get(self, idx: int) -> Data:
    """Get a processed data object.

    Args:
        idx (int): Index of the data object to retrieve.

    Returns:
        Data: The processed data object.
    """
    data = torch.load(os.path.join(self.processed_dir, self.processed_file_names[idx]))
    data["tx"].x = data["tx"].x.to_dense()
    return data

len

len()

Return the number of processed files.

Returns:

Name Type Description
int int

Number of processed files.

Source code in src/segger/data/utils.py
438
439
440
441
442
443
444
def len(self) -> int:
    """Return the number of processed files.

    Returns:
        int: Number of processed files.
    """
    return len(self.processed_file_names)

process

process()

Process the raw data and save it to the processed directory. This method should be overridden if you need to process the data.

Source code in src/segger/data/utils.py
434
435
436
def process(self) -> None:
    """Process the raw data and save it to the processed directory. This method should be overridden if you need to process the data."""
    pass