diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 8c69cfd6e..8ba36e582 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -86,7 +86,7 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 - + (8) modified beam search with RNNLM shallow fusion (with LG) ./pruned_transducer_stateless5/decode.py \ --epoch 35 \ @@ -103,7 +103,7 @@ Usage: --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 - + """ @@ -198,7 +198,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="lstm_transducer_stateless2/exp", + default="pruned_transducer_stateless5/exp", help="The experiment dir", ) @@ -228,7 +228,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - - modified-beam-search3 # for rnn lm shallow fusion + - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -265,7 +265,21 @@ def get_parser(): It specifies the scale for n-gram LM scores. """, ) - + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + parser.add_argument( "--max-contexts", type=int, @@ -317,7 +331,7 @@ def get_parser(): Used only when the decoding method is fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) - + parser.add_argument( "--simulate-streaming", type=str2bool, @@ -331,7 +345,7 @@ def get_parser(): "--rnn-lm-scale", type=float, default=0.0, - help="""Used only when --method is modified_beam_search3. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the path to RNN LM exp dir. """, ) @@ -430,7 +444,7 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: @@ -560,7 +574,7 @@ def decode_one_batch( for i in range(batch_size): # fmt: off - encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": hyp = greedy_search( @@ -606,7 +620,7 @@ def decode_dataset( decoding_graph: Optional[k2.Fsa] = None, rnnlm: Optional[RnnLmModel] = None, rnnlm_scale: float = 1.0, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -683,7 +697,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): @@ -751,7 +765,9 @@ def main(): params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -791,6 +807,11 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + logging.info(params) logging.info("About to create model")