Refactor nbest-oracle.

This commit is contained in:
Fangjun Kuang 2021-09-17 15:46:23 +08:00
parent d6a995978a
commit b9fc46f432
4 changed files with 38 additions and 22 deletions

View File

@ -30,14 +30,15 @@ from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import get_lattice
from icefall.decode import ( from icefall.decode import (
get_lattice,
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder, rescore_with_attention_decoder,
rescore_with_n_best_list, rescore_with_n_best_list,
rescore_with_whole_lattice, rescore_with_whole_lattice,
nbest_oracle,
) )
from icefall.decode2 import nbest_decoding, nbest_oracle from icefall.decode2 import nbest_decoding, nbest_oracle as nbest_oracle2
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -244,17 +245,27 @@ def decode_one_batch(
# We choose the HLG decoded lattice for speed reasons # We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER # as HLG decoding is faster and the oracle WER
# is slightly worse than that of rescored lattices. # is slightly worse than that of rescored lattices.
best_path = nbest_oracle( if True:
lattice=lattice, # TODO: delete the `else` branch
num_paths=params.num_paths, best_path = nbest_oracle2(
ref_texts=supervisions["text"], lattice=lattice,
word_table=word_table, num_paths=params.num_paths,
lattice_score_scale=params.lattice_score_scale, ref_texts=supervisions["text"],
oov="<UNK>", word_table=word_table,
) lattice_score_scale=params.lattice_score_scale,
hyps = get_texts(best_path) oov="<UNK>",
hyps = [[word_table[i] for i in ids] for ids in hyps] )
return {"nbest-orcale": hyps} hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {"nbest-orcale": hyps}
else:
return nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
scale=params.lattice_score_scale,
)
if params.method in ["1best", "nbest"]: if params.method in ["1best", "nbest"]:
if params.method == "1best": if params.method == "1best":

View File

@ -669,7 +669,7 @@ def nbest_oracle(
word_seq = k2.ragged.index(lattice.aux_labels, path) word_seq = k2.ragged.index(lattice.aux_labels, path)
else: else:
word_seq = lattice.aux_labels.index(path) word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1) word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
word_seq = word_seq.remove_values_leq(0) word_seq = word_seq.remove_values_leq(0)
unique_word_seq, _, _ = word_seq.unique( unique_word_seq, _, _ = word_seq.unique(

View File

@ -25,7 +25,8 @@ import torch
from icefall.utils import get_texts from icefall.utils import get_texts
# TODO(fangjun): Use Kangwei's C++ implementation that also List[List[int]] # TODO(fangjun): Use Kangwei's C++ implementation that also
# supports List[List[int]]
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa: def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
"""Construct a graph to compute Levenshtein distance. """Construct a graph to compute Levenshtein distance.
@ -42,10 +43,10 @@ def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
final_state = len(symbol_ids) + 1 final_state = len(symbol_ids) + 1
arcs = [] arcs = []
for i in range(final_state - 1): for i in range(final_state - 1):
arcs.append([i, i, 0, 0, -1.01]) arcs.append([i, i, 0, 0, -0.5])
arcs.append([i, i + 1, 0, symbol_ids[i], -0.5])
arcs.append([i, i + 1, symbol_ids[i], symbol_ids[i], 0]) arcs.append([i, i + 1, symbol_ids[i], symbol_ids[i], 0])
arcs.append([i, i + 1, 0, symbol_ids[i], -1]) arcs.append([final_state - 1, final_state - 1, 0, 0, -0.5])
arcs.append([final_state - 1, final_state - 1, 0, 0, -1.01])
arcs.append([final_state - 1, final_state, -1, -1, 0]) arcs.append([final_state - 1, final_state, -1, -1, 0])
arcs.append([final_state]) arcs.append([final_state])
@ -134,7 +135,8 @@ class Nbest(object):
if isinstance(lattice.aux_labels, torch.Tensor): if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path) word_seq = k2.ragged.index(lattice.aux_labels, path)
else: else:
word_seq = lattice.aux_labels.index(path, remove_axis=True) word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
# Each utterance has `num_paths` paths but some of them transduces # Each utterance has `num_paths` paths but some of them transduces
# to the same word sequence, so we need to remove repeated word # to the same word sequence, so we need to remove repeated word
@ -170,11 +172,11 @@ class Nbest(object):
# It has 2 axes [arc][word], so aux_labels is also a ragged tensor # It has 2 axes [arc][word], so aux_labels is also a ragged tensor
# with 2 axes [arc][word] # with 2 axes [arc][word]
aux_labels, _ = lattice.aux_labels.index( aux_labels, _ = lattice.aux_labels.index(
indexes=kept_path.data, axis=0, need_value_indexes=False indexes=kept_path.values, axis=0, need_value_indexes=False
) )
else: else:
assert isinstance(lattice.aux_labels, torch.Tensor) assert isinstance(lattice.aux_labels, torch.Tensor)
aux_labels = k2.index_select(lattice.aux_labels, kept_path.data) aux_labels = k2.index_select(lattice.aux_labels, kept_path.values)
# aux_labels is a 1-D torch.Tensor. It also contains -1 and 0. # aux_labels is a 1-D torch.Tensor. It also contains -1 and 0.
fsa = k2.linear_fsa(labels) fsa = k2.linear_fsa(labels)

View File

@ -40,7 +40,10 @@ def test_nbest_from_lattice():
lattice = k2.Fsa.from_fsas([lattice, lattice]) lattice = k2.Fsa.from_fsas([lattice, lattice])
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
lattice=lattice, num_paths=10, use_double_scores=True, scale=0.5 lattice=lattice,
num_paths=10,
use_double_scores=True,
lattice_score_scale=0.5,
) )
# each lattice has only 4 distinct paths that have different word sequences: # each lattice has only 4 distinct paths that have different word sequences:
# 10->30 # 10->30