From 46bf6df62fb3d78b4b9bf1b0592c889a04d7be9b Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 3 Apr 2023 14:55:45 +0800 Subject: [PATCH] Remove simulate streaming from stateless7 (#983) * Remove simulate streaming from stateless7 --- .../pruned_transducer_stateless7/decode.py | 49 +------------------ 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 32b3134b9..576621e24 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -343,29 +343,6 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) - parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. - """, - ) - - parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - parser.add_argument( "--use-shallow-fusion", type=str2bool, @@ -474,22 +451,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( - x=feature, - x_lens=feature_lens, - chunk_size=params.decode_chunk_size, - left_context=params.left_context, - simulate_streaming=True, - ) - else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -782,10 +744,6 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -834,11 +792,6 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - logging.info(params) logging.info("About to create model")