Add whole-lattice rescoring.

This commit is contained in:
Fangjun Kuang 2021-09-17 21:05:00 +08:00
parent a44b4f84a5
commit b0b942b355
2 changed files with 120 additions and 11 deletions

View File

@ -42,6 +42,7 @@ from icefall.decode2 import (
nbest_decoding,
nbest_oracle as nbest_oracle2,
rescore_with_n_best_list as rescore_with_n_best_list2,
rescore_with_whole_lattice as rescore_with_whole_lattice2,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
@ -261,9 +262,7 @@ def decode_one_batch(
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = (
f"oracle_{num_paths}_lattice_score_scale_{lattice_score_scale}"
)
key = f"oracle_{num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
return {key: hyps}
else:
return nbest_oracle(
@ -322,8 +321,18 @@ def decode_one_batch(
scale=params.lattice_score_scale,
)
elif params.method == "whole-lattice-rescoring":
if True:
# TODO: remove "else" branch
best_path_dict = rescore_with_whole_lattice2(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
else:
best_path_dict = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
@ -345,10 +354,14 @@ def decode_one_batch(
assert False, f"Unsupported decoding method: {params.method}"
ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
for lm_scale in lm_scale_list:
ans[lm_scale_str] = [[] * lattice.shape[0]]
return ans

View File

@ -17,7 +17,8 @@
# NOTE: This file is a refactor of decode.py
# We will delete decode.py and rename this file to decode.py
from typing import Dict, List
import logging
from typing import Dict, List, Optional, Union
import k2
import torch
@ -505,6 +506,27 @@ def rescore_with_n_best_list(
) -> Dict[str, k2.Fsa]:
"""Rescore a nbest list with an n-gram LM.
The path with a maximum score is used as the decoding output.
Args:
lattice:
An FsaVec with axes [utt][state][arc]. It must have the following
attributes: ``aux_labels`` and ``lm_scores``. Its labels are
token IDs and ``aux_labels`` word IDs.
G:
An FsaVec containing only a single FSA. It is an n-gram LM.
num_paths:
Size of nbest list.
lm_scale_list:
A list of float representing LM score scales.
lattice_score_scale:
Scale to be applied to ``lattice.score`` when sampling paths
using ``k2.random_paths``.
use_double_scores:
True to use double precision during computation. False to use
single precision.
Returns:
A dict of FsaVec, whose key is an lm_scale and the value is the
best decoding path for each utterance in the lattice.
"""
device = lattice.device
@ -543,3 +565,77 @@ def rescore_with_n_best_list(
key = f"lm_scale_{lm_scale}"
ans[key] = best_path
return ans
def rescore_with_whole_lattice(
lattice: k2.Fsa,
G_with_epsilon_loops: k2.Fsa,
lm_scale_list: Optional[List[float]] = None,
use_double_scores: bool = True,
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
# This is not an Nbest based coding method
assert hasattr(lattice, "lm_scores")
assert G_with_epsilon_loops.shape == (1, None, None)
device = lattice.device
lattice.scores = lattice.scores - lattice.lm_scores
# We will use lm_scores from G, so remove lats.lm_scores here
del lattice.lm_scores
assert hasattr(G_with_epsilon_loops, "lm_scores")
# Now, lattice.scores contains only am_scores
# inv_lattice has word IDs as labels.
# Its aux_labels are token IDs
inv_lattice = k2.invert(lattice)
num_seqs = lattice.shape[0]
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
max_loop_count = 10
loop_count = 0
while loop_count <= max_loop_count:
loop_count += 1
try:
rescoring_lattice = k2.intersect_device(
G_with_epsilon_loops,
inv_lattice,
b_to_a_map,
sorted_match_a=True,
)
rescoring_lattice = k2.top_sort(k2.connect(rescoring_lattice))
break
except RuntimeError as e:
logging.info(f"Caught exception:\n{e}\n")
logging.info(
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
)
# NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here
# to avoid OOM. You may need to fine tune it.
inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-9, True)
logging.info(
f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
)
if loop_count > max_loop_count:
logging.info("Return None as the resulting lattice is too large")
return None
# lat has token IDs as labels
# and word IDs as aux_labels.
lat = k2.invert(rescoring_lattice)
if lm_scale_list is None:
return lat
ans = dict()
saved_am_scores = lat.scores - lat.lm_scores
for lm_scale in lm_scale_list:
am_scores = saved_am_scores / lm_scale
lat.scores = am_scores + lat.lm_scores
best_path = k2.shortest_path(lat, use_double_scores=use_double_scores)
key = f"lm_scale_{lm_scale}_yy"
ans[key] = best_path
return ans