Compute edit distance with k2.

This commit is contained in:
Fangjun Kuang 2021-09-14 21:07:54 +08:00
parent d2bedbe02e
commit 664b87a57b
3 changed files with 178 additions and 11 deletions

View File

@ -32,13 +32,12 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import (
get_lattice,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.decode2 import nbest_decoding
from icefall.decode2 import nbest_decoding, nbest_oracle
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -245,13 +244,17 @@ def decode_one_batch(
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is slightly worse than that of rescored lattices.
return nbest_oracle(
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
scale=params.lattice_score_scale,
lattice_score_scale=params.lattice_score_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {"nbest-orcale": hyps}
if params.method in ["1best", "nbest"]:
if params.method == "1best":
@ -264,7 +267,7 @@ def decode_one_batch(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
scale=params.lattice_score_scale,
lattice_score_scale=params.lattice_score_scale,
)
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa

View File

@ -17,9 +17,47 @@
# NOTE: This file is a refactor of decode.py
# We will delete decode.py and rename this file to decode.py
from typing import Dict, List
import k2
import torch
from icefall.utils import get_texts
# TODO(fangjun): Use Kangwei's C++ implementation that also List[List[int]]
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
"""Construct a graph to compute Levenshtein distance.
An example graph for `levenshtein_graph([1, 2, 3]` can be found
at https://git.io/Ju7eW
(Assuming the symbol table is 1<->a, 2<->b, 3<->c, blk<->0)
Args:
symbol_ids:
A list of symbol IDs (excluding 0 and -1)
"""
assert 0 not in symbol_ids
assert -1 not in symbol_ids
final_state = len(symbol_ids) + 1
arcs = []
for i in range(final_state - 1):
arcs.append([i, i, 0, 0, -1.01])
arcs.append([i, i + 1, symbol_ids[i], symbol_ids[i], 0])
arcs.append([i, i + 1, 0, symbol_ids[i], -1])
arcs.append([final_state - 1, final_state - 1, 0, 0, -1.01])
arcs.append([final_state - 1, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
fsa = k2.arc_sort(fsa)
return fsa
class Nbest(object):
"""
@ -63,7 +101,7 @@ class Nbest(object):
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
scale: float = 0.5,
lattice_score_scale: float = 0.5,
) -> "Nbest":
"""Construct an Nbest object by **sampling** `num_paths` from a lattice.
@ -82,7 +120,7 @@ class Nbest(object):
to sample the path with the best score.
"""
saved_scores = lattice.scores.clone()
lattice.scores *= scale
lattice.scores *= lattice_score_scale
# path is a ragged tensor with dtype torch.int32.
# It has three axes [utt][path][arc_pos
path = k2.random_paths(
@ -230,19 +268,52 @@ class Nbest(object):
)
return k2.RaggedTensor(self.shape, scores)
def build_levenshtein_graphs(self) -> k2.Fsa:
"""Return an FsaVec with axes [utt][state][arc]."""
word_ids = get_texts(self.fsa)
word_levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids]
return k2.Fsa.from_fsas(word_levenshtein_graphs)
def nbest_decoding(
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
scale: float = 1.0,
lattice_score_scale: float = 1.0,
) -> k2.Fsa:
""" """
"""It implements something like CTC prefix beam search using n-best lists.
The basic idea is to first extra `num_paths` paths from the given lattice,
build a word sequence from these paths, and compute the total scores
of the word sequence in the tropical semiring. The one with the max score
is used as the decoding output.
Caution:
Don't be confused by `best` in the name `n-best`. Paths are selected
**randomly**, not by ranking their scores.
Args:
lattice:
The decoding lattice, e.g., can be the return value of
:func:`get_lattice`. It has 3 axes [utt][state][arc].
num_paths:
It specifies the size `n` in n-best. Note: Paths are selected randomly
and those containing identical word sequences are removed and only one
of them is kept.
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
lattice_score_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
An FsaVec containing linear FSAs. It axes are [utt][state][arc].
"""
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
scale=scale,
lattice_score_scale=lattice_score_scale,
)
nbest = nbest.intersect(lattice)
@ -253,3 +324,96 @@ def nbest_decoding(
best_path = k2.index_fsa(nbest.fsa, max_indexes)
return best_path
def nbest_oracle(
lattice: k2.Fsa,
num_paths: int,
ref_texts: List[str],
word_table: k2.SymbolTable,
use_double_scores: bool = True,
lattice_score_scale: float = 0.5,
oov: str = "<UNK>",
) -> Dict[str, List[List[int]]]:
"""Select the best hypothesis given a lattice and a reference transcript.
The basic idea is to extract n paths from the given lattice, unique them,
and select the one that has the minimum edit distance with the corresponding
reference transcript as the decoding output.
The decoding result returned from this function is the best result that
we can obtain using n-best decoding with all kinds of rescoring techniques.
This function is useful to tune the value of `lattice_score_scale`.
Args:
lattice:
An FsaVec with axes [utt][state][arc].
Note: We assume its aux_labels contain word IDs.
num_paths:
The size of `n` in n-best.
ref_texts:
A list of reference transcript. Each entry contains space(s)
separated words
word_table:
It is the word symbol table.
use_double_scores:
True to use double precision for computation. False to use
single precision.
lattice_score_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
oov:
The out of vocabulary word.
Return:
Return a dict. Its key contains the information about the parameters
when calling this function, while its value contains the decoding output.
`len(ans_dict) == len(ref_texts)`
"""
device = lattice.device
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale,
)
hyps = nbest.build_levenshtein_graphs().to(device)
oov_id = word_table[oov]
word_ids_list = []
for text in ref_texts:
word_ids = []
for word in text.split():
if word in word_table:
word_ids.append(word_table[word])
else:
word_ids.append(oov_id)
word_ids_list.append(word_ids)
levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids_list]
refs = k2.Fsa.from_fsas(levenshtein_graphs).to(device)
# Now compute the edit distance between hyps and refs
hyps.rename_tensor_attribute_("aux_labels", "aux_labels2")
edit_dist_lattice = k2.intersect_device(
refs,
hyps,
b_to_a_map=nbest.shape.row_ids(1),
sorted_match_a=True,
)
edit_dist_lattice = k2.remove_epsilon_self_loops(edit_dist_lattice)
edit_dist_best_path = k2.shortest_path(
edit_dist_lattice, use_double_scores=True
).invert_()
edit_dist_best_path.rename_tensor_attribute_("aux_labels2", "aux_labels")
tot_scores = edit_dist_best_path.get_tot_scores(
use_double_scores=False, log_semiring=False
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
return best_path

View File

@ -106,7 +106,7 @@ class CtcTrainingGraphCompiler(object):
word_ids_list = []
for text in texts:
word_ids = []
for word in text.split(" "):
for word in text.split():
if word in self.word_table:
word_ids.append(self.word_table[word])
else: