mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Incremental pruning threshold (#214)
* Incremental pruning threshold * flake8 * black * minor fix
This commit is contained in:
parent
70a3c56a18
commit
e8eb408760
@ -716,10 +716,13 @@ def rescore_with_whole_lattice(
|
||||
|
||||
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
|
||||
|
||||
# NOTE: The choice of the threshold list is arbitrary here to avoid OOM.
|
||||
# You may need to fine tune it.
|
||||
prune_th_list = [1e-10, 1e-9, 1e-8, 1e-7, 1e-6]
|
||||
prune_th_list += [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
|
||||
max_loop_count = 10
|
||||
loop_count = 0
|
||||
while loop_count <= max_loop_count:
|
||||
loop_count += 1
|
||||
try:
|
||||
rescoring_lattice = k2.intersect_device(
|
||||
G_with_epsilon_loops,
|
||||
@ -731,6 +734,11 @@ def rescore_with_whole_lattice(
|
||||
break
|
||||
except RuntimeError as e:
|
||||
logging.info(f"Caught exception:\n{e}\n")
|
||||
if loop_count >= max_loop_count:
|
||||
logging.info(
|
||||
"Return None as the resulting lattice is too large."
|
||||
)
|
||||
return None
|
||||
logging.info(
|
||||
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
|
||||
)
|
||||
@ -740,16 +748,15 @@ def rescore_with_whole_lattice(
|
||||
"is too large, or the input sound file is difficult to "
|
||||
"decode, you will meet this exception."
|
||||
)
|
||||
|
||||
# NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here
|
||||
# to avoid OOM. You may need to fine tune it.
|
||||
inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-9, True)
|
||||
inv_lattice = k2.prune_on_arc_post(
|
||||
inv_lattice,
|
||||
prune_th_list[loop_count],
|
||||
True,
|
||||
)
|
||||
logging.info(
|
||||
f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
|
||||
)
|
||||
if loop_count > max_loop_count:
|
||||
logging.info("Return None as the resulting lattice is too large")
|
||||
return None
|
||||
loop_count += 1
|
||||
|
||||
# lat has token IDs as labels
|
||||
# and word IDs as aux_labels.
|
||||
|
Loading…
x
Reference in New Issue
Block a user