Decrease num_paths while CUDA OOM

This commit is contained in:
Guanbo Wang 2022-04-11 21:45:29 +00:00
parent 3d2c261684
commit 6d07cf9245

View File

@ -630,15 +630,37 @@ def rescore_with_n_best_list(
assert G.device == device
assert hasattr(G, "aux_labels") is False
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores are all 0s at this point
max_loop_count = 10
loop_count = 0
while loop_count <= max_loop_count:
try:
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores are all 0s at this point
nbest = nbest.intersect(lattice)
break
except RuntimeError as e:
logging.info(f"Caught exception:\n{e}\n")
logging.info(f"num_paths before decreasing: {num_paths}")
num_paths = int(num_paths / 2)
if loop_count >= max_loop_count or num_paths <= 0:
logging.info(
"Return None as the resulting lattice is too large."
)
return None
logging.info(
"This OOM is not an error. You can ignore it. "
"If your model does not converge well, or --max-duration "
"is too large, or the input sound file is difficult to "
"decode, you will meet this exception."
)
logging.info(f"num_paths after decreasing: {num_paths}")
loop_count += 1
nbest = nbest.intersect(lattice)
# Now nbest.fsa has its scores set
assert hasattr(nbest.fsa, "lm_scores")