Merge remote-tracking branch 'dan/master' into rnnt-stateless2

This commit is contained in:
Fangjun Kuang 2021-12-27 15:07:01 +08:00
commit b5735ae16f
3 changed files with 37 additions and 9 deletions

View File

@ -22,13 +22,18 @@ import torch
from model import Transducer from model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
) -> List[int]:
""" """
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
encoder_out: encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
max_sym_per_frame:
Maximum number of symbols per frame. If it is set to 0, the WER
would be 100%.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -55,10 +60,6 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
# Maximum symbols per utterance. # Maximum symbols per utterance.
max_sym_per_utt = 1000 max_sym_per_utt = 1000
# If at frame t, it decodes more than this number of symbols,
# it will move to the next step t+1
max_sym_per_frame = 3
# symbols per frame # symbols per frame
sym_per_frame = 0 sym_per_frame = 0
@ -66,6 +67,11 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
sym_per_utt = 0 sym_per_utt = 0
while t < T and sym_per_utt < max_sym_per_utt: while t < T and sym_per_utt < max_sym_per_utt:
if sym_per_frame >= max_sym_per_frame:
sym_per_frame = 0
t += 1
continue
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on # fmt: on
@ -83,8 +89,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
sym_per_utt += 1 sym_per_utt += 1
sym_per_frame += 1 sym_per_frame += 1
else:
if y == blank_id or sym_per_frame > max_sym_per_frame:
sym_per_frame = 0 sym_per_frame = 0
t += 1 t += 1
hyp = hyp[context_size:] # remove blanks hyp = hyp[context_size:] # remove blanks

View File

@ -121,6 +121,12 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame",
)
return parser return parser
@ -241,7 +247,11 @@ def decode_one_batch(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on # fmt: on
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i) hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
@ -387,6 +397,7 @@ def main():
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")

View File

@ -117,6 +117,14 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
return parser return parser
@ -283,7 +291,11 @@ def main():
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on # fmt: on
if params.method == "greedy_search": if params.method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i) hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search": elif params.method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size