Use Levenshtein graphs/alignment from k2 v1.9

This commit is contained in:
Fangjun Kuang 2021-09-20 15:12:08 +08:00
parent a96f10a96d
commit d2c7fb9cea
2 changed files with 20 additions and 58 deletions

View File

@ -128,43 +128,6 @@ def get_lattice(
return lattice
# TODO(fangjun): Use Kangwei's C++ implementation that also
# supports List[List[int]]
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
"""Construct a graph to compute Levenshtein distance.
An example graph for `levenshtein_graph([1, 2, 3]` can be found
at https://git.io/Ju7eW
(Assuming the symbol table is 1<->a, 2<->b, 3<->c, blk<->0)
Args:
symbol_ids:
A list of symbol IDs (excluding 0 and -1)
Returns:
Return an Fsa (with 2 axes [state][arc]).
"""
assert 0 not in symbol_ids
assert -1 not in symbol_ids
final_state = len(symbol_ids) + 1
arcs = []
for i in range(final_state - 1):
arcs.append([i, i, 0, 0, -0.5])
arcs.append([i, i + 1, 0, symbol_ids[i], -0.5])
arcs.append([i, i + 1, symbol_ids[i], symbol_ids[i], 0])
arcs.append([final_state - 1, final_state - 1, 0, 0, -0.5])
arcs.append([final_state - 1, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
fsa = k2.arc_sort(fsa)
return fsa
class Nbest(object):
"""
An Nbest object contains two fields:
@ -456,9 +419,8 @@ class Nbest(object):
def build_levenshtein_graphs(self) -> k2.Fsa:
"""Return an FsaVec with axes [utt][state][arc]."""
word_ids = get_texts(self.fsa)
word_levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids]
return k2.Fsa.from_fsas(word_levenshtein_graphs)
word_ids = get_texts(self.fsa, return_ragged=True)
return k2.levenshtein_graph(word_ids)
def one_best_decoding(
@ -590,7 +552,7 @@ def nbest_oracle(
lattice_score_scale=lattice_score_scale,
)
hyps = nbest.build_levenshtein_graphs().to(device)
hyps = nbest.build_levenshtein_graphs()
oov_id = word_table[oov]
word_ids_list = []
@ -603,24 +565,16 @@ def nbest_oracle(
word_ids.append(oov_id)
word_ids_list.append(word_ids)
levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids_list]
refs = k2.Fsa.from_fsas(levenshtein_graphs).to(device)
refs = k2.levenshtein_graph(word_ids_list, device=device)
# Now compute the edit distance between hyps and refs
hyps.rename_tensor_attribute_("aux_labels", "aux_labels2")
edit_dist_lattice = k2.intersect_device(
refs,
hyps,
b_to_a_map=nbest.shape.row_ids(1),
sorted_match_a=True,
levenshtein_alignment = k2.levenshtein_alignment(
refs=refs,
hyps=hyps,
hyp_to_ref_map=nbest.shape.row_ids(1),
sorted_match_ref=True,
)
edit_dist_lattice = k2.remove_epsilon_self_loops(edit_dist_lattice)
edit_dist_best_path = k2.shortest_path(
edit_dist_lattice, use_double_scores=True
).invert_()
edit_dist_best_path.rename_tensor_attribute_("aux_labels2", "aux_labels")
tot_scores = edit_dist_best_path.get_tot_scores(
tot_scores = levenshtein_alignment.get_tot_scores(
use_double_scores=False, log_semiring=False
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)

View File

@ -186,7 +186,9 @@ def encode_supervisions(
return supervision_segments, texts
def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
def get_texts(
best_paths: k2.Fsa, return_ragged: bool = False
) -> Union[List[List[int]], k2.RaggedTensor]:
"""Extract the texts (as word IDs) from the best-path FSAs.
Args:
best_paths:
@ -194,6 +196,9 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful).
return_ragged:
True to return a ragged tensor with two axes [utt][word_id].
False to return a list-of-list word IDs.
Returns:
Returns a list of lists of int, containing the label sequences we
decoded.
@ -216,6 +221,9 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
aux_labels = aux_labels.remove_values_leq(0)
assert aux_labels.num_axes == 2
if return_ragged:
return aux_labels
else:
return aux_labels.tolist()