mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Add nbest decoding.
This commit is contained in:
parent
968e4a6609
commit
d2bedbe02e
@ -32,13 +32,13 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.decode2 import nbest_decoding
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -371,6 +371,8 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
if batch_idx > 20:
|
||||
break
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
|
@ -148,8 +148,7 @@ class Nbest(object):
|
||||
path from the resulting FsaVec.
|
||||
|
||||
Caution:
|
||||
We assume FSAs in `self.fsa` don't have epsilon self-loops.
|
||||
We also assume `self.fsa.labels` and `lattice.labels` are token IDs.
|
||||
We assume `self.fsa.labels` and `lattice.labels` are token IDs.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
@ -172,14 +171,22 @@ class Nbest(object):
|
||||
lattice.arcs.dim0() == self.shape.dim0
|
||||
), f"{lattice.arcs.dim0()} vs {self.shape.dim0}"
|
||||
|
||||
# Note: We view each linear FSA as a word sequence
|
||||
# and we use the passed lattice to give each word sequence a score.
|
||||
#
|
||||
# We are not viewing each linear FSAs as a token sequence.
|
||||
#
|
||||
# So we use k2.invert() here.
|
||||
|
||||
# We use a word fsa to intersect with k2.invert(lattice)
|
||||
word_fsa = k2.invert(self.fsa)
|
||||
|
||||
# delete token IDs as it is not needed
|
||||
del word_fsa.aux_labels
|
||||
word_fsa.scores.zero_()
|
||||
|
||||
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
|
||||
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
|
||||
word_fsa
|
||||
)
|
||||
|
||||
path_to_utt_map = self.shape.row_ids(1)
|
||||
|
||||
@ -222,3 +229,27 @@ class Nbest(object):
|
||||
use_double_scores=False, log_semiring=False
|
||||
)
|
||||
return k2.RaggedTensor(self.shape, scores)
|
||||
|
||||
|
||||
def nbest_decoding(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
use_double_scores: bool = True,
|
||||
scale: float = 1.0,
|
||||
) -> k2.Fsa:
|
||||
""" """
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
nbest = nbest.intersect(lattice)
|
||||
|
||||
# max_indexes contains the indexes for the max scores
|
||||
# of paths within an utterance.
|
||||
max_indexes = nbest.tot_scores().argmax()
|
||||
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
return best_path
|
||||
|
Loading…
x
Reference in New Issue
Block a user