mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Use Levenshtein graphs/alignment from k2 v1.9
This commit is contained in:
parent
a96f10a96d
commit
d2c7fb9cea
@ -128,43 +128,6 @@ def get_lattice(
|
||||
return lattice
|
||||
|
||||
|
||||
# TODO(fangjun): Use Kangwei's C++ implementation that also
|
||||
# supports List[List[int]]
|
||||
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
|
||||
"""Construct a graph to compute Levenshtein distance.
|
||||
|
||||
An example graph for `levenshtein_graph([1, 2, 3]` can be found
|
||||
at https://git.io/Ju7eW
|
||||
(Assuming the symbol table is 1<->a, 2<->b, 3<->c, blk<->0)
|
||||
|
||||
Args:
|
||||
symbol_ids:
|
||||
A list of symbol IDs (excluding 0 and -1)
|
||||
Returns:
|
||||
Return an Fsa (with 2 axes [state][arc]).
|
||||
"""
|
||||
assert 0 not in symbol_ids
|
||||
assert -1 not in symbol_ids
|
||||
final_state = len(symbol_ids) + 1
|
||||
arcs = []
|
||||
for i in range(final_state - 1):
|
||||
arcs.append([i, i, 0, 0, -0.5])
|
||||
arcs.append([i, i + 1, 0, symbol_ids[i], -0.5])
|
||||
arcs.append([i, i + 1, symbol_ids[i], symbol_ids[i], 0])
|
||||
arcs.append([final_state - 1, final_state - 1, 0, 0, -0.5])
|
||||
arcs.append([final_state - 1, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
fsa = k2.arc_sort(fsa)
|
||||
return fsa
|
||||
|
||||
|
||||
class Nbest(object):
|
||||
"""
|
||||
An Nbest object contains two fields:
|
||||
@ -456,9 +419,8 @@ class Nbest(object):
|
||||
|
||||
def build_levenshtein_graphs(self) -> k2.Fsa:
|
||||
"""Return an FsaVec with axes [utt][state][arc]."""
|
||||
word_ids = get_texts(self.fsa)
|
||||
word_levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids]
|
||||
return k2.Fsa.from_fsas(word_levenshtein_graphs)
|
||||
word_ids = get_texts(self.fsa, return_ragged=True)
|
||||
return k2.levenshtein_graph(word_ids)
|
||||
|
||||
|
||||
def one_best_decoding(
|
||||
@ -590,7 +552,7 @@ def nbest_oracle(
|
||||
lattice_score_scale=lattice_score_scale,
|
||||
)
|
||||
|
||||
hyps = nbest.build_levenshtein_graphs().to(device)
|
||||
hyps = nbest.build_levenshtein_graphs()
|
||||
|
||||
oov_id = word_table[oov]
|
||||
word_ids_list = []
|
||||
@ -603,24 +565,16 @@ def nbest_oracle(
|
||||
word_ids.append(oov_id)
|
||||
word_ids_list.append(word_ids)
|
||||
|
||||
levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids_list]
|
||||
refs = k2.Fsa.from_fsas(levenshtein_graphs).to(device)
|
||||
refs = k2.levenshtein_graph(word_ids_list, device=device)
|
||||
|
||||
# Now compute the edit distance between hyps and refs
|
||||
hyps.rename_tensor_attribute_("aux_labels", "aux_labels2")
|
||||
edit_dist_lattice = k2.intersect_device(
|
||||
refs,
|
||||
hyps,
|
||||
b_to_a_map=nbest.shape.row_ids(1),
|
||||
sorted_match_a=True,
|
||||
levenshtein_alignment = k2.levenshtein_alignment(
|
||||
refs=refs,
|
||||
hyps=hyps,
|
||||
hyp_to_ref_map=nbest.shape.row_ids(1),
|
||||
sorted_match_ref=True,
|
||||
)
|
||||
edit_dist_lattice = k2.remove_epsilon_self_loops(edit_dist_lattice)
|
||||
edit_dist_best_path = k2.shortest_path(
|
||||
edit_dist_lattice, use_double_scores=True
|
||||
).invert_()
|
||||
edit_dist_best_path.rename_tensor_attribute_("aux_labels2", "aux_labels")
|
||||
|
||||
tot_scores = edit_dist_best_path.get_tot_scores(
|
||||
tot_scores = levenshtein_alignment.get_tot_scores(
|
||||
use_double_scores=False, log_semiring=False
|
||||
)
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
|
@ -186,7 +186,9 @@ def encode_supervisions(
|
||||
return supervision_segments, texts
|
||||
|
||||
|
||||
def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
def get_texts(
|
||||
best_paths: k2.Fsa, return_ragged: bool = False
|
||||
) -> Union[List[List[int]], k2.RaggedTensor]:
|
||||
"""Extract the texts (as word IDs) from the best-path FSAs.
|
||||
Args:
|
||||
best_paths:
|
||||
@ -194,6 +196,9 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
containing multiple FSAs, which is expected to be the result
|
||||
of k2.shortest_path (otherwise the returned values won't
|
||||
be meaningful).
|
||||
return_ragged:
|
||||
True to return a ragged tensor with two axes [utt][word_id].
|
||||
False to return a list-of-list word IDs.
|
||||
Returns:
|
||||
Returns a list of lists of int, containing the label sequences we
|
||||
decoded.
|
||||
@ -216,6 +221,9 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
aux_labels = aux_labels.remove_values_leq(0)
|
||||
|
||||
assert aux_labels.num_axes == 2
|
||||
if return_ragged:
|
||||
return aux_labels
|
||||
else:
|
||||
return aux_labels.tolist()
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user