From 1f6c822dc0ac3bfbc688cf5d5ba624b57272d097 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 6 Jul 2022 17:46:36 +0800 Subject: [PATCH] add @torch.jit.export for init_states function --- .../ASR/conv_emformer_transducer_stateless/emformer.py | 3 +++ .../ASR/conv_emformer_transducer_stateless2/emformer.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 8d1a56736..61b7dec9c 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -1636,6 +1636,7 @@ class EmformerEncoder(nn.Module): ) return output, output_lengths, output_states + @torch.jit.export def init_states(self, device: torch.device = torch.device("cpu")): """Create initial states.""" attn_caches = [ @@ -1684,6 +1685,7 @@ class Emformer(EncoderInterface): self.subsampling_factor = subsampling_factor self.right_context_length = right_context_length + self.chunk_length = chunk_length if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") if chunk_length % subsampling_factor != 0: @@ -1832,6 +1834,7 @@ class Emformer(EncoderInterface): return output, output_lengths, output_states + @torch.jit.export def init_states(self, device: torch.device = torch.device("cpu")): """Create initial states.""" return self.encoder.init_states(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index 015ce9b9e..b3fce32e4 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -1539,6 +1539,7 @@ class EmformerEncoder(nn.Module): ) return output, output_lengths, output_states + @torch.jit.export def init_states(self, device: torch.device = torch.device("cpu")): """Create initial states.""" attn_caches = [ @@ -1587,6 +1588,7 @@ class Emformer(EncoderInterface): self.subsampling_factor = subsampling_factor self.right_context_length = right_context_length + self.chunk_length = chunk_length if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") if chunk_length % subsampling_factor != 0: @@ -1735,6 +1737,7 @@ class Emformer(EncoderInterface): return output, output_lengths, output_states + @torch.jit.export def init_states(self, device: torch.device = torch.device("cpu")): """Create initial states.""" return self.encoder.init_states(device)