From 87d9491fba42599af7b6718cf1f727e0cf2654d3 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 13 Apr 2023 17:20:25 +0800 Subject: [PATCH] minor fix in decode.py, about args --- .../pruned_transducer_stateless7/decode.py | 64 ++++++------------- 1 file changed, 21 insertions(+), 43 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 69ffab2f0..e7e041578 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -301,27 +301,18 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) - parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. - """, - ) - parser.add_argument( "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling), at 50Hz frame rate", ) parser.add_argument( - "--left-context", + "--decode-left-context", type=int, default=64, - help="left context can be seen during decoding (in frames after subsampling)", + help="left context can be seen during decoding (in frames after subsampling), at 50Hz frame rate", ) add_model_arguments(parser) @@ -378,28 +369,19 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - # this seems to cause insertions at the end of the utterance if used with zipformer. - #feature_lens += params.left_context - #feature = torch.nn.functional.pad( - # feature, - # pad=(0, 0, 0, params.left_context), - # value=LOG_EPS, - #) + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = params.decode_chunk_size + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) - if params.simulate_streaming: - # the chunk size and left context are now stored with the model. - # TODO: implement streaming_forward. - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( - x=feature, - x_lens=feature_lens, - #chunk_size=params.decode_chunk_size, - #left_context=params.left_context, - #simulate_streaming=True, - ) - else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -669,11 +651,12 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - # TODO: may still want to add something here? for now I am just - # moving the decoding directories around after decoding. - #if params.simulate_streaming: - #params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - #params.suffix += f"-left-context-{params.left_context}" + if params.causal: + # 'chunk_size' and 'left_context_frames' are used in function 'get_encoder_model' in train.py + params.chunk_size = str(params.decode_chunk_size) + params.left_context_frames = str(params.decode_left_context) + params.suffix += f"-decode-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-decode-left-context-{params.decode_left_context}" if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" @@ -712,11 +695,6 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - logging.info(params) logging.info("About to create model")