Incremental pruning threshold (#214)

* Incremental pruning threshold

* flake8

* black

* minor fix
This commit is contained in:
Wang, Guanbo 2022-02-16 03:59:27 -05:00 committed by GitHub
parent 70a3c56a18
commit e8eb408760
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -716,10 +716,13 @@ def rescore_with_whole_lattice(
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
# NOTE: The choice of the threshold list is arbitrary here to avoid OOM.
# You may need to fine tune it.
prune_th_list = [1e-10, 1e-9, 1e-8, 1e-7, 1e-6]
prune_th_list += [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
max_loop_count = 10
loop_count = 0
while loop_count <= max_loop_count:
loop_count += 1
try:
rescoring_lattice = k2.intersect_device(
G_with_epsilon_loops,
@ -731,6 +734,11 @@ def rescore_with_whole_lattice(
break
except RuntimeError as e:
logging.info(f"Caught exception:\n{e}\n")
if loop_count >= max_loop_count:
logging.info(
"Return None as the resulting lattice is too large."
)
return None
logging.info(
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
)
@ -740,16 +748,15 @@ def rescore_with_whole_lattice(
"is too large, or the input sound file is difficult to "
"decode, you will meet this exception."
)
# NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here
# to avoid OOM. You may need to fine tune it.
inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-9, True)
inv_lattice = k2.prune_on_arc_post(
inv_lattice,
prune_th_list[loop_count],
True,
)
logging.info(
f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
)
if loop_count > max_loop_count:
logging.info("Return None as the resulting lattice is too large")
return None
loop_count += 1
# lat has token IDs as labels
# and word IDs as aux_labels.