Add rescore with nbest lists.

This commit is contained in:
Fangjun Kuang 2021-09-17 19:51:45 +08:00
parent 8807381401
commit a44b4f84a5
2 changed files with 183 additions and 41 deletions

View File

@ -32,13 +32,17 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import get_lattice
from icefall.decode import (
one_best_decoding,
one_best_decoding, # done
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_n_best_list, # done
rescore_with_whole_lattice,
nbest_oracle,
nbest_oracle, # done
)
from icefall.decode2 import (
nbest_decoding,
nbest_oracle as nbest_oracle2,
rescore_with_n_best_list as rescore_with_n_best_list2,
)
from icefall.decode2 import nbest_decoding, nbest_oracle as nbest_oracle2
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -257,7 +261,10 @@ def decode_one_batch(
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {"nbest-orcale": hyps}
key = (
f"oracle_{num_paths}_lattice_score_scale_{lattice_score_scale}"
)
return {key: hyps}
else:
return nbest_oracle(
lattice=lattice,
@ -297,13 +304,23 @@ def decode_one_batch(
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
scale=params.lattice_score_scale,
)
if True:
# TODO: remove the "else" branch
best_path_dict = rescore_with_n_best_list2(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
lattice_score_scale=params.lattice_score_scale,
)
else:
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
scale=params.lattice_score_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
@ -385,7 +402,8 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
if batch_idx > 20:
# TODO: remove it
if batch_idx > 100:
break
hyps_dict = decode_one_batch(

View File

@ -25,6 +25,47 @@ import torch
from icefall.utils import get_texts
def _intersect_device(
a_fsas: k2.Fsa,
b_fsas: k2.Fsa,
b_to_a_map: torch.Tensor,
sorted_match_a: bool,
batch_size: int = 50,
) -> k2.Fsa:
"""This is a wrapper of k2.intersect_device and its purpose is to split
b_fsas into several batches and process each batch separately to avoid
CUDA OOM error.
The arguments and return value of this function are the same as
k2.intersect_device.
"""
num_fsas = b_fsas.shape[0]
if num_fsas <= batch_size:
return k2.intersect_device(
a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a
)
num_batches = (num_fsas + batch_size - 1) // batch_size
splits = []
for i in range(num_batches):
start = i * batch_size
end = min(start + batch_size, num_fsas)
splits.append((start, end))
ans = []
for start, end in splits:
indexes = torch.arange(start, end).to(b_to_a_map)
fsas = k2.index_fsa(b_fsas, indexes)
b_to_a = k2.index_select(b_to_a_map, indexes)
path_lattice = k2.intersect_device(
a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a
)
ans.append(path_lattice)
return k2.cat(ans)
# TODO(fangjun): Use Kangwei's C++ implementation that also
# supports List[List[int]]
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
@ -181,18 +222,22 @@ class Nbest(object):
fsa = k2.linear_fsa(labels)
fsa.aux_labels = aux_labels
# Caution: fsa.scores are all 0s.
return Nbest(fsa=fsa, shape=utt_to_path_shape)
def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest":
"""Intersect this Nbest object with a lattice and get 1-best
path from the resulting FsaVec.
Caution:
We assume `self.fsa.labels` and `lattice.labels` are token IDs.
The purpose of this function is to attach scores to an Nbest.
Args:
lattice:
An FsaVec with axes [utt][state][arc]
An FsaVec with axes [utt][state][arc]. If it has `aux_labels`, then
we assume its `labels` are token IDs and `aux_labels` are word IDs.
If it has only `labels`, we assume it `labels` are word IDs.
use_double_scores:
True to use double precision when computing shortest path.
False to use single precision.
@ -201,16 +246,6 @@ class Nbest(object):
while its `fsa` is the 1-best path from intersecting `self.fsa` and
`lattice`.
"""
assert (
self.fsa.device == lattice.device
), f"{self.fsa.device} vs {lattice.device}"
assert len(lattice.shape) == 3, f"{lattice.shape}"
assert (
lattice.arcs.dim0() == self.shape.dim0
), f"{lattice.arcs.dim0()} vs {self.shape.dim0}"
# Note: We view each linear FSA as a word sequence
# and we use the passed lattice to give each word sequence a score.
#
@ -221,8 +256,10 @@ class Nbest(object):
# We use a word fsa to intersect with k2.invert(lattice)
word_fsa = k2.invert(self.fsa)
# delete token IDs as it is not needed
del word_fsa.aux_labels
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
word_fsa
@ -230,17 +267,28 @@ class Nbest(object):
path_to_utt_map = self.shape.row_ids(1)
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
if hasattr(lattice, "aux_labels"):
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
else:
inv_lattice = k2.arc_sort(lattice)
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
if inv_lattice.shape[0] == 1:
path_lattice = _intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=torch.zeros_like(path_to_utt_map),
sorted_match_a=True,
)
else:
path_lattice = _intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
# path_lattice has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
@ -254,6 +302,29 @@ class Nbest(object):
return Nbest(fsa=one_best, shape=self.shape)
def compute_am_scores(self) -> k2.RaggedTensor:
"""Compute AM scores of each linear FSA (i.e., each path within
an utterance).
Hint:
`self.fsa.scores` contains two parts: am scores and lm scores.
Returns:
Return a ragged tensor with 2 axes [utt][path_scores].
Its dtype is torch.float64.
"""
saved_scores = self.fsa.scores
# The `scores` of every arc consists of `am_scores` and `lm_scores`
self.fsa.scores = self.fsa.scores - self.fsa.lm_scores
am_scores = self.fsa.get_tot_scores(
use_double_scores=True, log_semiring=False
)
self.fsa.scores = saved_scores
return k2.RaggedTensor(self.shape, am_scores)
def tot_scores(self) -> k2.RaggedTensor:
"""Get total scores of the FSAs in this Nbest.
@ -263,10 +334,11 @@ class Nbest(object):
Returns:
Return a ragged tensor with two axes [utt][path_scores].
Its dtype is torch.float64.
"""
# Use single precision since there are only additions.
scores = self.fsa.get_tot_scores(
use_double_scores=False, log_semiring=False
use_double_scores=True, log_semiring=False
)
return k2.RaggedTensor(self.shape, scores)
@ -317,11 +389,13 @@ def nbest_decoding(
use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale,
)
# nbest.fsa.scores contains 0s
nbest = nbest.intersect(lattice)
# now nbest.fsa.scores gets assigned
# max_indexes contains the indexes for the max scores
# of paths within an utterance.
# max_indexes contains the indexes for the path with the maximum score
# within an utterance.
max_indexes = nbest.tot_scores().argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
@ -419,3 +493,53 @@ def nbest_oracle(
best_path = k2.index_fsa(nbest.fsa, max_indexes)
return best_path
def rescore_with_n_best_list(
lattice: k2.Fsa,
G: k2.Fsa,
num_paths: int,
lm_scale_list: List[float],
lattice_score_scale: float = 1.0,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
"""Rescore a nbest list with an n-gram LM.
The path with a maximum score is used as the decoding output.
"""
device = lattice.device
assert len(lattice.shape) == 3
assert hasattr(lattice, "aux_labels")
assert hasattr(lattice, "lm_scores")
assert G.shape == (1, None, None)
assert G.device == device
assert hasattr(G, "aux_labels") is False
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale,
)
# nbest.fsa.scores are all 0s at this point
nbest = nbest.intersect(lattice)
# Now nbest.fsa has it scores set
assert hasattr(nbest.fsa, "lm_scores")
am_scores = nbest.compute_am_scores()
nbest = nbest.intersect(G)
# Now nbest contains only lm scores
lm_scores = nbest.tot_scores()
ans = dict()
for lm_scale in lm_scale_list:
tot_scores = am_scores.values / lm_scale + lm_scores.values
tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
key = f"lm_scale_{lm_scale}"
ans[key] = best_path
return ans