mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-15 20:22:42 +00:00
Add rescore with nbest lists.
This commit is contained in:
parent
8807381401
commit
a44b4f84a5
@ -32,13 +32,17 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
|||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.decode import get_lattice
|
from icefall.decode import get_lattice
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
one_best_decoding,
|
one_best_decoding, # done
|
||||||
rescore_with_attention_decoder,
|
rescore_with_attention_decoder,
|
||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list, # done
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
nbest_oracle,
|
nbest_oracle, # done
|
||||||
|
)
|
||||||
|
from icefall.decode2 import (
|
||||||
|
nbest_decoding,
|
||||||
|
nbest_oracle as nbest_oracle2,
|
||||||
|
rescore_with_n_best_list as rescore_with_n_best_list2,
|
||||||
)
|
)
|
||||||
from icefall.decode2 import nbest_decoding, nbest_oracle as nbest_oracle2
|
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -257,7 +261,10 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
hyps = get_texts(best_path)
|
hyps = get_texts(best_path)
|
||||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
return {"nbest-orcale": hyps}
|
key = (
|
||||||
|
f"oracle_{num_paths}_lattice_score_scale_{lattice_score_scale}"
|
||||||
|
)
|
||||||
|
return {key: hyps}
|
||||||
else:
|
else:
|
||||||
return nbest_oracle(
|
return nbest_oracle(
|
||||||
lattice=lattice,
|
lattice=lattice,
|
||||||
@ -297,6 +304,16 @@ def decode_one_batch(
|
|||||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||||
|
|
||||||
if params.method == "nbest-rescoring":
|
if params.method == "nbest-rescoring":
|
||||||
|
if True:
|
||||||
|
# TODO: remove the "else" branch
|
||||||
|
best_path_dict = rescore_with_n_best_list2(
|
||||||
|
lattice=lattice,
|
||||||
|
G=G,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
lm_scale_list=lm_scale_list,
|
||||||
|
lattice_score_scale=params.lattice_score_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
best_path_dict = rescore_with_n_best_list(
|
best_path_dict = rescore_with_n_best_list(
|
||||||
lattice=lattice,
|
lattice=lattice,
|
||||||
G=G,
|
G=G,
|
||||||
@ -385,7 +402,8 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
if batch_idx > 20:
|
# TODO: remove it
|
||||||
|
if batch_idx > 100:
|
||||||
break
|
break
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
|
@ -25,6 +25,47 @@ import torch
|
|||||||
from icefall.utils import get_texts
|
from icefall.utils import get_texts
|
||||||
|
|
||||||
|
|
||||||
|
def _intersect_device(
|
||||||
|
a_fsas: k2.Fsa,
|
||||||
|
b_fsas: k2.Fsa,
|
||||||
|
b_to_a_map: torch.Tensor,
|
||||||
|
sorted_match_a: bool,
|
||||||
|
batch_size: int = 50,
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""This is a wrapper of k2.intersect_device and its purpose is to split
|
||||||
|
b_fsas into several batches and process each batch separately to avoid
|
||||||
|
CUDA OOM error.
|
||||||
|
|
||||||
|
The arguments and return value of this function are the same as
|
||||||
|
k2.intersect_device.
|
||||||
|
"""
|
||||||
|
num_fsas = b_fsas.shape[0]
|
||||||
|
if num_fsas <= batch_size:
|
||||||
|
return k2.intersect_device(
|
||||||
|
a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a
|
||||||
|
)
|
||||||
|
|
||||||
|
num_batches = (num_fsas + batch_size - 1) // batch_size
|
||||||
|
splits = []
|
||||||
|
for i in range(num_batches):
|
||||||
|
start = i * batch_size
|
||||||
|
end = min(start + batch_size, num_fsas)
|
||||||
|
splits.append((start, end))
|
||||||
|
|
||||||
|
ans = []
|
||||||
|
for start, end in splits:
|
||||||
|
indexes = torch.arange(start, end).to(b_to_a_map)
|
||||||
|
|
||||||
|
fsas = k2.index_fsa(b_fsas, indexes)
|
||||||
|
b_to_a = k2.index_select(b_to_a_map, indexes)
|
||||||
|
path_lattice = k2.intersect_device(
|
||||||
|
a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a
|
||||||
|
)
|
||||||
|
ans.append(path_lattice)
|
||||||
|
|
||||||
|
return k2.cat(ans)
|
||||||
|
|
||||||
|
|
||||||
# TODO(fangjun): Use Kangwei's C++ implementation that also
|
# TODO(fangjun): Use Kangwei's C++ implementation that also
|
||||||
# supports List[List[int]]
|
# supports List[List[int]]
|
||||||
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
|
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
|
||||||
@ -181,18 +222,22 @@ class Nbest(object):
|
|||||||
|
|
||||||
fsa = k2.linear_fsa(labels)
|
fsa = k2.linear_fsa(labels)
|
||||||
fsa.aux_labels = aux_labels
|
fsa.aux_labels = aux_labels
|
||||||
|
# Caution: fsa.scores are all 0s.
|
||||||
return Nbest(fsa=fsa, shape=utt_to_path_shape)
|
return Nbest(fsa=fsa, shape=utt_to_path_shape)
|
||||||
|
|
||||||
def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest":
|
def intersect(self, lattice: k2.Fsa, use_double_scores=True) -> "Nbest":
|
||||||
"""Intersect this Nbest object with a lattice and get 1-best
|
"""Intersect this Nbest object with a lattice and get 1-best
|
||||||
path from the resulting FsaVec.
|
path from the resulting FsaVec.
|
||||||
|
|
||||||
Caution:
|
The purpose of this function is to attach scores to an Nbest.
|
||||||
We assume `self.fsa.labels` and `lattice.labels` are token IDs.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lattice:
|
lattice:
|
||||||
An FsaVec with axes [utt][state][arc]
|
An FsaVec with axes [utt][state][arc]. If it has `aux_labels`, then
|
||||||
|
we assume its `labels` are token IDs and `aux_labels` are word IDs.
|
||||||
|
If it has only `labels`, we assume it `labels` are word IDs.
|
||||||
|
|
||||||
use_double_scores:
|
use_double_scores:
|
||||||
True to use double precision when computing shortest path.
|
True to use double precision when computing shortest path.
|
||||||
False to use single precision.
|
False to use single precision.
|
||||||
@ -201,16 +246,6 @@ class Nbest(object):
|
|||||||
while its `fsa` is the 1-best path from intersecting `self.fsa` and
|
while its `fsa` is the 1-best path from intersecting `self.fsa` and
|
||||||
`lattice`.
|
`lattice`.
|
||||||
"""
|
"""
|
||||||
assert (
|
|
||||||
self.fsa.device == lattice.device
|
|
||||||
), f"{self.fsa.device} vs {lattice.device}"
|
|
||||||
|
|
||||||
assert len(lattice.shape) == 3, f"{lattice.shape}"
|
|
||||||
|
|
||||||
assert (
|
|
||||||
lattice.arcs.dim0() == self.shape.dim0
|
|
||||||
), f"{lattice.arcs.dim0()} vs {self.shape.dim0}"
|
|
||||||
|
|
||||||
# Note: We view each linear FSA as a word sequence
|
# Note: We view each linear FSA as a word sequence
|
||||||
# and we use the passed lattice to give each word sequence a score.
|
# and we use the passed lattice to give each word sequence a score.
|
||||||
#
|
#
|
||||||
@ -221,8 +256,10 @@ class Nbest(object):
|
|||||||
# We use a word fsa to intersect with k2.invert(lattice)
|
# We use a word fsa to intersect with k2.invert(lattice)
|
||||||
word_fsa = k2.invert(self.fsa)
|
word_fsa = k2.invert(self.fsa)
|
||||||
|
|
||||||
|
if hasattr(lattice, "aux_labels"):
|
||||||
# delete token IDs as it is not needed
|
# delete token IDs as it is not needed
|
||||||
del word_fsa.aux_labels
|
del word_fsa.aux_labels
|
||||||
|
|
||||||
word_fsa.scores.zero_()
|
word_fsa.scores.zero_()
|
||||||
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
|
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
|
||||||
word_fsa
|
word_fsa
|
||||||
@ -230,12 +267,23 @@ class Nbest(object):
|
|||||||
|
|
||||||
path_to_utt_map = self.shape.row_ids(1)
|
path_to_utt_map = self.shape.row_ids(1)
|
||||||
|
|
||||||
|
if hasattr(lattice, "aux_labels"):
|
||||||
# lattice has token IDs as labels and word IDs as aux_labels.
|
# lattice has token IDs as labels and word IDs as aux_labels.
|
||||||
# inv_lattice has word IDs as labels and token IDs as aux_labels
|
# inv_lattice has word IDs as labels and token IDs as aux_labels
|
||||||
inv_lattice = k2.invert(lattice)
|
inv_lattice = k2.invert(lattice)
|
||||||
inv_lattice = k2.arc_sort(inv_lattice)
|
inv_lattice = k2.arc_sort(inv_lattice)
|
||||||
|
else:
|
||||||
|
inv_lattice = k2.arc_sort(lattice)
|
||||||
|
|
||||||
path_lattice = k2.intersect_device(
|
if inv_lattice.shape[0] == 1:
|
||||||
|
path_lattice = _intersect_device(
|
||||||
|
inv_lattice,
|
||||||
|
word_fsa_with_epsilon_loops,
|
||||||
|
b_to_a_map=torch.zeros_like(path_to_utt_map),
|
||||||
|
sorted_match_a=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
path_lattice = _intersect_device(
|
||||||
inv_lattice,
|
inv_lattice,
|
||||||
word_fsa_with_epsilon_loops,
|
word_fsa_with_epsilon_loops,
|
||||||
b_to_a_map=path_to_utt_map,
|
b_to_a_map=path_to_utt_map,
|
||||||
@ -254,6 +302,29 @@ class Nbest(object):
|
|||||||
|
|
||||||
return Nbest(fsa=one_best, shape=self.shape)
|
return Nbest(fsa=one_best, shape=self.shape)
|
||||||
|
|
||||||
|
def compute_am_scores(self) -> k2.RaggedTensor:
|
||||||
|
"""Compute AM scores of each linear FSA (i.e., each path within
|
||||||
|
an utterance).
|
||||||
|
|
||||||
|
Hint:
|
||||||
|
`self.fsa.scores` contains two parts: am scores and lm scores.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a ragged tensor with 2 axes [utt][path_scores].
|
||||||
|
Its dtype is torch.float64.
|
||||||
|
"""
|
||||||
|
saved_scores = self.fsa.scores
|
||||||
|
|
||||||
|
# The `scores` of every arc consists of `am_scores` and `lm_scores`
|
||||||
|
self.fsa.scores = self.fsa.scores - self.fsa.lm_scores
|
||||||
|
|
||||||
|
am_scores = self.fsa.get_tot_scores(
|
||||||
|
use_double_scores=True, log_semiring=False
|
||||||
|
)
|
||||||
|
self.fsa.scores = saved_scores
|
||||||
|
|
||||||
|
return k2.RaggedTensor(self.shape, am_scores)
|
||||||
|
|
||||||
def tot_scores(self) -> k2.RaggedTensor:
|
def tot_scores(self) -> k2.RaggedTensor:
|
||||||
"""Get total scores of the FSAs in this Nbest.
|
"""Get total scores of the FSAs in this Nbest.
|
||||||
|
|
||||||
@ -263,10 +334,11 @@ class Nbest(object):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a ragged tensor with two axes [utt][path_scores].
|
Return a ragged tensor with two axes [utt][path_scores].
|
||||||
|
Its dtype is torch.float64.
|
||||||
"""
|
"""
|
||||||
# Use single precision since there are only additions.
|
# Use single precision since there are only additions.
|
||||||
scores = self.fsa.get_tot_scores(
|
scores = self.fsa.get_tot_scores(
|
||||||
use_double_scores=False, log_semiring=False
|
use_double_scores=True, log_semiring=False
|
||||||
)
|
)
|
||||||
return k2.RaggedTensor(self.shape, scores)
|
return k2.RaggedTensor(self.shape, scores)
|
||||||
|
|
||||||
@ -317,11 +389,13 @@ def nbest_decoding(
|
|||||||
use_double_scores=use_double_scores,
|
use_double_scores=use_double_scores,
|
||||||
lattice_score_scale=lattice_score_scale,
|
lattice_score_scale=lattice_score_scale,
|
||||||
)
|
)
|
||||||
|
# nbest.fsa.scores contains 0s
|
||||||
|
|
||||||
nbest = nbest.intersect(lattice)
|
nbest = nbest.intersect(lattice)
|
||||||
|
# now nbest.fsa.scores gets assigned
|
||||||
|
|
||||||
# max_indexes contains the indexes for the max scores
|
# max_indexes contains the indexes for the path with the maximum score
|
||||||
# of paths within an utterance.
|
# within an utterance.
|
||||||
max_indexes = nbest.tot_scores().argmax()
|
max_indexes = nbest.tot_scores().argmax()
|
||||||
|
|
||||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
@ -419,3 +493,53 @@ def nbest_oracle(
|
|||||||
|
|
||||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
return best_path
|
return best_path
|
||||||
|
|
||||||
|
|
||||||
|
def rescore_with_n_best_list(
|
||||||
|
lattice: k2.Fsa,
|
||||||
|
G: k2.Fsa,
|
||||||
|
num_paths: int,
|
||||||
|
lm_scale_list: List[float],
|
||||||
|
lattice_score_scale: float = 1.0,
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
) -> Dict[str, k2.Fsa]:
|
||||||
|
"""Rescore a nbest list with an n-gram LM.
|
||||||
|
The path with a maximum score is used as the decoding output.
|
||||||
|
"""
|
||||||
|
device = lattice.device
|
||||||
|
|
||||||
|
assert len(lattice.shape) == 3
|
||||||
|
assert hasattr(lattice, "aux_labels")
|
||||||
|
assert hasattr(lattice, "lm_scores")
|
||||||
|
|
||||||
|
assert G.shape == (1, None, None)
|
||||||
|
assert G.device == device
|
||||||
|
assert hasattr(G, "aux_labels") is False
|
||||||
|
|
||||||
|
nbest = Nbest.from_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
lattice_score_scale=lattice_score_scale,
|
||||||
|
)
|
||||||
|
# nbest.fsa.scores are all 0s at this point
|
||||||
|
|
||||||
|
nbest = nbest.intersect(lattice)
|
||||||
|
# Now nbest.fsa has it scores set
|
||||||
|
assert hasattr(nbest.fsa, "lm_scores")
|
||||||
|
|
||||||
|
am_scores = nbest.compute_am_scores()
|
||||||
|
|
||||||
|
nbest = nbest.intersect(G)
|
||||||
|
# Now nbest contains only lm scores
|
||||||
|
lm_scores = nbest.tot_scores()
|
||||||
|
|
||||||
|
ans = dict()
|
||||||
|
for lm_scale in lm_scale_list:
|
||||||
|
tot_scores = am_scores.values / lm_scale + lm_scores.values
|
||||||
|
tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
|
max_indexes = tot_scores.argmax()
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
key = f"lm_scale_{lm_scale}"
|
||||||
|
ans[key] = best_path
|
||||||
|
return ans
|
||||||
|
Loading…
x
Reference in New Issue
Block a user