-
Notifications
You must be signed in to change notification settings - Fork 400
Description
Hi Team,
First of all, thank you for the incredible Framework maintained by you.
I believe that the forward method of the A3T-GCN model is not in accordance with the paper. In the paper (Section 2.5) the authors describe this step as follows:
The calculation of the T-GCN is shown in eq. (11), where
$h_{t-1}$ is the output at$t-1$ .
The way it is implemented in the framework, I understand that
If I am correct in my statement, a possible solution would be as follows:
H_accum = 0
probs = torch.nn.functional.softmax(self._attention, dim=0)
for period in range(self.periods):
H = self._base_tgcn( X[:, :, :, period], edge_index, edge_weight, H)
H_accum = H_accum + probs[period] * H
return H_accumBelow is the reference in the repository
pytorch_geometric_temporal/torch_geometric_temporal/nn/recurrent/attentiontemporalgcn.py
Lines 130 to 157 in 6c98fb3
| def forward( | |
| self, | |
| X: torch.FloatTensor, | |
| edge_index: torch.LongTensor, | |
| edge_weight: torch.FloatTensor = None, | |
| H: torch.FloatTensor = None | |
| ) -> torch.FloatTensor: | |
| """ | |
| Making a forward pass. If edge weights are not present the forward pass | |
| defaults to an unweighted graph. If the hidden state matrix is not present | |
| when the forward pass is called it is initialized with zeros. | |
| Arg types: | |
| * **X** (PyTorch Float Tensor): Node features for T time periods. | |
| * **edge_index** (PyTorch Long Tensor): Graph edge indices. | |
| * **edge_weight** (PyTorch Long Tensor, optional)*: Edge weight vector. | |
| * **H** (PyTorch Float Tensor, optional): Hidden state matrix for all nodes. | |
| Return types: | |
| * **H** (PyTorch Float Tensor): Hidden state matrix for all nodes. | |
| """ | |
| H_accum = 0 | |
| probs = torch.nn.functional.softmax(self._attention, dim=0) | |
| for period in range(self.periods): | |
| H_accum = H_accum + probs[period] * self._base_tgcn( X[:, :, :, period], edge_index, edge_weight, H) #([32, 207, 32] | |
| return H_accum |