mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 03:52:18 +00:00
Add whole-lattice rescoring.
This commit is contained in:
parent
a44b4f84a5
commit
b0b942b355
@ -42,6 +42,7 @@ from icefall.decode2 import (
|
||||
nbest_decoding,
|
||||
nbest_oracle as nbest_oracle2,
|
||||
rescore_with_n_best_list as rescore_with_n_best_list2,
|
||||
rescore_with_whole_lattice as rescore_with_whole_lattice2,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
@ -261,9 +262,7 @@ def decode_one_batch(
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
key = (
|
||||
f"oracle_{num_paths}_lattice_score_scale_{lattice_score_scale}"
|
||||
)
|
||||
key = f"oracle_{num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
|
||||
return {key: hyps}
|
||||
else:
|
||||
return nbest_oracle(
|
||||
@ -322,8 +321,18 @@ def decode_one_batch(
|
||||
scale=params.lattice_score_scale,
|
||||
)
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
if True:
|
||||
# TODO: remove "else" branch
|
||||
best_path_dict = rescore_with_whole_lattice2(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
else:
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
elif params.method == "attention-decoder":
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
@ -345,10 +354,14 @@ def decode_one_batch(
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
|
||||
ans = dict()
|
||||
if best_path_dict is not None:
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
else:
|
||||
for lm_scale in lm_scale_list:
|
||||
ans[lm_scale_str] = [[] * lattice.shape[0]]
|
||||
return ans
|
||||
|
||||
|
||||
|
@ -17,7 +17,8 @@
|
||||
# NOTE: This file is a refactor of decode.py
|
||||
# We will delete decode.py and rename this file to decode.py
|
||||
|
||||
from typing import Dict, List
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import k2
|
||||
import torch
|
||||
@ -505,6 +506,27 @@ def rescore_with_n_best_list(
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
"""Rescore a nbest list with an n-gram LM.
|
||||
The path with a maximum score is used as the decoding output.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec with axes [utt][state][arc]. It must have the following
|
||||
attributes: ``aux_labels`` and ``lm_scores``. Its labels are
|
||||
token IDs and ``aux_labels`` word IDs.
|
||||
G:
|
||||
An FsaVec containing only a single FSA. It is an n-gram LM.
|
||||
num_paths:
|
||||
Size of nbest list.
|
||||
lm_scale_list:
|
||||
A list of float representing LM score scales.
|
||||
lattice_score_scale:
|
||||
Scale to be applied to ``lattice.score`` when sampling paths
|
||||
using ``k2.random_paths``.
|
||||
use_double_scores:
|
||||
True to use double precision during computation. False to use
|
||||
single precision.
|
||||
Returns:
|
||||
A dict of FsaVec, whose key is an lm_scale and the value is the
|
||||
best decoding path for each utterance in the lattice.
|
||||
"""
|
||||
device = lattice.device
|
||||
|
||||
@ -543,3 +565,77 @@ def rescore_with_n_best_list(
|
||||
key = f"lm_scale_{lm_scale}"
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
||||
|
||||
def rescore_with_whole_lattice(
|
||||
lattice: k2.Fsa,
|
||||
G_with_epsilon_loops: k2.Fsa,
|
||||
lm_scale_list: Optional[List[float]] = None,
|
||||
use_double_scores: bool = True,
|
||||
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
|
||||
# This is not an Nbest based coding method
|
||||
assert hasattr(lattice, "lm_scores")
|
||||
assert G_with_epsilon_loops.shape == (1, None, None)
|
||||
|
||||
device = lattice.device
|
||||
lattice.scores = lattice.scores - lattice.lm_scores
|
||||
# We will use lm_scores from G, so remove lats.lm_scores here
|
||||
del lattice.lm_scores
|
||||
|
||||
assert hasattr(G_with_epsilon_loops, "lm_scores")
|
||||
|
||||
# Now, lattice.scores contains only am_scores
|
||||
|
||||
# inv_lattice has word IDs as labels.
|
||||
# Its aux_labels are token IDs
|
||||
inv_lattice = k2.invert(lattice)
|
||||
num_seqs = lattice.shape[0]
|
||||
|
||||
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
|
||||
|
||||
max_loop_count = 10
|
||||
loop_count = 0
|
||||
while loop_count <= max_loop_count:
|
||||
loop_count += 1
|
||||
try:
|
||||
rescoring_lattice = k2.intersect_device(
|
||||
G_with_epsilon_loops,
|
||||
inv_lattice,
|
||||
b_to_a_map,
|
||||
sorted_match_a=True,
|
||||
)
|
||||
rescoring_lattice = k2.top_sort(k2.connect(rescoring_lattice))
|
||||
break
|
||||
except RuntimeError as e:
|
||||
logging.info(f"Caught exception:\n{e}\n")
|
||||
logging.info(
|
||||
f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}"
|
||||
)
|
||||
|
||||
# NOTE(fangjun): The choice of the threshold 1e-9 is arbitrary here
|
||||
# to avoid OOM. You may need to fine tune it.
|
||||
inv_lattice = k2.prune_on_arc_post(inv_lattice, 1e-9, True)
|
||||
logging.info(
|
||||
f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}"
|
||||
)
|
||||
if loop_count > max_loop_count:
|
||||
logging.info("Return None as the resulting lattice is too large")
|
||||
return None
|
||||
|
||||
# lat has token IDs as labels
|
||||
# and word IDs as aux_labels.
|
||||
lat = k2.invert(rescoring_lattice)
|
||||
|
||||
if lm_scale_list is None:
|
||||
return lat
|
||||
|
||||
ans = dict()
|
||||
saved_am_scores = lat.scores - lat.lm_scores
|
||||
for lm_scale in lm_scale_list:
|
||||
am_scores = saved_am_scores / lm_scale
|
||||
lat.scores = am_scores + lat.lm_scores
|
||||
|
||||
best_path = k2.shortest_path(lat, use_double_scores=use_double_scores)
|
||||
key = f"lm_scale_{lm_scale}_yy"
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
Loading…
x
Reference in New Issue
Block a user