Decrease num_paths while CUDA OOM

This commit is contained in:
Guanbo Wang 2022-04-06 19:47:44 -04:00
parent d9addb7c43
commit a4e1471d1d

View File

@ -824,15 +824,41 @@ def rescore_with_attention_decoder(
ngram_lm_scale_attention_scale and the value is the
best decoding path for each utterance in the lattice.
"""
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores are all 0s at this point
max_loop_count = 10
loop_count = 0
while loop_count <= max_loop_count:
try:
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores are all 0s at this point
nbest = nbest.intersect(lattice)
break
except RuntimeError as e:
logging.info(f"Caught exception:\n{e}\n")
logging.info(
f"num_paths before decreasing: {num_paths}"
)
num_paths = int(num_paths / 2)
if loop_count >= max_loop_count or num_paths <= 0:
logging.info(
"Return None as the resulting lattice is too large."
)
return None
logging.info(
"This OOM is not an error. You can ignore it. "
"If your model does not converge well, or --max-duration "
"is too large, or the input sound file is difficult to "
"decode, you will meet this exception."
)
logging.info(
f"num_paths after decreasing: {num_paths}"
)
loop_count += 1
nbest = nbest.intersect(lattice)
# Now nbest.fsa has its scores set.
# Also, nbest.fsa inherits the attributes from `lattice`.
assert hasattr(nbest.fsa, "lm_scores")