Add whole-lattice rescoring.

This commit is contained in:
Fangjun Kuang 2021-09-17 21:05:00 +08:00
parent a44b4f84a5
commit b0b942b355
2 changed files with 120 additions and 11 deletions

View File

@ -42,6 +42,7 @@ from icefall.decode2 import (
nbest_decoding, nbest_decoding,
nbest_oracle as nbest_oracle2, nbest_oracle as nbest_oracle2,
rescore_with_n_best_list as rescore_with_n_best_list2, rescore_with_n_best_list as rescore_with_n_best_list2,
rescore_with_whole_lattice as rescore_with_whole_lattice2,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
@ -261,9 +262,7 @@ def decode_one_batch(
) )
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
key = ( key = f"oracle_{num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
f"oracle_{num_paths}_lattice_score_scale_{lattice_score_scale}"
)
return {key: hyps} return {key: hyps}
else: else:
return nbest_oracle( return nbest_oracle(
@ -322,8 +321,18 @@ def decode_one_batch(
scale=params.lattice_score_scale, scale=params.lattice_score_scale,
) )
elif params.method == "whole-lattice-rescoring": elif params.method == "whole-lattice-rescoring":
if True:
# TODO: remove "else" branch
best_path_dict = rescore_with_whole_lattice2(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
else:
best_path_dict = rescore_with_whole_lattice( best_path_dict = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
) )
elif params.method == "attention-decoder": elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
@ -345,10 +354,14 @@ def decode_one_batch(
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"
ans = dict() ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items(): for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps ans[lm_scale_str] = hyps
else:
for lm_scale in lm_scale_list:
ans[lm_scale_str] = [[] * lattice.shape[0]]
return ans return ans

View File

@ -17,7 +17,8 @@
# NOTE: This file is a refactor of decode.py # NOTE: This file is a refactor of decode.py
# We will delete decode.py and rename this file to decode.py # We will delete decode.py and rename this file to decode.py
from typing import Dict, List import logging
from typing import Dict, List, Optional, Union
import k2 import k2
import torch import torch
@ -505,6 +506,27 @@ def rescore_with_n_best_list(
) -> Dict[str, k2.Fsa]: ) -> Dict[str, k2.Fsa]:
"""Rescore a nbest list with an n-gram LM. """Rescore a nbest list with an n-gram LM.
The path with a maximum score is used as the decoding output. The path with a maximum score is used as the decoding output.
Args:
lattice:
An FsaVec with axes [utt][state][arc]. It must have the following
attributes: ``aux_labels`` and ``lm_scores``. Its labels are
token IDs and ``aux_labels`` word IDs.
G:
An FsaVec containing only a single FSA. It is an n-gram LM.
num_paths:
Size of nbest list.
lm_scale_list:
A list of float representing LM score scales.
lattice_score_scale:
Scale to be applied to ``lattice.score`` when sampling paths
using ``k2.random_paths``.
use_double_scores:
True to use double precision during computation. False to use
single precision.
Returns:
A dict of FsaVec, whose key is an lm_scale and the value is the
best decoding path for each utterance in the lattice.
""" """
device = lattice.device device = lattice.device
@ -543,3 +565,77 @@ def rescore_with_n_best_list(
key = f"lm_scale_{lm_scale}" key = f"lm_scale_{lm_scale}"
ans[key] = best_path ans[key] = best_path
return ans return ans
def rescore_with_whole_lattice(
lattice: k2.Fsa,
G_with_epsilon_loops: k2.Fsa,
lm_scale_list: Optional[List[float]] = None,
use_double_scores: bool = True,
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
# This is not an Nbest based coding method
assert hasattr(lattice, "lm_scores")
assert G_with_epsilon_loops.shape == (1, None, None)
device = lattice.device
lattice.scores = lattice.scores - lattice.lm_scores
# We will use lm_scores from G, so remove lats.lm_scores here
del lattice.lm_scores
assert hasattr(G_with_epsilon_loops, "lm_scores")
# Now, lattice.scores contains only am_scores
# inv_lattice has word IDs as labels.
# Its aux_labels are token IDs
inv_lattice = k2.invert(lattice)
num_seqs = lattice.shape[0]
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
max_loop_count = 10
loop_count = 0
while loop_count <= max_loop_count:
loop_count += 1
try:
rescoring_lattice = k2.intersect_device(
G_with_epsilon_loops,
inv_lattice,
b_to_a_map,
sorted_match_a=True,
)
rescoring_lattice = k2.top_sort(k2.connect(rescoring_lattice))
break
except RuntimeError as e:
logging.info(f"Caught exception:\n{e}\n")
logging.info(
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
)
# NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here
# to avoid OOM. You may need to fine tune it.
inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-9, True)
logging.info(
f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
)
if loop_count > max_loop_count:
logging.info("Return None as the resulting lattice is too large")
return None
# lat has token IDs as labels
# and word IDs as aux_labels.
lat = k2.invert(rescoring_lattice)
if lm_scale_list is None:
return lat
ans = dict()
saved_am_scores = lat.scores - lat.lm_scores
for lm_scale in lm_scale_list:
am_scores = saved_am_scores / lm_scale
lat.scores = am_scores + lat.lm_scores
best_path = k2.shortest_path(lat, use_double_scores=use_double_scores)
key = f"lm_scale_{lm_scale}_yy"
ans[key] = best_path
return ans