Fix a bug introduced while supporting torch script. (#79)

This commit is contained in:
Fangjun Kuang 2021-10-14 20:09:38 +08:00 committed by GitHub
parent 5016ee3c95
commit f2387fe523
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -660,7 +660,7 @@ class PositionalEncoding(nn.Module):
self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout)
# not doing: self.pe = None because of errors thrown by torchscript
self.pe = torch.zeros(0, self.d_model, dtype=torch.float32)
self.pe = torch.zeros(0, 0, dtype=torch.float32)
def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required.