mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix exporting emformer with torchscript using torch 1.6.0 (#402)
This commit is contained in:
parent
29fa878fff
commit
80c46f0abd
@ -202,6 +202,7 @@ class Emformer(EncoderInterface):
|
||||
)
|
||||
self.log_eps = math.log(1e-10)
|
||||
|
||||
self._has_init_state = False
|
||||
self._init_state = torch.jit.Attribute([], List[List[torch.Tensor]])
|
||||
|
||||
def forward(
|
||||
@ -296,7 +297,7 @@ class Emformer(EncoderInterface):
|
||||
Return the initial state of each layer. NOTE: the returned
|
||||
tensors are on the given device. `len(ans) == num_emformer_layers`.
|
||||
"""
|
||||
if len(self._init_state) > 0:
|
||||
if self._has_init_state:
|
||||
# Note(fangjun): It is OK to share the init state as it is
|
||||
# not going to be modified by the model
|
||||
return self._init_state
|
||||
@ -308,6 +309,7 @@ class Emformer(EncoderInterface):
|
||||
s = layer._init_state(batch_size=batch_size, device=device)
|
||||
ans.append(s)
|
||||
|
||||
self._has_init_state = True
|
||||
self._init_state = ans
|
||||
|
||||
return ans
|
||||
|
Loading…
x
Reference in New Issue
Block a user