Add nbest decoding.

This commit is contained in:
Fangjun Kuang 2021-09-14 13:45:51 +08:00
parent 968e4a6609
commit d2bedbe02e
2 changed files with 38 additions and 5 deletions

View File

@ -32,13 +32,13 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.decode2 import nbest_decoding
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -371,6 +371,8 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
if batch_idx > 20:
break
hyps_dict = decode_one_batch(
params=params,

View File

@ -148,8 +148,7 @@ class Nbest(object):
path from the resulting FsaVec.
Caution:
We assume FSAs in `self.fsa` don't have epsilon self-loops.
We also assume `self.fsa.labels` and `lattice.labels` are token IDs.
We assume `self.fsa.labels` and `lattice.labels` are token IDs.
Args:
lattice:
@ -172,14 +171,22 @@ class Nbest(object):
lattice.arcs.dim0() == self.shape.dim0
), f"{lattice.arcs.dim0()} vs {self.shape.dim0}"
# Note: We view each linear FSA as a word sequence
# and we use the passed lattice to give each word sequence a score.
#
# We are not viewing each linear FSAs as a token sequence.
#
# So we use k2.invert() here.
# We use a word fsa to intersect with k2.invert(lattice)
word_fsa = k2.invert(self.fsa)
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
word_fsa
)
path_to_utt_map = self.shape.row_ids(1)
@ -222,3 +229,27 @@ class Nbest(object):
use_double_scores=False, log_semiring=False
)
return k2.RaggedTensor(self.shape, scores)
def nbest_decoding(
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
scale: float = 1.0,
) -> k2.Fsa:
""" """
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
scale=scale,
)
nbest = nbest.intersect(lattice)
# max_indexes contains the indexes for the max scores
# of paths within an utterance.
max_indexes = nbest.tot_scores().argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
return best_path