mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Minor fix to maximum number of symbols per frame for RNN-T decoding. (#157)
* Minor fix to maximum number of symbols per frame RNN-T decoding. * Minor fixes.
This commit is contained in:
parent
5b6699a835
commit
8187d6236c
@ -22,13 +22,18 @@ import torch
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
def greedy_search(
|
||||||
|
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
||||||
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||||
|
max_sym_per_frame:
|
||||||
|
Maximum number of symbols per frame. If it is set to 0, the WER
|
||||||
|
would be 100%.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
Return the decoded result.
|
||||||
"""
|
"""
|
||||||
@ -55,10 +60,6 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|||||||
# Maximum symbols per utterance.
|
# Maximum symbols per utterance.
|
||||||
max_sym_per_utt = 1000
|
max_sym_per_utt = 1000
|
||||||
|
|
||||||
# If at frame t, it decodes more than this number of symbols,
|
|
||||||
# it will move to the next step t+1
|
|
||||||
max_sym_per_frame = 3
|
|
||||||
|
|
||||||
# symbols per frame
|
# symbols per frame
|
||||||
sym_per_frame = 0
|
sym_per_frame = 0
|
||||||
|
|
||||||
@ -66,6 +67,11 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|||||||
sym_per_utt = 0
|
sym_per_utt = 0
|
||||||
|
|
||||||
while t < T and sym_per_utt < max_sym_per_utt:
|
while t < T and sym_per_utt < max_sym_per_utt:
|
||||||
|
if sym_per_frame >= max_sym_per_frame:
|
||||||
|
sym_per_frame = 0
|
||||||
|
t += 1
|
||||||
|
continue
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
@ -83,8 +89,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|||||||
|
|
||||||
sym_per_utt += 1
|
sym_per_utt += 1
|
||||||
sym_per_frame += 1
|
sym_per_frame += 1
|
||||||
|
else:
|
||||||
if y == blank_id or sym_per_frame > max_sym_per_frame:
|
|
||||||
sym_per_frame = 0
|
sym_per_frame = 0
|
||||||
t += 1
|
t += 1
|
||||||
hyp = hyp[context_size:] # remove blanks
|
hyp = hyp[context_size:] # remove blanks
|
||||||
|
@ -114,6 +114,13 @@ def get_parser():
|
|||||||
help="Used only when --decoding-method is beam_search",
|
help="Used only when --decoding-method is beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sym-per-frame",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Maximum number of symbols per frame",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -237,7 +244,11 @@ def decode_one_batch(
|
|||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
|
hyp = greedy_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
@ -381,6 +392,8 @@ def main():
|
|||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
if params.decoding_method == "beam_search":
|
if params.decoding_method == "beam_search":
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += f"-beam-{params.beam_size}"
|
||||||
|
else:
|
||||||
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
@ -110,6 +110,15 @@ def get_parser():
|
|||||||
help="Used only when --method is beam_search",
|
help="Used only when --method is beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sym-per-frame",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="""Maximum number of symbols per frame. Used only when
|
||||||
|
--method is greedy_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -279,7 +288,11 @@ def main():
|
|||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if params.method == "greedy_search":
|
if params.method == "greedy_search":
|
||||||
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
|
hyp = greedy_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
)
|
||||||
elif params.method == "beam_search":
|
elif params.method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
|
Loading…
x
Reference in New Issue
Block a user