mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Compute edit distance with k2.
This commit is contained in:
parent
d2bedbe02e
commit
664b87a57b
@ -32,13 +32,12 @@ from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
|||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
get_lattice,
|
get_lattice,
|
||||||
nbest_oracle,
|
|
||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_attention_decoder,
|
rescore_with_attention_decoder,
|
||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
from icefall.decode2 import nbest_decoding
|
from icefall.decode2 import nbest_decoding, nbest_oracle
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -245,13 +244,17 @@ def decode_one_batch(
|
|||||||
# We choose the HLG decoded lattice for speed reasons
|
# We choose the HLG decoded lattice for speed reasons
|
||||||
# as HLG decoding is faster and the oracle WER
|
# as HLG decoding is faster and the oracle WER
|
||||||
# is slightly worse than that of rescored lattices.
|
# is slightly worse than that of rescored lattices.
|
||||||
return nbest_oracle(
|
best_path = nbest_oracle(
|
||||||
lattice=lattice,
|
lattice=lattice,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
ref_texts=supervisions["text"],
|
ref_texts=supervisions["text"],
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
scale=params.lattice_score_scale,
|
lattice_score_scale=params.lattice_score_scale,
|
||||||
|
oov="<UNK>",
|
||||||
)
|
)
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
|
return {"nbest-orcale": hyps}
|
||||||
|
|
||||||
if params.method in ["1best", "nbest"]:
|
if params.method in ["1best", "nbest"]:
|
||||||
if params.method == "1best":
|
if params.method == "1best":
|
||||||
@ -264,7 +267,7 @@ def decode_one_batch(
|
|||||||
lattice=lattice,
|
lattice=lattice,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
use_double_scores=params.use_double_scores,
|
use_double_scores=params.use_double_scores,
|
||||||
scale=params.lattice_score_scale,
|
lattice_score_scale=params.lattice_score_scale,
|
||||||
)
|
)
|
||||||
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa
|
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa
|
||||||
|
|
||||||
|
@ -17,9 +17,47 @@
|
|||||||
# NOTE: This file is a refactor of decode.py
|
# NOTE: This file is a refactor of decode.py
|
||||||
# We will delete decode.py and rename this file to decode.py
|
# We will delete decode.py and rename this file to decode.py
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from icefall.utils import get_texts
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(fangjun): Use Kangwei's C++ implementation that also List[List[int]]
|
||||||
|
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
|
||||||
|
"""Construct a graph to compute Levenshtein distance.
|
||||||
|
|
||||||
|
An example graph for `levenshtein_graph([1, 2, 3]` can be found
|
||||||
|
at https://git.io/Ju7eW
|
||||||
|
(Assuming the symbol table is 1<->a, 2<->b, 3<->c, blk<->0)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol_ids:
|
||||||
|
A list of symbol IDs (excluding 0 and -1)
|
||||||
|
"""
|
||||||
|
assert 0 not in symbol_ids
|
||||||
|
assert -1 not in symbol_ids
|
||||||
|
final_state = len(symbol_ids) + 1
|
||||||
|
arcs = []
|
||||||
|
for i in range(final_state - 1):
|
||||||
|
arcs.append([i, i, 0, 0, -1.01])
|
||||||
|
arcs.append([i, i + 1, symbol_ids[i], symbol_ids[i], 0])
|
||||||
|
arcs.append([i, i + 1, 0, symbol_ids[i], -1])
|
||||||
|
arcs.append([final_state - 1, final_state - 1, 0, 0, -1.01])
|
||||||
|
arcs.append([final_state - 1, final_state, -1, -1, 0])
|
||||||
|
arcs.append([final_state])
|
||||||
|
|
||||||
|
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||||
|
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||||
|
arcs = [" ".join(arc) for arc in arcs]
|
||||||
|
arcs = "\n".join(arcs)
|
||||||
|
|
||||||
|
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||||
|
fsa = k2.arc_sort(fsa)
|
||||||
|
return fsa
|
||||||
|
|
||||||
|
|
||||||
class Nbest(object):
|
class Nbest(object):
|
||||||
"""
|
"""
|
||||||
@ -63,7 +101,7 @@ class Nbest(object):
|
|||||||
lattice: k2.Fsa,
|
lattice: k2.Fsa,
|
||||||
num_paths: int,
|
num_paths: int,
|
||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
scale: float = 0.5,
|
lattice_score_scale: float = 0.5,
|
||||||
) -> "Nbest":
|
) -> "Nbest":
|
||||||
"""Construct an Nbest object by **sampling** `num_paths` from a lattice.
|
"""Construct an Nbest object by **sampling** `num_paths` from a lattice.
|
||||||
|
|
||||||
@ -82,7 +120,7 @@ class Nbest(object):
|
|||||||
to sample the path with the best score.
|
to sample the path with the best score.
|
||||||
"""
|
"""
|
||||||
saved_scores = lattice.scores.clone()
|
saved_scores = lattice.scores.clone()
|
||||||
lattice.scores *= scale
|
lattice.scores *= lattice_score_scale
|
||||||
# path is a ragged tensor with dtype torch.int32.
|
# path is a ragged tensor with dtype torch.int32.
|
||||||
# It has three axes [utt][path][arc_pos
|
# It has three axes [utt][path][arc_pos
|
||||||
path = k2.random_paths(
|
path = k2.random_paths(
|
||||||
@ -230,19 +268,52 @@ class Nbest(object):
|
|||||||
)
|
)
|
||||||
return k2.RaggedTensor(self.shape, scores)
|
return k2.RaggedTensor(self.shape, scores)
|
||||||
|
|
||||||
|
def build_levenshtein_graphs(self) -> k2.Fsa:
|
||||||
|
"""Return an FsaVec with axes [utt][state][arc]."""
|
||||||
|
word_ids = get_texts(self.fsa)
|
||||||
|
word_levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids]
|
||||||
|
return k2.Fsa.from_fsas(word_levenshtein_graphs)
|
||||||
|
|
||||||
|
|
||||||
def nbest_decoding(
|
def nbest_decoding(
|
||||||
lattice: k2.Fsa,
|
lattice: k2.Fsa,
|
||||||
num_paths: int,
|
num_paths: int,
|
||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
scale: float = 1.0,
|
lattice_score_scale: float = 1.0,
|
||||||
) -> k2.Fsa:
|
) -> k2.Fsa:
|
||||||
""" """
|
"""It implements something like CTC prefix beam search using n-best lists.
|
||||||
|
|
||||||
|
The basic idea is to first extra `num_paths` paths from the given lattice,
|
||||||
|
build a word sequence from these paths, and compute the total scores
|
||||||
|
of the word sequence in the tropical semiring. The one with the max score
|
||||||
|
is used as the decoding output.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
Don't be confused by `best` in the name `n-best`. Paths are selected
|
||||||
|
**randomly**, not by ranking their scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lattice:
|
||||||
|
The decoding lattice, e.g., can be the return value of
|
||||||
|
:func:`get_lattice`. It has 3 axes [utt][state][arc].
|
||||||
|
num_paths:
|
||||||
|
It specifies the size `n` in n-best. Note: Paths are selected randomly
|
||||||
|
and those containing identical word sequences are removed and only one
|
||||||
|
of them is kept.
|
||||||
|
use_double_scores:
|
||||||
|
True to use double precision floating point in the computation.
|
||||||
|
False to use single precision.
|
||||||
|
lattice_score_scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
|
Returns:
|
||||||
|
An FsaVec containing linear FSAs. It axes are [utt][state][arc].
|
||||||
|
"""
|
||||||
nbest = Nbest.from_lattice(
|
nbest = Nbest.from_lattice(
|
||||||
lattice=lattice,
|
lattice=lattice,
|
||||||
num_paths=num_paths,
|
num_paths=num_paths,
|
||||||
use_double_scores=use_double_scores,
|
use_double_scores=use_double_scores,
|
||||||
scale=scale,
|
lattice_score_scale=lattice_score_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
nbest = nbest.intersect(lattice)
|
nbest = nbest.intersect(lattice)
|
||||||
@ -253,3 +324,96 @@ def nbest_decoding(
|
|||||||
|
|
||||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
return best_path
|
return best_path
|
||||||
|
|
||||||
|
|
||||||
|
def nbest_oracle(
|
||||||
|
lattice: k2.Fsa,
|
||||||
|
num_paths: int,
|
||||||
|
ref_texts: List[str],
|
||||||
|
word_table: k2.SymbolTable,
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
lattice_score_scale: float = 0.5,
|
||||||
|
oov: str = "<UNK>",
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
"""Select the best hypothesis given a lattice and a reference transcript.
|
||||||
|
|
||||||
|
The basic idea is to extract n paths from the given lattice, unique them,
|
||||||
|
and select the one that has the minimum edit distance with the corresponding
|
||||||
|
reference transcript as the decoding output.
|
||||||
|
|
||||||
|
The decoding result returned from this function is the best result that
|
||||||
|
we can obtain using n-best decoding with all kinds of rescoring techniques.
|
||||||
|
|
||||||
|
This function is useful to tune the value of `lattice_score_scale`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lattice:
|
||||||
|
An FsaVec with axes [utt][state][arc].
|
||||||
|
Note: We assume its aux_labels contain word IDs.
|
||||||
|
num_paths:
|
||||||
|
The size of `n` in n-best.
|
||||||
|
ref_texts:
|
||||||
|
A list of reference transcript. Each entry contains space(s)
|
||||||
|
separated words
|
||||||
|
word_table:
|
||||||
|
It is the word symbol table.
|
||||||
|
use_double_scores:
|
||||||
|
True to use double precision for computation. False to use
|
||||||
|
single precision.
|
||||||
|
lattice_score_scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
|
oov:
|
||||||
|
The out of vocabulary word.
|
||||||
|
Return:
|
||||||
|
Return a dict. Its key contains the information about the parameters
|
||||||
|
when calling this function, while its value contains the decoding output.
|
||||||
|
`len(ans_dict) == len(ref_texts)`
|
||||||
|
"""
|
||||||
|
device = lattice.device
|
||||||
|
|
||||||
|
nbest = Nbest.from_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
lattice_score_scale=lattice_score_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
hyps = nbest.build_levenshtein_graphs().to(device)
|
||||||
|
|
||||||
|
oov_id = word_table[oov]
|
||||||
|
word_ids_list = []
|
||||||
|
for text in ref_texts:
|
||||||
|
word_ids = []
|
||||||
|
for word in text.split():
|
||||||
|
if word in word_table:
|
||||||
|
word_ids.append(word_table[word])
|
||||||
|
else:
|
||||||
|
word_ids.append(oov_id)
|
||||||
|
word_ids_list.append(word_ids)
|
||||||
|
levenshtein_graphs = [levenshtein_graph(ids) for ids in word_ids_list]
|
||||||
|
refs = k2.Fsa.from_fsas(levenshtein_graphs).to(device)
|
||||||
|
|
||||||
|
# Now compute the edit distance between hyps and refs
|
||||||
|
hyps.rename_tensor_attribute_("aux_labels", "aux_labels2")
|
||||||
|
edit_dist_lattice = k2.intersect_device(
|
||||||
|
refs,
|
||||||
|
hyps,
|
||||||
|
b_to_a_map=nbest.shape.row_ids(1),
|
||||||
|
sorted_match_a=True,
|
||||||
|
)
|
||||||
|
edit_dist_lattice = k2.remove_epsilon_self_loops(edit_dist_lattice)
|
||||||
|
edit_dist_best_path = k2.shortest_path(
|
||||||
|
edit_dist_lattice, use_double_scores=True
|
||||||
|
).invert_()
|
||||||
|
edit_dist_best_path.rename_tensor_attribute_("aux_labels2", "aux_labels")
|
||||||
|
|
||||||
|
tot_scores = edit_dist_best_path.get_tot_scores(
|
||||||
|
use_double_scores=False, log_semiring=False
|
||||||
|
)
|
||||||
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
|
|
||||||
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
|
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
return best_path
|
||||||
|
@ -106,7 +106,7 @@ class CtcTrainingGraphCompiler(object):
|
|||||||
word_ids_list = []
|
word_ids_list = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
word_ids = []
|
word_ids = []
|
||||||
for word in text.split(" "):
|
for word in text.split():
|
||||||
if word in self.word_table:
|
if word in self.word_table:
|
||||||
word_ids.append(self.word_table[word])
|
word_ids.append(self.word_table[word])
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user