mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix exporting emformer with torchscript using torch 1.6.0 (#402)
This commit is contained in:
parent
29fa878fff
commit
80c46f0abd
@ -202,6 +202,7 @@ class Emformer(EncoderInterface):
|
|||||||
)
|
)
|
||||||
self.log_eps = math.log(1e-10)
|
self.log_eps = math.log(1e-10)
|
||||||
|
|
||||||
|
self._has_init_state = False
|
||||||
self._init_state = torch.jit.Attribute([], List[List[torch.Tensor]])
|
self._init_state = torch.jit.Attribute([], List[List[torch.Tensor]])
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -296,7 +297,7 @@ class Emformer(EncoderInterface):
|
|||||||
Return the initial state of each layer. NOTE: the returned
|
Return the initial state of each layer. NOTE: the returned
|
||||||
tensors are on the given device. `len(ans) == num_emformer_layers`.
|
tensors are on the given device. `len(ans) == num_emformer_layers`.
|
||||||
"""
|
"""
|
||||||
if len(self._init_state) > 0:
|
if self._has_init_state:
|
||||||
# Note(fangjun): It is OK to share the init state as it is
|
# Note(fangjun): It is OK to share the init state as it is
|
||||||
# not going to be modified by the model
|
# not going to be modified by the model
|
||||||
return self._init_state
|
return self._init_state
|
||||||
@ -308,6 +309,7 @@ class Emformer(EncoderInterface):
|
|||||||
s = layer._init_state(batch_size=batch_size, device=device)
|
s = layer._init_state(batch_size=batch_size, device=device)
|
||||||
ans.append(s)
|
ans.append(s)
|
||||||
|
|
||||||
|
self._has_init_state = True
|
||||||
self._init_state = ans
|
self._init_state = ans
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
Loading…
x
Reference in New Issue
Block a user