mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Add whole-lattice rescoring.
This commit is contained in:
parent
a44b4f84a5
commit
b0b942b355
@ -42,6 +42,7 @@ from icefall.decode2 import (
|
|||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
nbest_oracle as nbest_oracle2,
|
nbest_oracle as nbest_oracle2,
|
||||||
rescore_with_n_best_list as rescore_with_n_best_list2,
|
rescore_with_n_best_list as rescore_with_n_best_list2,
|
||||||
|
rescore_with_whole_lattice as rescore_with_whole_lattice2,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -261,9 +262,7 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
hyps = get_texts(best_path)
|
hyps = get_texts(best_path)
|
||||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
key = (
|
key = f"oracle_{num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
|
||||||
f"oracle_{num_paths}_lattice_score_scale_{lattice_score_scale}"
|
|
||||||
)
|
|
||||||
return {key: hyps}
|
return {key: hyps}
|
||||||
else:
|
else:
|
||||||
return nbest_oracle(
|
return nbest_oracle(
|
||||||
@ -322,9 +321,19 @@ def decode_one_batch(
|
|||||||
scale=params.lattice_score_scale,
|
scale=params.lattice_score_scale,
|
||||||
)
|
)
|
||||||
elif params.method == "whole-lattice-rescoring":
|
elif params.method == "whole-lattice-rescoring":
|
||||||
best_path_dict = rescore_with_whole_lattice(
|
if True:
|
||||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
|
# TODO: remove "else" branch
|
||||||
)
|
best_path_dict = rescore_with_whole_lattice2(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=lm_scale_list,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=lm_scale_list,
|
||||||
|
)
|
||||||
elif params.method == "attention-decoder":
|
elif params.method == "attention-decoder":
|
||||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||||
rescored_lattice = rescore_with_whole_lattice(
|
rescored_lattice = rescore_with_whole_lattice(
|
||||||
@ -345,10 +354,14 @@ def decode_one_batch(
|
|||||||
assert False, f"Unsupported decoding method: {params.method}"
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
|
|
||||||
ans = dict()
|
ans = dict()
|
||||||
for lm_scale_str, best_path in best_path_dict.items():
|
if best_path_dict is not None:
|
||||||
hyps = get_texts(best_path)
|
for lm_scale_str, best_path in best_path_dict.items():
|
||||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
hyps = get_texts(best_path)
|
||||||
ans[lm_scale_str] = hyps
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
|
ans[lm_scale_str] = hyps
|
||||||
|
else:
|
||||||
|
for lm_scale in lm_scale_list:
|
||||||
|
ans[lm_scale_str] = [[] * lattice.shape[0]]
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +17,8 @@
|
|||||||
# NOTE: This file is a refactor of decode.py
|
# NOTE: This file is a refactor of decode.py
|
||||||
# We will delete decode.py and rename this file to decode.py
|
# We will delete decode.py and rename this file to decode.py
|
||||||
|
|
||||||
from typing import Dict, List
|
import logging
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
@ -505,6 +506,27 @@ def rescore_with_n_best_list(
|
|||||||
) -> Dict[str, k2.Fsa]:
|
) -> Dict[str, k2.Fsa]:
|
||||||
"""Rescore a nbest list with an n-gram LM.
|
"""Rescore a nbest list with an n-gram LM.
|
||||||
The path with a maximum score is used as the decoding output.
|
The path with a maximum score is used as the decoding output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lattice:
|
||||||
|
An FsaVec with axes [utt][state][arc]. It must have the following
|
||||||
|
attributes: ``aux_labels`` and ``lm_scores``. Its labels are
|
||||||
|
token IDs and ``aux_labels`` word IDs.
|
||||||
|
G:
|
||||||
|
An FsaVec containing only a single FSA. It is an n-gram LM.
|
||||||
|
num_paths:
|
||||||
|
Size of nbest list.
|
||||||
|
lm_scale_list:
|
||||||
|
A list of float representing LM score scales.
|
||||||
|
lattice_score_scale:
|
||||||
|
Scale to be applied to ``lattice.score`` when sampling paths
|
||||||
|
using ``k2.random_paths``.
|
||||||
|
use_double_scores:
|
||||||
|
True to use double precision during computation. False to use
|
||||||
|
single precision.
|
||||||
|
Returns:
|
||||||
|
A dict of FsaVec, whose key is an lm_scale and the value is the
|
||||||
|
best decoding path for each utterance in the lattice.
|
||||||
"""
|
"""
|
||||||
device = lattice.device
|
device = lattice.device
|
||||||
|
|
||||||
@ -543,3 +565,77 @@ def rescore_with_n_best_list(
|
|||||||
key = f"lm_scale_{lm_scale}"
|
key = f"lm_scale_{lm_scale}"
|
||||||
ans[key] = best_path
|
ans[key] = best_path
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def rescore_with_whole_lattice(
|
||||||
|
lattice: k2.Fsa,
|
||||||
|
G_with_epsilon_loops: k2.Fsa,
|
||||||
|
lm_scale_list: Optional[List[float]] = None,
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
|
||||||
|
# This is not an Nbest based coding method
|
||||||
|
assert hasattr(lattice, "lm_scores")
|
||||||
|
assert G_with_epsilon_loops.shape == (1, None, None)
|
||||||
|
|
||||||
|
device = lattice.device
|
||||||
|
lattice.scores = lattice.scores - lattice.lm_scores
|
||||||
|
# We will use lm_scores from G, so remove lats.lm_scores here
|
||||||
|
del lattice.lm_scores
|
||||||
|
|
||||||
|
assert hasattr(G_with_epsilon_loops, "lm_scores")
|
||||||
|
|
||||||
|
# Now, lattice.scores contains only am_scores
|
||||||
|
|
||||||
|
# inv_lattice has word IDs as labels.
|
||||||
|
# Its aux_labels are token IDs
|
||||||
|
inv_lattice = k2.invert(lattice)
|
||||||
|
num_seqs = lattice.shape[0]
|
||||||
|
|
||||||
|
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
|
max_loop_count = 10
|
||||||
|
loop_count = 0
|
||||||
|
while loop_count <= max_loop_count:
|
||||||
|
loop_count += 1
|
||||||
|
try:
|
||||||
|
rescoring_lattice = k2.intersect_device(
|
||||||
|
G_with_epsilon_loops,
|
||||||
|
inv_lattice,
|
||||||
|
b_to_a_map,
|
||||||
|
sorted_match_a=True,
|
||||||
|
)
|
||||||
|
rescoring_lattice = k2.top_sort(k2.connect(rescoring_lattice))
|
||||||
|
break
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.info(f"Caught exception:\n{e}\n")
|
||||||
|
logging.info(
|
||||||
|
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here
|
||||||
|
# to avoid OOM. You may need to fine tune it.
|
||||||
|
inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-9, True)
|
||||||
|
logging.info(
|
||||||
|
f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
|
||||||
|
)
|
||||||
|
if loop_count > max_loop_count:
|
||||||
|
logging.info("Return None as the resulting lattice is too large")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# lat has token IDs as labels
|
||||||
|
# and word IDs as aux_labels.
|
||||||
|
lat = k2.invert(rescoring_lattice)
|
||||||
|
|
||||||
|
if lm_scale_list is None:
|
||||||
|
return lat
|
||||||
|
|
||||||
|
ans = dict()
|
||||||
|
saved_am_scores = lat.scores - lat.lm_scores
|
||||||
|
for lm_scale in lm_scale_list:
|
||||||
|
am_scores = saved_am_scores / lm_scale
|
||||||
|
lat.scores = am_scores + lat.lm_scores
|
||||||
|
|
||||||
|
best_path = k2.shortest_path(lat, use_double_scores=use_double_scores)
|
||||||
|
key = f"lm_scale_{lm_scale}_yy"
|
||||||
|
ans[key] = best_path
|
||||||
|
return ans
|
||||||
|
Loading…
x
Reference in New Issue
Block a user