mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
Decrease num_paths while CUDA OOM
This commit is contained in:
parent
3d2c261684
commit
6d07cf9245
@ -630,15 +630,37 @@ def rescore_with_n_best_list(
|
||||
assert G.device == device
|
||||
assert hasattr(G, "aux_labels") is False
|
||||
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
nbest_scale=nbest_scale,
|
||||
)
|
||||
# nbest.fsa.scores are all 0s at this point
|
||||
max_loop_count = 10
|
||||
loop_count = 0
|
||||
while loop_count <= max_loop_count:
|
||||
try:
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
nbest_scale=nbest_scale,
|
||||
)
|
||||
# nbest.fsa.scores are all 0s at this point
|
||||
nbest = nbest.intersect(lattice)
|
||||
break
|
||||
except RuntimeError as e:
|
||||
logging.info(f"Caught exception:\n{e}\n")
|
||||
logging.info(f"num_paths before decreasing: {num_paths}")
|
||||
num_paths = int(num_paths / 2)
|
||||
if loop_count >= max_loop_count or num_paths <= 0:
|
||||
logging.info(
|
||||
"Return None as the resulting lattice is too large."
|
||||
)
|
||||
return None
|
||||
logging.info(
|
||||
"This OOM is not an error. You can ignore it. "
|
||||
"If your model does not converge well, or --max-duration "
|
||||
"is too large, or the input sound file is difficult to "
|
||||
"decode, you will meet this exception."
|
||||
)
|
||||
logging.info(f"num_paths after decreasing: {num_paths}")
|
||||
loop_count += 1
|
||||
|
||||
nbest = nbest.intersect(lattice)
|
||||
# Now nbest.fsa has its scores set
|
||||
assert hasattr(nbest.fsa, "lm_scores")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user