Skip to content

segger.training

training module for Segger.

Contains the implementation of the Segger model using Graph Neural Networks.

LitSegger

LitSegger(**kwargs)

Bases: LightningModule

LitSegger is a PyTorch Lightning module for training and validating the Segger model.

Attributes

model : Segger The Segger model wrapped with PyTorch Geometric's to_hetero for heterogeneous graph support. validation_step_outputs : list A list to store outputs from the validation steps. criterion : torch.nn.Module The loss function used for training, specifically BCEWithLogitsLoss.

Initializes the LitSegger module with the given parameters.

Parameters

**kwargs : dict Keyword arguments for initializing the module. Specific parameters depend on whether the module is initialized with new parameters or components.

Source code in src/segger/training/train.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(self, **kwargs):
    """
    Initializes the LitSegger module with the given parameters.

    Parameters
    ----------
    **kwargs : dict
        Keyword arguments for initializing the module. Specific parameters
        depend on whether the module is initialized with new parameters or components.
    """
    super().__init__()
    new_args = inspect.getfullargspec(self.from_new)[0][1:]
    cmp_args = inspect.getfullargspec(self.from_components)[0][1:]

    # Initialize with new parameters (ensure num_tx_tokens is passed here)
    if set(kwargs.keys()) == set(new_args):
        self.from_new(**kwargs)

    # Initialize with existing components
    elif set(kwargs.keys()) == set(cmp_args):
        self.from_components(**kwargs)

    # Handle invalid arguments
    else:
        raise ValueError(
            f"Supplied kwargs do not match either constructor. Should be one of '{new_args}' or '{cmp_args}'."
        )

    self.validation_step_outputs = []
    self.criterion = torch.nn.BCEWithLogitsLoss()

configure_optimizers

configure_optimizers()

Configures the optimizer for training.

Returns

torch.optim.Optimizer The optimizer for training.

Source code in src/segger/training/train.py
217
218
219
220
221
222
223
224
225
226
227
def configure_optimizers(self) -> torch.optim.Optimizer:
    """
    Configures the optimizer for training.

    Returns
    -------
    torch.optim.Optimizer
        The optimizer for training.
    """
    optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
    return optimizer

forward

forward(batch)

Forward pass for the batch of data.

Parameters

batch : SpatialTranscriptomicsDataset The batch of data, including node features and edge indices.

Returns

torch.Tensor The output of the model.

Source code in src/segger/training/train.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def forward(self, batch: SpatialTranscriptomicsDataset) -> torch.Tensor:
    """
    Forward pass for the batch of data.

    Parameters
    ----------
    batch : SpatialTranscriptomicsDataset
        The batch of data, including node features and edge indices.

    Returns
    -------
    torch.Tensor
        The output of the model.
    """
    z = self.model(batch.x_dict, batch.edge_index_dict)
    output = torch.matmul(z["tx"], z["bd"].t())  # Example for bipartite graph
    return output

from_components

from_components(model)

Initializes the LitSegger module with existing Segger components.

Parameters

model : Segger The Segger model to be used.

Source code in src/segger/training/train.py
111
112
113
114
115
116
117
118
119
120
def from_components(self, model: Segger):
    """
    Initializes the LitSegger module with existing Segger components.

    Parameters
    ----------
    model : Segger
        The Segger model to be used.
    """
    self.model = model

from_new

from_new(num_tx_tokens, init_emb, hidden_channels, out_channels, heads, num_mid_layers, aggr, metadata)

Initializes the LitSegger module with new parameters.

Parameters

num_tx_tokens : int Number of unique 'tx' tokens for embedding (this must be passed here). init_emb : int Initial embedding size. hidden_channels : int Number of hidden channels. out_channels : int Number of output channels. heads : int Number of attention heads. aggr : str Aggregation method for heterogeneous graph conversion. num_mid_layers: int Number of hidden layers (excluding first and last layers). metadata : Union[Tuple, Metadata] Metadata for heterogeneous graph structure.

Source code in src/segger/training/train.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def from_new(
    self,
    num_tx_tokens: int,
    init_emb: int,
    hidden_channels: int,
    out_channels: int,
    heads: int,
    num_mid_layers: int,
    aggr: str,
    metadata: Union[Tuple, Metadata],
):
    """
    Initializes the LitSegger module with new parameters.

    Parameters
    ----------
    num_tx_tokens : int
        Number of unique 'tx' tokens for embedding (this must be passed here).
    init_emb : int
        Initial embedding size.
    hidden_channels : int
        Number of hidden channels.
    out_channels : int
        Number of output channels.
    heads : int
        Number of attention heads.
    aggr : str
        Aggregation method for heterogeneous graph conversion.
    num_mid_layers: int
        Number of hidden layers (excluding first and last layers).
    metadata : Union[Tuple, Metadata]
        Metadata for heterogeneous graph structure.
    """
    # Create the Segger model (ensure num_tx_tokens is passed here)
    model = Segger(
        num_tx_tokens=num_tx_tokens,  # This is required and must be passed here
        init_emb=init_emb,
        hidden_channels=hidden_channels,
        out_channels=out_channels,
        heads=heads,
        num_mid_layers=num_mid_layers,
    )
    # Convert model to handle heterogeneous graphs
    model = to_hetero(model, metadata=metadata, aggr=aggr)
    self.model = model
    # Save hyperparameters
    self.save_hyperparameters()

training_step

training_step(batch, batch_idx)

Defines the training step.

Parameters

batch : Any The batch of data. batch_idx : int The index of the batch.

Returns

torch.Tensor The loss value for the current training step.

Source code in src/segger/training/train.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
    """
    Defines the training step.

    Parameters
    ----------
    batch : Any
        The batch of data.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    torch.Tensor
        The loss value for the current training step.
    """
    # Forward pass to get the logits
    z = self.model(batch.x_dict, batch.edge_index_dict)
    output = torch.matmul(z["tx"], z["bd"].t())

    # Get edge labels and logits
    edge_label_index = batch["tx", "belongs", "bd"].edge_label_index
    out_values = output[edge_label_index[0], edge_label_index[1]]
    edge_label = batch["tx", "belongs", "bd"].edge_label

    # Compute binary cross-entropy loss with logits (no sigmoid here)
    loss = self.criterion(out_values, edge_label)

    # Log the training loss
    self.log("train_loss", loss, prog_bar=True, batch_size=batch.num_graphs)
    return loss

validation_step

validation_step(batch, batch_idx)

Defines the validation step.

Parameters

batch : Any The batch of data. batch_idx : int The index of the batch.

Returns

torch.Tensor The loss value for the current validation step.

Source code in src/segger/training/train.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
    """
    Defines the validation step.

    Parameters
    ----------
    batch : Any
        The batch of data.
    batch_idx : int
        The index of the batch.

    Returns
    -------
    torch.Tensor
        The loss value for the current validation step.
    """
    # Forward pass to get the logits
    z = self.model(batch.x_dict, batch.edge_index_dict)
    output = torch.matmul(z["tx"], z["bd"].t())

    # Get edge labels and logits
    edge_label_index = batch["tx", "belongs", "bd"].edge_label_index
    out_values = output[edge_label_index[0], edge_label_index[1]]
    edge_label = batch["tx", "belongs", "bd"].edge_label

    # Compute binary cross-entropy loss with logits (no sigmoid here)
    loss = self.criterion(out_values, edge_label)

    # Apply sigmoid to logits for AUROC and F1 metrics
    out_values_prob = torch.sigmoid(out_values)

    # Compute metrics
    auroc = torchmetrics.AUROC(task="binary")
    auroc_res = auroc(out_values_prob, edge_label)

    f1 = F1Score(task="binary").to(self.device)
    f1_res = f1(out_values_prob, edge_label)

    # Log validation metrics
    self.log("validation_loss", loss, batch_size=batch.num_graphs)
    self.log("validation_auroc", auroc_res, prog_bar=True, batch_size=batch.num_graphs)
    self.log("validation_f1", f1_res, prog_bar=True, batch_size=batch.num_graphs)

    return loss

Segger

Segger(num_tx_tokens, init_emb=16, hidden_channels=32, num_mid_layers=3, out_channels=32, heads=3)

Bases: Module

Initializes the Segger model.

Parameters:

Name Type Description Default
num_tx_tokens int)

Number of unique 'tx' tokens for embedding.

required
init_emb int)

Initial embedding size for both 'tx' and boundary (non-token) nodes.

16
hidden_channels int

Number of hidden channels.

32
num_mid_layers int)

Number of hidden layers (excluding first and last layers).

3
out_channels int)

Number of output channels.

32
heads int)

Number of attention heads.

3
Source code in src/segger/models/segger_model.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    num_tx_tokens: int,
    init_emb: int = 16,
    hidden_channels: int = 32,
    num_mid_layers: int = 3,
    out_channels: int = 32,
    heads: int = 3,
):
    """
    Initializes the Segger model.

    Args:
        num_tx_tokens (int)  : Number of unique 'tx' tokens for embedding.
        init_emb (int)       : Initial embedding size for both 'tx' and boundary (non-token) nodes.
        hidden_channels (int): Number of hidden channels.
        num_mid_layers (int) : Number of hidden layers (excluding first and last layers).
        out_channels (int)   : Number of output channels.
        heads (int)          : Number of attention heads.
    """
    super().__init__()

    # Embedding for 'tx' (transcript) nodes
    self.tx_embedding = Embedding(num_tx_tokens, init_emb)

    # Linear layer for boundary (non-token) nodes
    self.lin0 = Linear(-1, init_emb, bias=False)

    # First GATv2Conv layer
    self.conv_first = GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False)
    # self.lin_first = Linear(-1, hidden_channels * heads)

    # Middle GATv2Conv layers
    self.num_mid_layers = num_mid_layers
    if num_mid_layers > 0:
        self.conv_mid_layers = torch.nn.ModuleList()
        # self.lin_mid_layers = torch.nn.ModuleList()
        for _ in range(num_mid_layers):
            self.conv_mid_layers.append(GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False))
            # self.lin_mid_layers.append(Linear(-1, hidden_channels * heads))

    # Last GATv2Conv layer
    self.conv_last = GATv2Conv((-1, -1), out_channels, heads=heads, add_self_loops=False)

decode

decode(z, edge_index)

Decode the node embeddings to predict edge values.

Parameters:

Name Type Description Default
z Tensor

Node embeddings.

required
edge_index EdgeIndex

Edge label indices.

required

Returns:

Name Type Description
Tensor Tensor

Predicted edge values.

Source code in src/segger/models/segger_model.py
87
88
89
90
91
92
93
94
95
96
97
98
def decode(self, z: Tensor, edge_index: Union[Tensor]) -> Tensor:
    """
    Decode the node embeddings to predict edge values.

    Args:
        z (Tensor): Node embeddings.
        edge_index (EdgeIndex): Edge label indices.

    Returns:
        Tensor: Predicted edge values.
    """
    return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)

forward

forward(x, edge_index)

Forward pass for the Segger model.

Parameters:

Name Type Description Default
x Tensor

Node features.

required
edge_index Tensor

Edge indices.

required

Returns:

Name Type Description
Tensor Tensor

Output node embeddings.

Source code in src/segger/models/segger_model.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
    """
    Forward pass for the Segger model.

    Args:
        x (Tensor): Node features.
        edge_index (Tensor): Edge indices.

    Returns:
        Tensor: Output node embeddings.
    """
    x = torch.nan_to_num(x, nan=0)
    is_one_dim = (x.ndim == 1) * 1
    # x = x[:, None]
    x = self.tx_embedding(((x.sum(1) * is_one_dim).int())) * is_one_dim + self.lin0(x.float()) * (1 - is_one_dim)
    # First layer
    x = x.relu()
    x = self.conv_first(x, edge_index)  # + self.lin_first(x)
    x = x.relu()

    # Middle layers
    if self.num_mid_layers > 0:
        for conv_mid in self.conv_mid_layers:
            x = conv_mid(x, edge_index)  # + lin_mid(x)
            x = x.relu()

    # Last layer
    x = self.conv_last(x, edge_index)  # + self.lin_last(x)

    return x

SpatialTranscriptomicsDataset

SpatialTranscriptomicsDataset(root, transform=None, pre_transform=None, pre_filter=None)

Bases: InMemoryDataset

A dataset class for handling SpatialTranscriptomics spatial transcriptomics data.

Attributes:

Name Type Description
root str

The root directory where the dataset is stored.

transform callable

A function/transform that takes in a Data object and returns a transformed version.

pre_transform callable

A function/transform that takes in a Data object and returns a transformed version.

pre_filter callable

A function that takes in a Data object and returns a boolean indicating whether to keep it.

Initialize the SpatialTranscriptomicsDataset.

Parameters:

Name Type Description Default
root str

Root directory where the dataset is stored.

required
transform callable

A function/transform that takes in a Data object and returns a transformed version. Defaults to None.

None
pre_transform callable

A function/transform that takes in a Data object and returns a transformed version. Defaults to None.

None
pre_filter callable

A function that takes in a Data object and returns a boolean indicating whether to keep it. Defaults to None.

None
Source code in src/segger/data/utils.py
441
442
443
444
445
446
447
448
449
450
451
452
def __init__(
    self, root: str, transform: Callable = None, pre_transform: Callable = None, pre_filter: Callable = None
):
    """Initialize the SpatialTranscriptomicsDataset.

    Args:
        root (str): Root directory where the dataset is stored.
        transform (callable, optional): A function/transform that takes in a Data object and returns a transformed version. Defaults to None.
        pre_transform (callable, optional): A function/transform that takes in a Data object and returns a transformed version. Defaults to None.
        pre_filter (callable, optional): A function that takes in a Data object and returns a boolean indicating whether to keep it. Defaults to None.
    """
    super().__init__(root, transform, pre_transform, pre_filter)

processed_file_names property

processed_file_names

Return a list of processed file names in the processed directory.

Returns:

Type Description
List[str]

List[str]: List of processed file names.

raw_file_names property

raw_file_names

Return a list of raw file names in the raw directory.

Returns:

Type Description
List[str]

List[str]: List of raw file names.

download

download()

Download the raw data. This method should be overridden if you need to download the data.

Source code in src/segger/data/utils.py
472
473
474
def download(self) -> None:
    """Download the raw data. This method should be overridden if you need to download the data."""
    pass

get

get(idx)

Get a processed data object.

Parameters:

Name Type Description Default
idx int

Index of the data object to retrieve.

required

Returns:

Name Type Description
Data Data

The processed data object.

Source code in src/segger/data/utils.py
488
489
490
491
492
493
494
495
496
497
498
499
def get(self, idx: int) -> Data:
    """Get a processed data object.

    Args:
        idx (int): Index of the data object to retrieve.

    Returns:
        Data: The processed data object.
    """
    data = torch.load(os.path.join(self.processed_dir, self.processed_file_names[idx]))
    data["tx"].x = data["tx"].x.to_dense()
    return data

len

len()

Return the number of processed files.

Returns:

Name Type Description
int int

Number of processed files.

Source code in src/segger/data/utils.py
480
481
482
483
484
485
486
def len(self) -> int:
    """Return the number of processed files.

    Returns:
        int: Number of processed files.
    """
    return len(self.processed_file_names)

process

process()

Process the raw data and save it to the processed directory. This method should be overridden if you need to process the data.

Source code in src/segger/data/utils.py
476
477
478
def process(self) -> None:
    """Process the raw data and save it to the processed directory. This method should be overridden if you need to process the data."""
    pass