mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Compute edit distance with k2.
This commit is contained in:
parent
d2bedbe02e
commit
664b87a57b
@ -32,13 +32,12 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.decode2 import nbest_decoding
|
||||
from icefall.decode2 import nbest_decoding, nbest_oracle
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -245,13 +244,17 @@ def decode_one_batch(
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
# as HLG decoding is faster and the oracle WER
|
||||
# is slightly worse than that of rescored lattices.
|
||||
return nbest_oracle(
|
||||
best_path = nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
scale=params.lattice_score_scale,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
return {"nbest-orcale": hyps}
|
||||
|
||||
if params.method in ["1best", "nbest"]:
|
||||
if params.method == "1best":
|
||||
@ -264,7 +267,7 @@ def decode_one_batch(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
scale=params.lattice_score_scale,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa
|
||||
|
||||
|
@ -17,9 +17,47 @@
|
||||
# NOTE: This file is a refactor of decode.py
|
||||
# We will delete decode.py and rename this file to decode.py
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
# TODO(fangjun): Use Kangwei's C++ implementation that also List[List[int]]
|
||||
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
|
||||
"""Construct a graph to compute Levenshtein distance.
|
||||
|
||||
An example graph for `levenshtein_graph([1, 2, 3]` can be found
|
||||
at https://git.io/Ju7eW
|
||||
(Assuming the symbol table is 1<->a, 2<->b, 3<->c, blk<->0)
|
||||
|
||||
Args:
|
||||
symbol_ids:
|
||||
A list of symbol IDs (excluding 0 and -1)
|
||||
"""
|
||||
assert 0 not in symbol_ids
|
||||
assert -1 not in symbol_ids
|
||||
final_state = len(symbol_ids) + 1
|
||||
arcs = []
|
||||
for i in range(final_state - 1):
|
||||
arcs.append([i, i, 0, 0, -1.01])
|
||||
arcs.append([i, i + 1, symbol_ids[i], symbol_ids[i], 0])
|
||||
arcs.append([i, i + 1, 0, symbol_ids[i], -1])
|
||||
arcs.append([final_state - 1, final_state - 1, 0, 0, -1.01])
|
||||
arcs.append([final_state - 1, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
fsa = k2.arc_sort(fsa)
|
||||
return fsa
|
||||
|
||||
|
||||
class Nbest(object):
|
||||
"""
|
||||
@ -63,7 +101,7 @@ class Nbest(object):
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
use_double_scores: bool = True,
|
||||
scale: float = 0.5,
|
||||
lattice_score_scale: float = 0.5,
|
||||
) -> "Nbest":
|
||||
"""Construct an Nbest object by **sampling** `num_paths` from a lattice.
|
||||
|
||||
@ -82,7 +120,7 @@ class Nbest(object):
|
||||
to sample the path with the best score.
|
||||
"""
|
||||
saved_scores = lattice.scores.clone()
|
||||
lattice.scores *= scale
|
||||
lattice.scores *= lattice_score_scale
|
||||
# path is a ragged tensor with dtype torch.int32.
|
||||
# It has three axes [utt][path][arc_pos
|
||||
path = k2.random_paths(
|
||||
@ -230,19 +268,52 @@ class Nbest(object):
|
||||
)
|
||||
return k2.RaggedTensor(self.shape, scores)
|
||||
|
||||
def build_levenshtein_graphs(self) -> k2.Fsa:
|
||||
"""Return an FsaVec with axes [utt][state][arc]."""
|
||||
word_ids = get_texts(self.fsa)
|
||||
word_levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids]
|
||||
return k2.Fsa.from_fsas(word_levenshtein_graphs)
|
||||
|
||||
|
||||
def nbest_decoding(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
use_double_scores: bool = True,
|
||||
scale: float = 1.0,
|
||||
lattice_score_scale: float = 1.0,
|
||||
) -> k2.Fsa:
|
||||
""" """
|
||||
"""It implements something like CTC prefix beam search using n-best lists.
|
||||
|
||||
The basic idea is to first extra `num_paths` paths from the given lattice,
|
||||
build a word sequence from these paths, and compute the total scores
|
||||
of the word sequence in the tropical semiring. The one with the max score
|
||||
is used as the decoding output.
|
||||
|
||||
Caution:
|
||||
Don't be confused by `best` in the name `n-best`. Paths are selected
|
||||
**randomly**, not by ranking their scores.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
The decoding lattice, e.g., can be the return value of
|
||||
:func:`get_lattice`. It has 3 axes [utt][state][arc].
|
||||
num_paths:
|
||||
It specifies the size `n` in n-best. Note: Paths are selected randomly
|
||||
and those containing identical word sequences are removed and only one
|
||||
of them is kept.
|
||||
use_double_scores:
|
||||
True to use double precision floating point in the computation.
|
||||
False to use single precision.
|
||||
lattice_score_scale:
|
||||
It's the scale applied to the lattice.scores. A smaller value
|
||||
yields more unique paths.
|
||||
Returns:
|
||||
An FsaVec containing linear FSAs. It axes are [utt][state][arc].
|
||||
"""
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
scale=scale,
|
||||
lattice_score_scale=lattice_score_scale,
|
||||
)
|
||||
|
||||
nbest = nbest.intersect(lattice)
|
||||
@ -253,3 +324,96 @@ def nbest_decoding(
|
||||
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
return best_path
|
||||
|
||||
|
||||
def nbest_oracle(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
ref_texts: List[str],
|
||||
word_table: k2.SymbolTable,
|
||||
use_double_scores: bool = True,
|
||||
lattice_score_scale: float = 0.5,
|
||||
oov: str = "<UNK>",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Select the best hypothesis given a lattice and a reference transcript.
|
||||
|
||||
The basic idea is to extract n paths from the given lattice, unique them,
|
||||
and select the one that has the minimum edit distance with the corresponding
|
||||
reference transcript as the decoding output.
|
||||
|
||||
The decoding result returned from this function is the best result that
|
||||
we can obtain using n-best decoding with all kinds of rescoring techniques.
|
||||
|
||||
This function is useful to tune the value of `lattice_score_scale`.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec with axes [utt][state][arc].
|
||||
Note: We assume its aux_labels contain word IDs.
|
||||
num_paths:
|
||||
The size of `n` in n-best.
|
||||
ref_texts:
|
||||
A list of reference transcript. Each entry contains space(s)
|
||||
separated words
|
||||
word_table:
|
||||
It is the word symbol table.
|
||||
use_double_scores:
|
||||
True to use double precision for computation. False to use
|
||||
single precision.
|
||||
lattice_score_scale:
|
||||
It's the scale applied to the lattice.scores. A smaller value
|
||||
yields more unique paths.
|
||||
oov:
|
||||
The out of vocabulary word.
|
||||
Return:
|
||||
Return a dict. Its key contains the information about the parameters
|
||||
when calling this function, while its value contains the decoding output.
|
||||
`len(ans_dict) == len(ref_texts)`
|
||||
"""
|
||||
device = lattice.device
|
||||
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
lattice_score_scale=lattice_score_scale,
|
||||
)
|
||||
|
||||
hyps = nbest.build_levenshtein_graphs().to(device)
|
||||
|
||||
oov_id = word_table[oov]
|
||||
word_ids_list = []
|
||||
for text in ref_texts:
|
||||
word_ids = []
|
||||
for word in text.split():
|
||||
if word in word_table:
|
||||
word_ids.append(word_table[word])
|
||||
else:
|
||||
word_ids.append(oov_id)
|
||||
word_ids_list.append(word_ids)
|
||||
levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids_list]
|
||||
refs = k2.Fsa.from_fsas(levenshtein_graphs).to(device)
|
||||
|
||||
# Now compute the edit distance between hyps and refs
|
||||
hyps.rename_tensor_attribute_("aux_labels", "aux_labels2")
|
||||
edit_dist_lattice = k2.intersect_device(
|
||||
refs,
|
||||
hyps,
|
||||
b_to_a_map=nbest.shape.row_ids(1),
|
||||
sorted_match_a=True,
|
||||
)
|
||||
edit_dist_lattice = k2.remove_epsilon_self_loops(edit_dist_lattice)
|
||||
edit_dist_best_path = k2.shortest_path(
|
||||
edit_dist_lattice, use_double_scores=True
|
||||
).invert_()
|
||||
edit_dist_best_path.rename_tensor_attribute_("aux_labels2", "aux_labels")
|
||||
|
||||
tot_scores = edit_dist_best_path.get_tot_scores(
|
||||
use_double_scores=False, log_semiring=False
|
||||
)
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
return best_path
|
||||
|
@ -106,7 +106,7 @@ class CtcTrainingGraphCompiler(object):
|
||||
word_ids_list = []
|
||||
for text in texts:
|
||||
word_ids = []
|
||||
for word in text.split(" "):
|
||||
for word in text.split():
|
||||
if word in self.word_table:
|
||||
word_ids.append(self.word_table[word])
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user