From 80c46f0abd386398595073a503f270d3afc90bd0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 7 Jun 2022 09:19:37 +0800 Subject: [PATCH] Fix exporting emformer with torchscript using torch 1.6.0 (#402) --- .../ASR/pruned_stateless_emformer_rnnt2/emformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py index 7d4702f11..318cd5094 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py @@ -202,6 +202,7 @@ class Emformer(EncoderInterface): ) self.log_eps = math.log(1e-10) + self._has_init_state = False self._init_state = torch.jit.Attribute([], List[List[torch.Tensor]]) def forward( @@ -296,7 +297,7 @@ class Emformer(EncoderInterface): Return the initial state of each layer. NOTE: the returned tensors are on the given device. `len(ans) == num_emformer_layers`. """ - if len(self._init_state) > 0: + if self._has_init_state: # Note(fangjun): It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -308,6 +309,7 @@ class Emformer(EncoderInterface): s = layer._init_state(batch_size=batch_size, device=device) ans.append(s) + self._has_init_state = True self._init_state = ans return ans