mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix a bug introduced while supporting torch script. (#79)
This commit is contained in:
parent
5016ee3c95
commit
f2387fe523
@ -660,7 +660,7 @@ class PositionalEncoding(nn.Module):
|
|||||||
self.xscale = math.sqrt(self.d_model)
|
self.xscale = math.sqrt(self.d_model)
|
||||||
self.dropout = nn.Dropout(p=dropout)
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
# not doing: self.pe = None because of errors thrown by torchscript
|
# not doing: self.pe = None because of errors thrown by torchscript
|
||||||
self.pe = torch.zeros(0, self.d_model, dtype=torch.float32)
|
self.pe = torch.zeros(0, 0, dtype=torch.float32)
|
||||||
|
|
||||||
def extend_pe(self, x: torch.Tensor) -> None:
|
def extend_pe(self, x: torch.Tensor) -> None:
|
||||||
"""Extend the time t in the positional encoding if required.
|
"""Extend the time t in the positional encoding if required.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user