segger.models¶
Models module for Segger.
Contains the implementation of the Segger model using Graph Neural Networks.
Models module for Segger.
Contains the implementation of the Segger model using Graph Neural Networks.
Segger ¶
Segger(num_tx_tokens, init_emb=16, hidden_channels=32, num_mid_layers=3, out_channels=32, heads=3)
Bases: Module
Initializes the Segger model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_tx_tokens
|
int)
|
Number of unique 'tx' tokens for embedding. |
required |
init_emb
|
int)
|
Initial embedding size for both 'tx' and boundary (non-token) nodes. |
16
|
hidden_channels
|
int
|
Number of hidden channels. |
32
|
num_mid_layers
|
int)
|
Number of hidden layers (excluding first and last layers). |
3
|
out_channels
|
int)
|
Number of output channels. |
32
|
heads
|
int)
|
Number of attention heads. |
3
|
Source code in src/segger/models/segger_model.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
|
decode ¶
decode(z, edge_index)
Decode the node embeddings to predict edge values.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
z
|
Tensor
|
Node embeddings. |
required |
edge_index
|
EdgeIndex
|
Edge label indices. |
required |
Returns:
Name | Type | Description |
---|---|---|
Tensor |
Tensor
|
Predicted edge values. |
Source code in src/segger/models/segger_model.py
87 88 89 90 91 92 93 94 95 96 97 98 |
|
forward ¶
forward(x, edge_index)
Forward pass for the Segger model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Tensor
|
Node features. |
required |
edge_index
|
Tensor
|
Edge indices. |
required |
Returns:
Name | Type | Description |
---|---|---|
Tensor |
Tensor
|
Output node embeddings. |
Source code in src/segger/models/segger_model.py
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
|