Refactor nbest-oracle.

This commit is contained in:
Fangjun Kuang 2021-09-17 15:46:23 +08:00
parent d6a995978a
commit b9fc46f432
4 changed files with 38 additions and 22 deletions

View File

@ -30,14 +30,15 @@ from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import get_lattice
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_whole_lattice,
nbest_oracle,
)
from icefall.decode2 import nbest_decoding, nbest_oracle
from icefall.decode2 import nbest_decoding, nbest_oracle as nbest_oracle2
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -244,7 +245,9 @@ def decode_one_batch(
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is slightly worse than that of rescored lattices.
best_path = nbest_oracle(
if True:
# TODO: delete the `else` branch
best_path = nbest_oracle2(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
@ -255,6 +258,14 @@ def decode_one_batch(
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {"nbest-orcale": hyps}
else:
return nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
scale=params.lattice_score_scale,
)
if params.method in ["1best", "nbest"]:
if params.method == "1best":

View File

@ -669,7 +669,7 @@ def nbest_oracle(
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
word_seq = word_seq.remove_values_leq(0)
unique_word_seq, _, _ = word_seq.unique(

View File

@ -25,7 +25,8 @@ import torch
from icefall.utils import get_texts
# TODO(fangjun): Use Kangwei's C++ implementation that also List[List[int]]
# TODO(fangjun): Use Kangwei's C++ implementation that also
# supports List[List[int]]
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
"""Construct a graph to compute Levenshtein distance.
@ -42,10 +43,10 @@ def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
final_state = len(symbol_ids) + 1
arcs = []
for i in range(final_state - 1):
arcs.append([i, i, 0, 0, -1.01])
arcs.append([i, i, 0, 0, -0.5])
arcs.append([i, i + 1, 0, symbol_ids[i], -0.5])
arcs.append([i, i + 1, symbol_ids[i], symbol_ids[i], 0])
arcs.append([i, i + 1, 0, symbol_ids[i], -1])
arcs.append([final_state - 1, final_state - 1, 0, 0, -1.01])
arcs.append([final_state - 1, final_state - 1, 0, 0, -0.5])
arcs.append([final_state - 1, final_state, -1, -1, 0])
arcs.append([final_state])
@ -134,7 +135,8 @@ class Nbest(object):
if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path, remove_axis=True)
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
# Each utterance has `num_paths` paths but some of them transduces
# to the same word sequence, so we need to remove repeated word
@ -170,11 +172,11 @@ class Nbest(object):
# It has 2 axes [arc][word], so aux_labels is also a ragged tensor
# with 2 axes [arc][word]
aux_labels, _ = lattice.aux_labels.index(
indexes=kept_path.data, axis=0, need_value_indexes=False
indexes=kept_path.values, axis=0, need_value_indexes=False
)
else:
assert isinstance(lattice.aux_labels, torch.Tensor)
aux_labels = k2.index_select(lattice.aux_labels, kept_path.data)
aux_labels = k2.index_select(lattice.aux_labels, kept_path.values)
# aux_labels is a 1-D torch.Tensor. It also contains -1 and 0.
fsa = k2.linear_fsa(labels)

View File

@ -40,7 +40,10 @@ def test_nbest_from_lattice():
lattice = k2.Fsa.from_fsas([lattice, lattice])
nbest = Nbest.from_lattice(
lattice=lattice, num_paths=10, use_double_scores=True, scale=0.5
lattice=lattice,
num_paths=10,
use_double_scores=True,
lattice_score_scale=0.5,
)
# each lattice has only 4 distinct paths that have different word sequences:
# 10->30