segger.models¶
Models module for Segger.
Contains the implementation of the Segger model using Graph Neural Networks.
Models module for Segger.
Contains the implementation of the Segger model using Graph Neural Networks.
Segger ¶
Segger(is_token_based, num_node_features, init_emb=16, hidden_channels=32, num_mid_layers=3, out_channels=32, heads=3)
Bases: Module
Initializes the Segger model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
is_token_based
|
int)
|
Whether the model is using token-based embeddings or scRNAseq embeddings. |
required |
num_node_features
|
dict[str, int]
|
Number of node features for each node type. |
required |
init_emb
|
int)
|
Initial embedding size for both 'tx' and boundary (non-token) nodes. |
16
|
hidden_channels
|
int)
|
Number of hidden channels. |
32
|
num_mid_layers
|
int)
|
Number of hidden layers (excluding first and last layers). |
3
|
out_channels
|
int)
|
Number of output channels. |
32
|
heads
|
int)
|
Number of attention heads. |
3
|
Source code in src/segger/models/segger_model.py
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
|
decode ¶
decode(z_dict, edge_index)
Decode the node embeddings to predict edge values.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
z_dict
|
dict[str, Tensor]
|
Node embeddings for each node type. |
required |
edge_index
|
EdgeIndex
|
Edge label indices. |
required |
Returns:
Name | Type | Description |
---|---|---|
Tensor |
Tensor
|
Predicted edge values. |
Source code in src/segger/models/segger_model.py
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
|
forward ¶
forward(x_dict, edge_index_dict)
Forward pass for the Segger model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x_dict
|
dict[str, Tensor]
|
Node features for each node type. |
required |
edge_index_dict
|
dict[str, Tensor]
|
Edge indices for each edge type. |
required |
Source code in src/segger/models/segger_model.py
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
|