Skip to content

segger.models

Models module for Segger.

Contains the implementation of the Segger model using Graph Neural Networks.

Models module for Segger.

Contains the implementation of the Segger model using Graph Neural Networks.

Segger

Segger(is_token_based, num_node_features, init_emb=16, hidden_channels=32, num_mid_layers=3, out_channels=32, heads=3)

Bases: Module

Initializes the Segger model.

Parameters:

Name Type Description Default
is_token_based int)

Whether the model is using token-based embeddings or scRNAseq embeddings.

required
num_node_features dict[str, int]

Number of node features for each node type.

required
init_emb int)

Initial embedding size for both 'tx' and boundary (non-token) nodes.

16
hidden_channels int)

Number of hidden channels.

32
num_mid_layers int)

Number of hidden layers (excluding first and last layers).

3
out_channels int)

Number of output channels.

32
heads int)

Number of attention heads.

3
Source code in src/segger/models/segger_model.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def __init__(
    self,
    is_token_based: int,
    num_node_features: dict[str, int],
    init_emb: int = 16,
    hidden_channels: int = 32,
    num_mid_layers: int = 3,
    out_channels: int = 32,
    heads: int = 3,
):
    """
    Initializes the Segger model.

    Args:
        is_token_based (int)   : Whether the model is using token-based embeddings or scRNAseq embeddings.
        num_node_features (dict[str, int]): Number of node features for each node type.
        init_emb (int)         : Initial embedding size for both 'tx' and boundary (non-token) nodes.
        hidden_channels (int)  : Number of hidden channels.
        num_mid_layers (int)   : Number of hidden layers (excluding first and last layers).
        out_channels (int)     : Number of output channels.
        heads (int)            : Number of attention heads.
    """
    super().__init__()

    # Initialize node embeddings
    if is_token_based:
        # Using token-based embeddings for transcript ('tx') nodes
        self.node_init = nn.ModuleDict(
            {
                "tx": nn.Embedding(num_node_features["tx"], init_emb),
                "bd": nn.Linear(num_node_features["bd"], init_emb),
            }
        )
    else:
        # Using scRNAseq embeddings (i.e. prior biological knowledge) for transcript ('tx') nodes
        self.node_init = nn.ModuleDict(
            {
                "tx": nn.Linear(num_node_features["tx"], init_emb),
                "bd": nn.Linear(num_node_features["bd"], init_emb),
            }
        )

    # First GATv2Conv layer
    self.conv1 = SkipGAT(init_emb, hidden_channels, heads)

    # Middle GATv2Conv layers
    self.num_mid_layers = num_mid_layers
    if num_mid_layers > 0:
        self.conv_mid_layers = nn.ModuleList()
        for _ in range(num_mid_layers):
            self.conv_mid_layers.append(SkipGAT(heads * hidden_channels, hidden_channels, heads))

    # Last GATv2Conv layer
    self.conv_last = SkipGAT(heads * hidden_channels, out_channels, heads)

    # Finalize node embeddings
    self.node_final = HeteroDictLinear(heads * out_channels, out_channels, types=("tx", "bd"))

decode

decode(z_dict, edge_index)

Decode the node embeddings to predict edge values.

Parameters:

Name Type Description Default
z_dict dict[str, Tensor]

Node embeddings for each node type.

required
edge_index EdgeIndex

Edge label indices.

required

Returns:

Name Type Description
Tensor Tensor

Predicted edge values.

Source code in src/segger/models/segger_model.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def decode(
    self,
    z_dict: dict[str, Tensor],
    edge_index: Union[Tensor],
) -> Tensor:
    """
    Decode the node embeddings to predict edge values.

    Args:
        z_dict (dict[str, Tensor]): Node embeddings for each node type.
        edge_index (EdgeIndex): Edge label indices.

    Returns:
        Tensor: Predicted edge values.
    """
    z_left = z_dict["tx"][edge_index[0]]
    z_right = z_dict["bd"][edge_index[1]]
    return (z_left * z_right).sum(dim=-1)

forward

forward(x_dict, edge_index_dict)

Forward pass for the Segger model.

Parameters:

Name Type Description Default
x_dict dict[str, Tensor]

Node features for each node type.

required
edge_index_dict dict[str, Tensor]

Edge indices for each edge type.

required
Source code in src/segger/models/segger_model.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def forward(
    self,
    x_dict: dict[str, Tensor],
    edge_index_dict: dict[str, Tensor],
) -> dict[str, Tensor]:
    """
    Forward pass for the Segger model.

    Args:
        x_dict (dict[str, Tensor]): Node features for each node type.
        edge_index_dict (dict[str, Tensor]): Edge indices for each edge type.
    """

    x_dict = {key: self.node_init[key](x) for key, x in x_dict.items()}

    x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}

    x_dict = self.conv1(x_dict, edge_index_dict)

    if self.num_mid_layers > 0:
        for i in range(self.num_mid_layers):
            x_dict = self.conv_mid_layers[i](x_dict, edge_index_dict)

    x_dict = self.conv_last(x_dict, edge_index_dict)

    x_dict = self.node_final(x_dict)

    return x_dict