diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index fb16ab6c1..d05bef337 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -141,12 +141,20 @@ def get_parser(): """, ) + parser.add_argument( + "--streaming-model", + type=str2bool, + default=False, + help="""Whether to export a streaming model, if the models in exp-dir + are streaming model, this should be True. + """, + ) parser.add_argument( "--causal-convolution", type=str2bool, default=False, help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. + exporting a streaming model. """, ) @@ -174,6 +182,9 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.streaming_model: + assert params.causal_convolution + logging.info(params) logging.info("About to create model") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bde13f32d..bb786c775 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -122,7 +122,7 @@ class Conformer(EncoderInterface): causal, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - self._init_state = torch.jit.Attribute([], List[torch.Tensor]) + self._init_state: List[torch.Tensor] = [torch.empty(0)] def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -1396,7 +1396,7 @@ class ConvolutionModule(nn.Module): self, x: Tensor, cache: Optional[Tensor] = None, - right_context=0, + right_context: int = 0, ) -> Tuple[Tensor, Tensor]: """Compute convolution module. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index a8cbbbd9b..c018bfa03 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -156,12 +156,20 @@ def get_parser(): """, ) + parser.add_argument( + "--streaming-model", + type=str2bool, + default=False, + help="""Whether to export a streaming model, if the models in exp-dir + are streaming model, this should be True. + """, + ) parser.add_argument( "--causal-convolution", type=str2bool, default=False, help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. + exporting a streaming model. """, ) @@ -188,6 +196,9 @@ def main(): params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.streaming_model: + assert params.causal_convolution + logging.info(params) logging.info("About to create model") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index 0781dfb85..65bffd063 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -480,7 +480,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index e674fb360..f7f96f9a6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -125,6 +125,55 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + Note: not needed here, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + Note: not needed for here, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="""How many left context can be seen in chunks when calculating attention. + Note: not needed here, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--streaming-model", + type=str2bool, + default=False, + help="""Whether to export a streaming model, if the models in exp-dir + are streaming model, this should be True. + """, + ) + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + exporting a streaming model. + """, + ) + return parser @@ -148,6 +197,9 @@ def main(): params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.streaming_model: + assert params.causal_convolution + logging.info(params) logging.info("About to create model") diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 89e424084..031d2414a 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -125,7 +125,7 @@ class Conformer(Transformer): # and throws an error without this change. self.after_norm = identity - self._init_state = torch.jit.Attribute([], List[torch.Tensor]) + self._init_state: List[torch.Tensor] = [torch.empty(0)] def forward( self, x: torch.Tensor, x_lens: torch.Tensor