Minor fixes.

This commit is contained in:
Fangjun Kuang 2021-12-24 11:16:29 +08:00
parent 0fa4ca7f02
commit 60696d3eb2

View File

@ -110,6 +110,15 @@ def get_parser():
help="Used only when --method is beam_search",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
return parser
@ -279,7 +288,11 @@ def main():
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size