mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Decrease num_paths while CUDA OOM
This commit is contained in:
parent
d9addb7c43
commit
a4e1471d1d
@ -824,15 +824,41 @@ def rescore_with_attention_decoder(
|
||||
ngram_lm_scale_attention_scale and the value is the
|
||||
best decoding path for each utterance in the lattice.
|
||||
"""
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
nbest_scale=nbest_scale,
|
||||
)
|
||||
# nbest.fsa.scores are all 0s at this point
|
||||
max_loop_count = 10
|
||||
loop_count = 0
|
||||
while loop_count <= max_loop_count:
|
||||
try:
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
nbest_scale=nbest_scale,
|
||||
)
|
||||
# nbest.fsa.scores are all 0s at this point
|
||||
nbest = nbest.intersect(lattice)
|
||||
break
|
||||
except RuntimeError as e:
|
||||
logging.info(f"Caught exception:\n{e}\n")
|
||||
logging.info(
|
||||
f"num_paths before decreasing: {num_paths}"
|
||||
)
|
||||
num_paths = int(num_paths / 2)
|
||||
if loop_count >= max_loop_count or num_paths <= 0:
|
||||
logging.info(
|
||||
"Return None as the resulting lattice is too large."
|
||||
)
|
||||
return None
|
||||
logging.info(
|
||||
"This OOM is not an error. You can ignore it. "
|
||||
"If your model does not converge well, or --max-duration "
|
||||
"is too large, or the input sound file is difficult to "
|
||||
"decode, you will meet this exception."
|
||||
)
|
||||
logging.info(
|
||||
f"num_paths after decreasing: {num_paths}"
|
||||
)
|
||||
loop_count += 1
|
||||
|
||||
nbest = nbest.intersect(lattice)
|
||||
# Now nbest.fsa has its scores set.
|
||||
# Also, nbest.fsa inherits the attributes from `lattice`.
|
||||
assert hasattr(nbest.fsa, "lm_scores")
|
||||
|
Loading…
x
Reference in New Issue
Block a user