From 8187d6236c2926500da5ee854f758e621df803cc Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 24 Dec 2021 21:48:40 +0800 Subject: [PATCH] Minor fix to maximum number of symbols per frame for RNN-T decoding. (#157) * Minor fix to maximum number of symbols per frame RNN-T decoding. * Minor fixes. --- .../ASR/transducer_stateless/beam_search.py | 19 ++++++++++++------- .../ASR/transducer_stateless/decode.py | 15 ++++++++++++++- .../ASR/transducer_stateless/pretrained.py | 15 ++++++++++++++- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 45118a8bc..9ed9b2ad1 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -22,13 +22,18 @@ import torch from model import Transducer -def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: +def greedy_search( + model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int +) -> List[int]: """ Args: model: An instance of `Transducer`. encoder_out: A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + max_sym_per_frame: + Maximum number of symbols per frame. If it is set to 0, the WER + would be 100%. Returns: Return the decoded result. """ @@ -55,10 +60,6 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: # Maximum symbols per utterance. max_sym_per_utt = 1000 - # If at frame t, it decodes more than this number of symbols, - # it will move to the next step t+1 - max_sym_per_frame = 3 - # symbols per frame sym_per_frame = 0 @@ -66,6 +67,11 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: sym_per_utt = 0 while t < T and sym_per_utt < max_sym_per_utt: + if sym_per_frame >= max_sym_per_frame: + sym_per_frame = 0 + t += 1 + continue + # fmt: off current_encoder_out = encoder_out[:, t:t+1, :] # fmt: on @@ -83,8 +89,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: sym_per_utt += 1 sym_per_frame += 1 - - if y == blank_id or sym_per_frame > max_sym_per_frame: + else: sym_per_frame = 0 t += 1 hyp = hyp[context_size:] # remove blanks diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index 82175e8db..51bebed5a 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -114,6 +114,13 @@ def get_parser(): help="Used only when --decoding-method is beam_search", ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=3, + help="Maximum number of symbols per frame", + ) + return parser @@ -237,7 +244,11 @@ def decode_one_batch( encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": - hyp = greedy_search(model=model, encoder_out=encoder_out_i) + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) elif params.decoding_method == "beam_search": hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size @@ -381,6 +392,8 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.decoding_method == "beam_search": params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index 49efa6749..6a6626371 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -110,6 +110,15 @@ def get_parser(): help="Used only when --method is beam_search", ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=3, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + return parser @@ -279,7 +288,11 @@ def main(): encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] # fmt: on if params.method == "greedy_search": - hyp = greedy_search(model=model, encoder_out=encoder_out_i) + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) elif params.method == "beam_search": hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size