diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py index 7d4702f11..318cd5094 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py @@ -202,6 +202,7 @@ class Emformer(EncoderInterface): ) self.log_eps = math.log(1e-10) + self._has_init_state = False self._init_state = torch.jit.Attribute([], List[List[torch.Tensor]]) def forward( @@ -296,7 +297,7 @@ class Emformer(EncoderInterface): Return the initial state of each layer. NOTE: the returned tensors are on the given device. `len(ans) == num_emformer_layers`. """ - if len(self._init_state) > 0: + if self._has_init_state: # Note(fangjun): It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -308,6 +309,7 @@ class Emformer(EncoderInterface): s = layer._init_state(batch_size=batch_size, device=device) ans.append(s) + self._has_init_state = True self._init_state = ans return ans