Skip to content

segger.models

Models module for Segger.

Contains the implementation of the Segger model using Graph Neural Networks.

Models module for Segger.

Contains the implementation of the Segger model using Graph Neural Networks.

Segger

Segger(num_tx_tokens, init_emb=16, hidden_channels=32, num_mid_layers=3, out_channels=32, heads=3)

Bases: Module

Initializes the Segger model.

Parameters:

Name Type Description Default
num_tx_tokens int)

Number of unique 'tx' tokens for embedding.

required
init_emb int)

Initial embedding size for both 'tx' and boundary (non-token) nodes.

16
hidden_channels int

Number of hidden channels.

32
num_mid_layers int)

Number of hidden layers (excluding first and last layers).

3
out_channels int)

Number of output channels.

32
heads int)

Number of attention heads.

3
Source code in src/segger/models/segger_model.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    num_tx_tokens: int,
    init_emb: int = 16,
    hidden_channels: int = 32,
    num_mid_layers: int = 3,
    out_channels: int = 32,
    heads: int = 3,
):
    """
    Initializes the Segger model.

    Args:
        num_tx_tokens (int)  : Number of unique 'tx' tokens for embedding.
        init_emb (int)       : Initial embedding size for both 'tx' and boundary (non-token) nodes.
        hidden_channels (int): Number of hidden channels.
        num_mid_layers (int) : Number of hidden layers (excluding first and last layers).
        out_channels (int)   : Number of output channels.
        heads (int)          : Number of attention heads.
    """
    super().__init__()

    # Embedding for 'tx' (transcript) nodes
    self.tx_embedding = Embedding(num_tx_tokens, init_emb)

    # Linear layer for boundary (non-token) nodes
    self.lin0 = Linear(-1, init_emb, bias=False)

    # First GATv2Conv layer
    self.conv_first = GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False)
    # self.lin_first = Linear(-1, hidden_channels * heads)

    # Middle GATv2Conv layers
    self.num_mid_layers = num_mid_layers
    if num_mid_layers > 0:
        self.conv_mid_layers = torch.nn.ModuleList()
        # self.lin_mid_layers = torch.nn.ModuleList()
        for _ in range(num_mid_layers):
            self.conv_mid_layers.append(GATv2Conv((-1, -1), hidden_channels, heads=heads, add_self_loops=False))
            # self.lin_mid_layers.append(Linear(-1, hidden_channels * heads))

    # Last GATv2Conv layer
    self.conv_last = GATv2Conv((-1, -1), out_channels, heads=heads, add_self_loops=False)

decode

decode(z, edge_index)

Decode the node embeddings to predict edge values.

Parameters:

Name Type Description Default
z Tensor

Node embeddings.

required
edge_index EdgeIndex

Edge label indices.

required

Returns:

Name Type Description
Tensor Tensor

Predicted edge values.

Source code in src/segger/models/segger_model.py
87
88
89
90
91
92
93
94
95
96
97
98
def decode(self, z: Tensor, edge_index: Union[Tensor]) -> Tensor:
    """
    Decode the node embeddings to predict edge values.

    Args:
        z (Tensor): Node embeddings.
        edge_index (EdgeIndex): Edge label indices.

    Returns:
        Tensor: Predicted edge values.
    """
    return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)

forward

forward(x, edge_index)

Forward pass for the Segger model.

Parameters:

Name Type Description Default
x Tensor

Node features.

required
edge_index Tensor

Edge indices.

required

Returns:

Name Type Description
Tensor Tensor

Output node embeddings.

Source code in src/segger/models/segger_model.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
    """
    Forward pass for the Segger model.

    Args:
        x (Tensor): Node features.
        edge_index (Tensor): Edge indices.

    Returns:
        Tensor: Output node embeddings.
    """
    x = torch.nan_to_num(x, nan=0)
    is_one_dim = (x.ndim == 1) * 1
    # x = x[:, None]
    x = self.tx_embedding(((x.sum(1) * is_one_dim).int())) * is_one_dim + self.lin0(x.float()) * (1 - is_one_dim)
    # First layer
    x = x.relu()
    x = self.conv_first(x, edge_index)  # + self.lin_first(x)
    x = x.relu()

    # Middle layers
    if self.num_mid_layers > 0:
        for conv_mid in self.conv_mid_layers:
            x = conv_mid(x, edge_index)  # + lin_mid(x)
            x = x.relu()

    # Last layer
    x = self.conv_last(x, edge_index)  # + self.lin_last(x)

    return x