diff --git a/icefall/decode.py b/icefall/decode.py index 4c2a8e01b..d3e420eec 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -716,10 +716,13 @@ def rescore_with_whole_lattice( b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) + # NOTE: The choice of the threshold list is arbitrary here to avoid OOM. + # You may need to fine tune it. + prune_th_list = [1e-10, 1e-9, 1e-8, 1e-7, 1e-6] + prune_th_list += [1e-5, 1e-4, 1e-3, 1e-2, 1e-1] max_loop_count = 10 loop_count = 0 while loop_count <= max_loop_count: - loop_count += 1 try: rescoring_lattice = k2.intersect_device( G_with_epsilon_loops, @@ -731,6 +734,11 @@ def rescore_with_whole_lattice( break except RuntimeError as e: logging.info(f"Caught exception:\n{e}\n") + if loop_count >= max_loop_count: + logging.info( + "Return None as the resulting lattice is too large." + ) + return None logging.info( f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}" ) @@ -740,16 +748,15 @@ def rescore_with_whole_lattice( "is too large, or the input sound file is difficult to " "decode, you will meet this exception." ) - - # NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here - # to avoid OOM. You may need to fine tune it. - inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-9, True) + inv_lattice = k2.prune_on_arc_post( + inv_lattice, + prune_th_list[loop_count], + True, + ) logging.info( f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}" ) - if loop_count > max_loop_count: - logging.info("Return None as the resulting lattice is too large") - return None + loop_count += 1 # lat has token IDs as labels # and word IDs as aux_labels.