From c90f57afdbc8116944771aab1c3a42217a53eeec Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 4 Apr 2023 11:04:00 +0800 Subject: [PATCH] Remove simulate streaming from stateless8 (#985) --- .../pruned_transducer_stateless8/decode.py | 49 +------------------ 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index 7b651a632..e07777c9f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -301,29 +301,6 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) - parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. - """, - ) - - parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - add_model_arguments(parser) return parser @@ -378,22 +355,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( - x=feature, - x_lens=feature_lens, - chunk_size=params.decode_chunk_size, - left_context=params.left_context, - simulate_streaming=True, - ) - else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -651,10 +613,6 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -690,11 +648,6 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - logging.info(params) logging.info("About to create model")