mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Refactor nbest-oracle.
This commit is contained in:
parent
d6a995978a
commit
b9fc46f432
@ -30,14 +30,15 @@ from conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import get_lattice
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
nbest_oracle,
|
||||
)
|
||||
from icefall.decode2 import nbest_decoding, nbest_oracle
|
||||
from icefall.decode2 import nbest_decoding, nbest_oracle as nbest_oracle2
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -244,17 +245,27 @@ def decode_one_batch(
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
# as HLG decoding is faster and the oracle WER
|
||||
# is slightly worse than that of rescored lattices.
|
||||
best_path = nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
return {"nbest-orcale": hyps}
|
||||
if True:
|
||||
# TODO: delete the `else` branch
|
||||
best_path = nbest_oracle2(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
return {"nbest-orcale": hyps}
|
||||
else:
|
||||
return nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
scale=params.lattice_score_scale,
|
||||
)
|
||||
|
||||
if params.method in ["1best", "nbest"]:
|
||||
if params.method == "1best":
|
||||
|
@ -669,7 +669,7 @@ def nbest_oracle(
|
||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(1)
|
||||
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
unique_word_seq, _, _ = word_seq.unique(
|
||||
|
@ -25,7 +25,8 @@ import torch
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
# TODO(fangjun): Use Kangwei's C++ implementation that also List[List[int]]
|
||||
# TODO(fangjun): Use Kangwei's C++ implementation that also
|
||||
# supports List[List[int]]
|
||||
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
|
||||
"""Construct a graph to compute Levenshtein distance.
|
||||
|
||||
@ -42,10 +43,10 @@ def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
|
||||
final_state = len(symbol_ids) + 1
|
||||
arcs = []
|
||||
for i in range(final_state - 1):
|
||||
arcs.append([i, i, 0, 0, -1.01])
|
||||
arcs.append([i, i, 0, 0, -0.5])
|
||||
arcs.append([i, i + 1, 0, symbol_ids[i], -0.5])
|
||||
arcs.append([i, i + 1, symbol_ids[i], symbol_ids[i], 0])
|
||||
arcs.append([i, i + 1, 0, symbol_ids[i], -1])
|
||||
arcs.append([final_state - 1, final_state - 1, 0, 0, -1.01])
|
||||
arcs.append([final_state - 1, final_state - 1, 0, 0, -0.5])
|
||||
arcs.append([final_state - 1, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
@ -134,7 +135,8 @@ class Nbest(object):
|
||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path, remove_axis=True)
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||
|
||||
# Each utterance has `num_paths` paths but some of them transduces
|
||||
# to the same word sequence, so we need to remove repeated word
|
||||
@ -170,11 +172,11 @@ class Nbest(object):
|
||||
# It has 2 axes [arc][word], so aux_labels is also a ragged tensor
|
||||
# with 2 axes [arc][word]
|
||||
aux_labels, _ = lattice.aux_labels.index(
|
||||
indexes=kept_path.data, axis=0, need_value_indexes=False
|
||||
indexes=kept_path.values, axis=0, need_value_indexes=False
|
||||
)
|
||||
else:
|
||||
assert isinstance(lattice.aux_labels, torch.Tensor)
|
||||
aux_labels = k2.index_select(lattice.aux_labels, kept_path.data)
|
||||
aux_labels = k2.index_select(lattice.aux_labels, kept_path.values)
|
||||
# aux_labels is a 1-D torch.Tensor. It also contains -1 and 0.
|
||||
|
||||
fsa = k2.linear_fsa(labels)
|
||||
|
@ -40,7 +40,10 @@ def test_nbest_from_lattice():
|
||||
lattice = k2.Fsa.from_fsas([lattice, lattice])
|
||||
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice, num_paths=10, use_double_scores=True, scale=0.5
|
||||
lattice=lattice,
|
||||
num_paths=10,
|
||||
use_double_scores=True,
|
||||
lattice_score_scale=0.5,
|
||||
)
|
||||
# each lattice has only 4 distinct paths that have different word sequences:
|
||||
# 10->30
|
||||
|
Loading…
x
Reference in New Issue
Block a user