mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Add rescore with nbest lists.
This commit is contained in:
parent
8807381401
commit
a44b4f84a5
@ -32,13 +32,17 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import get_lattice
|
||||
from icefall.decode import (
|
||||
one_best_decoding,
|
||||
one_best_decoding, # done
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_n_best_list, # done
|
||||
rescore_with_whole_lattice,
|
||||
nbest_oracle,
|
||||
nbest_oracle, # done
|
||||
)
|
||||
from icefall.decode2 import (
|
||||
nbest_decoding,
|
||||
nbest_oracle as nbest_oracle2,
|
||||
rescore_with_n_best_list as rescore_with_n_best_list2,
|
||||
)
|
||||
from icefall.decode2 import nbest_decoding, nbest_oracle as nbest_oracle2
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -257,7 +261,10 @@ def decode_one_batch(
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
return {"nbest-orcale": hyps}
|
||||
key = (
|
||||
f"oracle_{num_paths}_lattice_score_scale_{lattice_score_scale}"
|
||||
)
|
||||
return {key: hyps}
|
||||
else:
|
||||
return nbest_oracle(
|
||||
lattice=lattice,
|
||||
@ -297,13 +304,23 @@ def decode_one_batch(
|
||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||
|
||||
if params.method == "nbest-rescoring":
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
scale=params.lattice_score_scale,
|
||||
)
|
||||
if True:
|
||||
# TODO: remove the "else" branch
|
||||
best_path_dict = rescore_with_n_best_list2(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
else:
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
scale=params.lattice_score_scale,
|
||||
)
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
|
||||
@ -385,7 +402,8 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
if batch_idx > 20:
|
||||
# TODO: remove it
|
||||
if batch_idx > 100:
|
||||
break
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
|
@ -25,6 +25,47 @@ import torch
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
def _intersect_device(
|
||||
a_fsas: k2.Fsa,
|
||||
b_fsas: k2.Fsa,
|
||||
b_to_a_map: torch.Tensor,
|
||||
sorted_match_a: bool,
|
||||
batch_size: int = 50,
|
||||
) -> k2.Fsa:
|
||||
"""This is a wrapper of k2.intersect_device and its purpose is to split
|
||||
b_fsas into several batches and process each batch separately to avoid
|
||||
CUDA OOM error.
|
||||
|
||||
The arguments and return value of this function are the same as
|
||||
k2.intersect_device.
|
||||
"""
|
||||
num_fsas = b_fsas.shape[0]
|
||||
if num_fsas <= batch_size:
|
||||
return k2.intersect_device(
|
||||
a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a
|
||||
)
|
||||
|
||||
num_batches = (num_fsas + batch_size - 1) // batch_size
|
||||
splits = []
|
||||
for i in range(num_batches):
|
||||
start = i * batch_size
|
||||
end = min(start + batch_size, num_fsas)
|
||||
splits.append((start, end))
|
||||
|
||||
ans = []
|
||||
for start, end in splits:
|
||||
indexes = torch.arange(start, end).to(b_to_a_map)
|
||||
|
||||
fsas = k2.index_fsa(b_fsas, indexes)
|
||||
b_to_a = k2.index_select(b_to_a_map, indexes)
|
||||
path_lattice = k2.intersect_device(
|
||||
a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a
|
||||
)
|
||||
ans.append(path_lattice)
|
||||
|
||||
return k2.cat(ans)
|
||||
|
||||
|
||||
# TODO(fangjun): Use Kangwei's C++ implementation that also
|
||||
# supports List[List[int]]
|
||||
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
|
||||
@ -181,18 +222,22 @@ class Nbest(object):
|
||||
|
||||
fsa = k2.linear_fsa(labels)
|
||||
fsa.aux_labels = aux_labels
|
||||
# Caution: fsa.scores are all 0s.
|
||||
return Nbest(fsa=fsa, shape=utt_to_path_shape)
|
||||
|
||||
def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest":
|
||||
"""Intersect this Nbest object with a lattice and get 1-best
|
||||
path from the resulting FsaVec.
|
||||
|
||||
Caution:
|
||||
We assume `self.fsa.labels` and `lattice.labels` are token IDs.
|
||||
The purpose of this function is to attach scores to an Nbest.
|
||||
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec with axes [utt][state][arc]
|
||||
An FsaVec with axes [utt][state][arc]. If it has `aux_labels`, then
|
||||
we assume its `labels` are token IDs and `aux_labels` are word IDs.
|
||||
If it has only `labels`, we assume it `labels` are word IDs.
|
||||
|
||||
use_double_scores:
|
||||
True to use double precision when computing shortest path.
|
||||
False to use single precision.
|
||||
@ -201,16 +246,6 @@ class Nbest(object):
|
||||
while its `fsa` is the 1-best path from intersecting `self.fsa` and
|
||||
`lattice`.
|
||||
"""
|
||||
assert (
|
||||
self.fsa.device == lattice.device
|
||||
), f"{self.fsa.device} vs {lattice.device}"
|
||||
|
||||
assert len(lattice.shape) == 3, f"{lattice.shape}"
|
||||
|
||||
assert (
|
||||
lattice.arcs.dim0() == self.shape.dim0
|
||||
), f"{lattice.arcs.dim0()} vs {self.shape.dim0}"
|
||||
|
||||
# Note: We view each linear FSA as a word sequence
|
||||
# and we use the passed lattice to give each word sequence a score.
|
||||
#
|
||||
@ -221,8 +256,10 @@ class Nbest(object):
|
||||
# We use a word fsa to intersect with k2.invert(lattice)
|
||||
word_fsa = k2.invert(self.fsa)
|
||||
|
||||
# delete token IDs as it is not needed
|
||||
del word_fsa.aux_labels
|
||||
if hasattr(lattice, "aux_labels"):
|
||||
# delete token IDs as it is not needed
|
||||
del word_fsa.aux_labels
|
||||
|
||||
word_fsa.scores.zero_()
|
||||
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
|
||||
word_fsa
|
||||
@ -230,17 +267,28 @@ class Nbest(object):
|
||||
|
||||
path_to_utt_map = self.shape.row_ids(1)
|
||||
|
||||
# lattice has token IDs as labels and word IDs as aux_labels.
|
||||
# inv_lattice has word IDs as labels and token IDs as aux_labels
|
||||
inv_lattice = k2.invert(lattice)
|
||||
inv_lattice = k2.arc_sort(inv_lattice)
|
||||
if hasattr(lattice, "aux_labels"):
|
||||
# lattice has token IDs as labels and word IDs as aux_labels.
|
||||
# inv_lattice has word IDs as labels and token IDs as aux_labels
|
||||
inv_lattice = k2.invert(lattice)
|
||||
inv_lattice = k2.arc_sort(inv_lattice)
|
||||
else:
|
||||
inv_lattice = k2.arc_sort(lattice)
|
||||
|
||||
path_lattice = k2.intersect_device(
|
||||
inv_lattice,
|
||||
word_fsa_with_epsilon_loops,
|
||||
b_to_a_map=path_to_utt_map,
|
||||
sorted_match_a=True,
|
||||
)
|
||||
if inv_lattice.shape[0] == 1:
|
||||
path_lattice = _intersect_device(
|
||||
inv_lattice,
|
||||
word_fsa_with_epsilon_loops,
|
||||
b_to_a_map=torch.zeros_like(path_to_utt_map),
|
||||
sorted_match_a=True,
|
||||
)
|
||||
else:
|
||||
path_lattice = _intersect_device(
|
||||
inv_lattice,
|
||||
word_fsa_with_epsilon_loops,
|
||||
b_to_a_map=path_to_utt_map,
|
||||
sorted_match_a=True,
|
||||
)
|
||||
|
||||
# path_lattice has word IDs as labels and token IDs as aux_labels
|
||||
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
||||
@ -254,6 +302,29 @@ class Nbest(object):
|
||||
|
||||
return Nbest(fsa=one_best, shape=self.shape)
|
||||
|
||||
def compute_am_scores(self) -> k2.RaggedTensor:
|
||||
"""Compute AM scores of each linear FSA (i.e., each path within
|
||||
an utterance).
|
||||
|
||||
Hint:
|
||||
`self.fsa.scores` contains two parts: am scores and lm scores.
|
||||
|
||||
Returns:
|
||||
Return a ragged tensor with 2 axes [utt][path_scores].
|
||||
Its dtype is torch.float64.
|
||||
"""
|
||||
saved_scores = self.fsa.scores
|
||||
|
||||
# The `scores` of every arc consists of `am_scores` and `lm_scores`
|
||||
self.fsa.scores = self.fsa.scores - self.fsa.lm_scores
|
||||
|
||||
am_scores = self.fsa.get_tot_scores(
|
||||
use_double_scores=True, log_semiring=False
|
||||
)
|
||||
self.fsa.scores = saved_scores
|
||||
|
||||
return k2.RaggedTensor(self.shape, am_scores)
|
||||
|
||||
def tot_scores(self) -> k2.RaggedTensor:
|
||||
"""Get total scores of the FSAs in this Nbest.
|
||||
|
||||
@ -263,10 +334,11 @@ class Nbest(object):
|
||||
|
||||
Returns:
|
||||
Return a ragged tensor with two axes [utt][path_scores].
|
||||
Its dtype is torch.float64.
|
||||
"""
|
||||
# Use single precision since there are only additions.
|
||||
scores = self.fsa.get_tot_scores(
|
||||
use_double_scores=False, log_semiring=False
|
||||
use_double_scores=True, log_semiring=False
|
||||
)
|
||||
return k2.RaggedTensor(self.shape, scores)
|
||||
|
||||
@ -317,11 +389,13 @@ def nbest_decoding(
|
||||
use_double_scores=use_double_scores,
|
||||
lattice_score_scale=lattice_score_scale,
|
||||
)
|
||||
# nbest.fsa.scores contains 0s
|
||||
|
||||
nbest = nbest.intersect(lattice)
|
||||
# now nbest.fsa.scores gets assigned
|
||||
|
||||
# max_indexes contains the indexes for the max scores
|
||||
# of paths within an utterance.
|
||||
# max_indexes contains the indexes for the path with the maximum score
|
||||
# within an utterance.
|
||||
max_indexes = nbest.tot_scores().argmax()
|
||||
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
@ -419,3 +493,53 @@ def nbest_oracle(
|
||||
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
return best_path
|
||||
|
||||
|
||||
def rescore_with_n_best_list(
|
||||
lattice: k2.Fsa,
|
||||
G: k2.Fsa,
|
||||
num_paths: int,
|
||||
lm_scale_list: List[float],
|
||||
lattice_score_scale: float = 1.0,
|
||||
use_double_scores: bool = True,
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
"""Rescore a nbest list with an n-gram LM.
|
||||
The path with a maximum score is used as the decoding output.
|
||||
"""
|
||||
device = lattice.device
|
||||
|
||||
assert len(lattice.shape) == 3
|
||||
assert hasattr(lattice, "aux_labels")
|
||||
assert hasattr(lattice, "lm_scores")
|
||||
|
||||
assert G.shape == (1, None, None)
|
||||
assert G.device == device
|
||||
assert hasattr(G, "aux_labels") is False
|
||||
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
lattice_score_scale=lattice_score_scale,
|
||||
)
|
||||
# nbest.fsa.scores are all 0s at this point
|
||||
|
||||
nbest = nbest.intersect(lattice)
|
||||
# Now nbest.fsa has it scores set
|
||||
assert hasattr(nbest.fsa, "lm_scores")
|
||||
|
||||
am_scores = nbest.compute_am_scores()
|
||||
|
||||
nbest = nbest.intersect(G)
|
||||
# Now nbest contains only lm scores
|
||||
lm_scores = nbest.tot_scores()
|
||||
|
||||
ans = dict()
|
||||
for lm_scale in lm_scale_list:
|
||||
tot_scores = am_scores.values / lm_scale + lm_scores.values
|
||||
tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
max_indexes = tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
key = f"lm_scale_{lm_scale}"
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
Loading…
x
Reference in New Issue
Block a user