Fix exporting emformer with torchscript using torch 1.6.0 (#402)

This commit is contained in:
Fangjun Kuang 2022-06-07 09:19:37 +08:00 committed by GitHub
parent 29fa878fff
commit 80c46f0abd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -202,6 +202,7 @@ class Emformer(EncoderInterface):
) )
self.log_eps = math.log(1e-10) self.log_eps = math.log(1e-10)
self._has_init_state = False
self._init_state = torch.jit.Attribute([], List[List[torch.Tensor]]) self._init_state = torch.jit.Attribute([], List[List[torch.Tensor]])
def forward( def forward(
@ -296,7 +297,7 @@ class Emformer(EncoderInterface):
Return the initial state of each layer. NOTE: the returned Return the initial state of each layer. NOTE: the returned
tensors are on the given device. `len(ans) == num_emformer_layers`. tensors are on the given device. `len(ans) == num_emformer_layers`.
""" """
if len(self._init_state) > 0: if self._has_init_state:
# Note(fangjun): It is OK to share the init state as it is # Note(fangjun): It is OK to share the init state as it is
# not going to be modified by the model # not going to be modified by the model
return self._init_state return self._init_state
@ -308,6 +309,7 @@ class Emformer(EncoderInterface):
s = layer._init_state(batch_size=batch_size, device=device) s = layer._init_state(batch_size=batch_size, device=device)
ans.append(s) ans.append(s)
self._has_init_state = True
self._init_state = ans self._init_state = ans
return ans return ans