mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Refactor nbest-oracle.
This commit is contained in:
parent
d6a995978a
commit
b9fc46f432
@ -30,14 +30,15 @@ from conformer import Conformer
|
|||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
|
from icefall.decode import get_lattice
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
get_lattice,
|
|
||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_attention_decoder,
|
rescore_with_attention_decoder,
|
||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
|
nbest_oracle,
|
||||||
)
|
)
|
||||||
from icefall.decode2 import nbest_decoding, nbest_oracle
|
from icefall.decode2 import nbest_decoding, nbest_oracle as nbest_oracle2
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -244,17 +245,27 @@ def decode_one_batch(
|
|||||||
# We choose the HLG decoded lattice for speed reasons
|
# We choose the HLG decoded lattice for speed reasons
|
||||||
# as HLG decoding is faster and the oracle WER
|
# as HLG decoding is faster and the oracle WER
|
||||||
# is slightly worse than that of rescored lattices.
|
# is slightly worse than that of rescored lattices.
|
||||||
best_path = nbest_oracle(
|
if True:
|
||||||
lattice=lattice,
|
# TODO: delete the `else` branch
|
||||||
num_paths=params.num_paths,
|
best_path = nbest_oracle2(
|
||||||
ref_texts=supervisions["text"],
|
lattice=lattice,
|
||||||
word_table=word_table,
|
num_paths=params.num_paths,
|
||||||
lattice_score_scale=params.lattice_score_scale,
|
ref_texts=supervisions["text"],
|
||||||
oov="<UNK>",
|
word_table=word_table,
|
||||||
)
|
lattice_score_scale=params.lattice_score_scale,
|
||||||
hyps = get_texts(best_path)
|
oov="<UNK>",
|
||||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
)
|
||||||
return {"nbest-orcale": hyps}
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
|
return {"nbest-orcale": hyps}
|
||||||
|
else:
|
||||||
|
return nbest_oracle(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=supervisions["text"],
|
||||||
|
word_table=word_table,
|
||||||
|
scale=params.lattice_score_scale,
|
||||||
|
)
|
||||||
|
|
||||||
if params.method in ["1best", "nbest"]:
|
if params.method in ["1best", "nbest"]:
|
||||||
if params.method == "1best":
|
if params.method == "1best":
|
||||||
|
@ -669,7 +669,7 @@ def nbest_oracle(
|
|||||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||||
else:
|
else:
|
||||||
word_seq = lattice.aux_labels.index(path)
|
word_seq = lattice.aux_labels.index(path)
|
||||||
word_seq = word_seq.remove_axis(1)
|
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||||
|
|
||||||
word_seq = word_seq.remove_values_leq(0)
|
word_seq = word_seq.remove_values_leq(0)
|
||||||
unique_word_seq, _, _ = word_seq.unique(
|
unique_word_seq, _, _ = word_seq.unique(
|
||||||
|
@ -25,7 +25,8 @@ import torch
|
|||||||
from icefall.utils import get_texts
|
from icefall.utils import get_texts
|
||||||
|
|
||||||
|
|
||||||
# TODO(fangjun): Use Kangwei's C++ implementation that also List[List[int]]
|
# TODO(fangjun): Use Kangwei's C++ implementation that also
|
||||||
|
# supports List[List[int]]
|
||||||
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
|
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
|
||||||
"""Construct a graph to compute Levenshtein distance.
|
"""Construct a graph to compute Levenshtein distance.
|
||||||
|
|
||||||
@ -42,10 +43,10 @@ def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
|
|||||||
final_state = len(symbol_ids) + 1
|
final_state = len(symbol_ids) + 1
|
||||||
arcs = []
|
arcs = []
|
||||||
for i in range(final_state - 1):
|
for i in range(final_state - 1):
|
||||||
arcs.append([i, i, 0, 0, -1.01])
|
arcs.append([i, i, 0, 0, -0.5])
|
||||||
|
arcs.append([i, i + 1, 0, symbol_ids[i], -0.5])
|
||||||
arcs.append([i, i + 1, symbol_ids[i], symbol_ids[i], 0])
|
arcs.append([i, i + 1, symbol_ids[i], symbol_ids[i], 0])
|
||||||
arcs.append([i, i + 1, 0, symbol_ids[i], -1])
|
arcs.append([final_state - 1, final_state - 1, 0, 0, -0.5])
|
||||||
arcs.append([final_state - 1, final_state - 1, 0, 0, -1.01])
|
|
||||||
arcs.append([final_state - 1, final_state, -1, -1, 0])
|
arcs.append([final_state - 1, final_state, -1, -1, 0])
|
||||||
arcs.append([final_state])
|
arcs.append([final_state])
|
||||||
|
|
||||||
@ -134,7 +135,8 @@ class Nbest(object):
|
|||||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||||
else:
|
else:
|
||||||
word_seq = lattice.aux_labels.index(path, remove_axis=True)
|
word_seq = lattice.aux_labels.index(path)
|
||||||
|
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||||
|
|
||||||
# Each utterance has `num_paths` paths but some of them transduces
|
# Each utterance has `num_paths` paths but some of them transduces
|
||||||
# to the same word sequence, so we need to remove repeated word
|
# to the same word sequence, so we need to remove repeated word
|
||||||
@ -170,11 +172,11 @@ class Nbest(object):
|
|||||||
# It has 2 axes [arc][word], so aux_labels is also a ragged tensor
|
# It has 2 axes [arc][word], so aux_labels is also a ragged tensor
|
||||||
# with 2 axes [arc][word]
|
# with 2 axes [arc][word]
|
||||||
aux_labels, _ = lattice.aux_labels.index(
|
aux_labels, _ = lattice.aux_labels.index(
|
||||||
indexes=kept_path.data, axis=0, need_value_indexes=False
|
indexes=kept_path.values, axis=0, need_value_indexes=False
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert isinstance(lattice.aux_labels, torch.Tensor)
|
assert isinstance(lattice.aux_labels, torch.Tensor)
|
||||||
aux_labels = k2.index_select(lattice.aux_labels, kept_path.data)
|
aux_labels = k2.index_select(lattice.aux_labels, kept_path.values)
|
||||||
# aux_labels is a 1-D torch.Tensor. It also contains -1 and 0.
|
# aux_labels is a 1-D torch.Tensor. It also contains -1 and 0.
|
||||||
|
|
||||||
fsa = k2.linear_fsa(labels)
|
fsa = k2.linear_fsa(labels)
|
||||||
|
@ -40,7 +40,10 @@ def test_nbest_from_lattice():
|
|||||||
lattice = k2.Fsa.from_fsas([lattice, lattice])
|
lattice = k2.Fsa.from_fsas([lattice, lattice])
|
||||||
|
|
||||||
nbest = Nbest.from_lattice(
|
nbest = Nbest.from_lattice(
|
||||||
lattice=lattice, num_paths=10, use_double_scores=True, scale=0.5
|
lattice=lattice,
|
||||||
|
num_paths=10,
|
||||||
|
use_double_scores=True,
|
||||||
|
lattice_score_scale=0.5,
|
||||||
)
|
)
|
||||||
# each lattice has only 4 distinct paths that have different word sequences:
|
# each lattice has only 4 distinct paths that have different word sequences:
|
||||||
# 10->30
|
# 10->30
|
||||||
|
Loading…
x
Reference in New Issue
Block a user