mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Use attention decoder for rescoring.
This commit is contained in:
parent
f65854cca5
commit
bd69e4be32
@ -21,6 +21,7 @@ from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
@ -82,9 +83,12 @@ def get_params() -> AttributeDict:
|
||||
# - nbest
|
||||
# - nbest-rescoring
|
||||
# - whole-lattice-rescoring
|
||||
"method": "nbest-rescoring",
|
||||
# num_paths is used when method is "nbest" and "nbest-rescoring"
|
||||
"num_paths": 100,
|
||||
# - attention-decoder
|
||||
# "method": "whole-lattice-rescoring",
|
||||
"method": "attention-decoder",
|
||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||
# and attention-decoder
|
||||
"num_paths": 1000,
|
||||
}
|
||||
)
|
||||
return params
|
||||
@ -147,7 +151,7 @@ def decode_one_batch(
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||
# nnet_output is [N, C, T]
|
||||
|
||||
nnet_output = nnet_output.permute(0, 2, 1)
|
||||
@ -191,7 +195,11 @@ def decode_one_batch(
|
||||
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
|
||||
return {key: hyps}
|
||||
|
||||
assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"]
|
||||
assert params.method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
]
|
||||
|
||||
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||
@ -203,10 +211,25 @@ def decode_one_batch(
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
else:
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
|
||||
)
|
||||
elif params.method == "attention-decoder":
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
||||
)
|
||||
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
|
||||
ans = dict()
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
@ -351,7 +374,11 @@ def main():
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]:
|
||||
if params.method in (
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
):
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
logging.warning("It may take 8 minutes.")
|
||||
@ -374,7 +401,7 @@ def main():
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
||||
G = k2.Fsa.from_dict(d).to(device)
|
||||
|
||||
if params.method == "whole-lattice-rescoring":
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
|
@ -259,7 +259,7 @@ class Transformer(nn.Module):
|
||||
return decoder_loss
|
||||
|
||||
def decoder_nll(
|
||||
self, x: Tensor, encoder_mask: Tensor, token_ids: List[int] = None
|
||||
self, x: Tensor, encoder_mask: Tensor, token_ids: List[List[int]] = None
|
||||
) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
|
9
egs/librispeech/ASR/local/compile_hlg.py
Normal file → Executable file
9
egs/librispeech/ASR/local/compile_hlg.py
Normal file → Executable file
@ -32,7 +32,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
print(f"building ctc_top. max_token_id: {max_token_id}")
|
||||
print(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
@ -86,13 +86,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
print("Composing H and LG")
|
||||
HLG = k2.compose(H, LG, inner_labels="phones")
|
||||
# CAUTION: The name of the inner_labels is fixed
|
||||
# to `tokens`. If you want to change it, please
|
||||
# also change other places in icefall that are using
|
||||
# it.
|
||||
HLG = k2.compose(H, LG, inner_labels="tokens")
|
||||
|
||||
print("Connecting LG")
|
||||
HLG = k2.connect(HLG)
|
||||
|
||||
print("Arc sorting LG")
|
||||
HLG = k2.arc_sort(HLG)
|
||||
print(f"HLG.shape: {HLG.shape}")
|
||||
|
||||
return HLG
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _intersect_device(
|
||||
@ -11,7 +12,7 @@ def _intersect_device(
|
||||
b_to_a_map: torch.Tensor,
|
||||
sorted_match_a: bool,
|
||||
batch_size: int = 50,
|
||||
):
|
||||
) -> k2.Fsa:
|
||||
"""This is a wrapper of k2.intersect_device and its purpose is to split
|
||||
b_fsas into several batches and process each batch separately to avoid
|
||||
CUDA OOM error.
|
||||
@ -55,7 +56,7 @@ def get_lattice(
|
||||
min_active_states: int,
|
||||
max_active_states: int,
|
||||
subsampling_factor: int = 1,
|
||||
):
|
||||
) -> k2.Fsa:
|
||||
"""Get the decoding lattice from a decoding graph and neural
|
||||
network output.
|
||||
|
||||
@ -129,7 +130,7 @@ def one_best_decoding(
|
||||
|
||||
def nbest_decoding(
|
||||
lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True
|
||||
):
|
||||
) -> k2.Fsa:
|
||||
"""It implements something like CTC prefix beam search using n-best lists.
|
||||
|
||||
The basic idea is to first extra n-best paths from the given lattice,
|
||||
@ -253,11 +254,11 @@ def nbest_decoding(
|
||||
return best_path_fsa
|
||||
|
||||
|
||||
def compute_am_scores(
|
||||
def compute_am_and_lm_scores(
|
||||
lattice: k2.Fsa,
|
||||
word_fsa_with_epsilon_loops: k2.Fsa,
|
||||
path_to_seq_map: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute AM scores of n-best lists (represented as word_fsas).
|
||||
|
||||
Args:
|
||||
@ -272,8 +273,8 @@ def compute_am_scores(
|
||||
which sequence the i-th Fsa in word_fsa_with_epsilon_loops belongs to.
|
||||
path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0().
|
||||
Returns:
|
||||
Return a 1-D torch.Tensor containing the AM scores of each path.
|
||||
`ans.numel() == word_fsas_with_epsilon_loops.shape[0]`
|
||||
Return a tuple containing two 1-D torch.Tensors: (am_scores, lm_scores).
|
||||
Each tensor's `numel()' equals to `word_fsas_with_epsilon_loops.shape[0]`
|
||||
"""
|
||||
assert len(lattice.shape) == 3
|
||||
assert hasattr(lattice, "lm_scores")
|
||||
@ -293,23 +294,29 @@ def compute_am_scores(
|
||||
del inv_lattice.aux_labels
|
||||
inv_lattice = k2.arc_sort(inv_lattice)
|
||||
|
||||
am_path_lattice = _intersect_device(
|
||||
path_lattice = _intersect_device(
|
||||
inv_lattice,
|
||||
word_fsa_with_epsilon_loops,
|
||||
b_to_a_map=path_to_seq_map,
|
||||
sorted_match_a=True,
|
||||
)
|
||||
|
||||
am_path_lattice = k2.top_sort(k2.connect(am_path_lattice))
|
||||
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
||||
|
||||
# The `scores` of every arc consists of `am_scores` and `lm_scores`
|
||||
am_path_lattice.scores = am_path_lattice.scores - am_path_lattice.lm_scores
|
||||
path_lattice.scores = path_lattice.scores - path_lattice.lm_scores
|
||||
|
||||
am_scores = am_path_lattice.get_tot_scores(
|
||||
am_scores = path_lattice.get_tot_scores(
|
||||
use_double_scores=True, log_semiring=False
|
||||
)
|
||||
|
||||
return am_scores
|
||||
path_lattice.scores = path_lattice.lm_scores
|
||||
|
||||
lm_scores = path_lattice.get_tot_scores(
|
||||
use_double_scores=True, log_semiring=False
|
||||
)
|
||||
|
||||
return am_scores.to(torch.float32), lm_scores.to(torch.float32)
|
||||
|
||||
|
||||
def rescore_with_n_best_list(
|
||||
@ -395,7 +402,7 @@ def rescore_with_n_best_list(
|
||||
|
||||
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
|
||||
|
||||
am_scores = compute_am_scores(
|
||||
am_scores, _ = compute_am_and_lm_scores(
|
||||
lattice, word_fsa_with_epsilon_loops, path_to_seq_map
|
||||
)
|
||||
|
||||
@ -456,8 +463,10 @@ def rescore_with_n_best_list(
|
||||
|
||||
|
||||
def rescore_with_whole_lattice(
|
||||
lattice: k2.Fsa, G_with_epsilon_loops: k2.Fsa, lm_scale_list: List[float]
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
lattice: k2.Fsa,
|
||||
G_with_epsilon_loops: k2.Fsa,
|
||||
lm_scale_list: Optional[List[float]] = None,
|
||||
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
|
||||
"""Use whole lattice to rescore.
|
||||
|
||||
Args:
|
||||
@ -467,10 +476,13 @@ def rescore_with_whole_lattice(
|
||||
An FsaVec representing the language model (LM). Note that it
|
||||
is an FsaVec, but it contains only one Fsa.
|
||||
lm_scale_list:
|
||||
A list containing lm_scale values.
|
||||
A list containing lm_scale values or None.
|
||||
Returns:
|
||||
A dict of FsaVec, whose key is a lm_scale and the value represents the
|
||||
best decoding path for each sequence in the lattice.
|
||||
If lm_scale_list is not None, return a dict of FsaVec, whose key
|
||||
is a lm_scale and the value represents the best decoding path for
|
||||
each sequence in the lattice.
|
||||
If lm_scale_list is not None, return a lattice that is rescored
|
||||
with the given LM.
|
||||
"""
|
||||
assert len(lattice.shape) == 3
|
||||
assert hasattr(lattice, "lm_scores")
|
||||
@ -517,6 +529,9 @@ def rescore_with_whole_lattice(
|
||||
# and word IDs as aux_labels.
|
||||
lat = k2.invert(rescoring_lattice)
|
||||
|
||||
if lm_scale_list is None:
|
||||
return lat
|
||||
|
||||
ans = dict()
|
||||
#
|
||||
# The following implements
|
||||
@ -532,3 +547,165 @@ def rescore_with_whole_lattice(
|
||||
key = f"lm_scale_{lm_scale}"
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
||||
|
||||
def rescore_with_attention_decoder(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
model: nn.Module,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
"""This function extracts n paths from the given lattice and uses
|
||||
an attention decoder to rescore them. The path with the highest
|
||||
score is used as the decoding output.
|
||||
|
||||
lattice:
|
||||
An FsaVec. It can be the return value of :func:`get_lattice`.
|
||||
num_paths:
|
||||
Number of paths to extract from the given lattice for rescoring.
|
||||
model:
|
||||
A transformer model. See the class "Transformer" in
|
||||
conformer_ctc/transformer.py for its interface.
|
||||
memory:
|
||||
The encoder memory of the given model. It is the output of
|
||||
the last torch.nn.TransformerEncoder layer in the given model.
|
||||
Its shape is `[T, N, C]`.
|
||||
memory_key_padding_mask:
|
||||
The padding mask for memory with shape [N, T].
|
||||
Returns:
|
||||
A dict of FsaVec, whose key contains a string
|
||||
ngram_lm_scale_attention_scale and the value is the
|
||||
best decoding path for each sequence in the lattice.
|
||||
"""
|
||||
# First, extract `num_paths` paths for each sequence.
|
||||
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
|
||||
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
|
||||
|
||||
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
||||
# but it contains word IDs. Note that it also contains 0s and -1s.
|
||||
# The last entry in each sublist is -1.
|
||||
word_seq = k2.index(lattice.aux_labels, path)
|
||||
|
||||
# Remove epsilons and -1 from word_seq
|
||||
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
|
||||
|
||||
# Remove paths that has identical word sequences.
|
||||
#
|
||||
# unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word]
|
||||
# except that there are no repeated paths with the same word_seq
|
||||
# within a sequence.
|
||||
#
|
||||
# num_repeats is also a k2.RaggedInt with 2 axes containing the
|
||||
# multiplicities of each path.
|
||||
# num_repeats.num_elements() == unique_word_seqs.num_elements()
|
||||
#
|
||||
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
||||
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
||||
# to the input path index.
|
||||
# new2old.numel() == unique_word_seqs.tot_size(1)
|
||||
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
|
||||
word_seq, need_num_repeats=True, need_new2old_indexes=True
|
||||
)
|
||||
|
||||
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
|
||||
|
||||
# path_to_seq_map is a 1-D torch.Tensor.
|
||||
# path_to_seq_map[i] is the seq to which the i-th path
|
||||
# belongs.
|
||||
path_to_seq_map = seq_to_path_shape.row_ids(1)
|
||||
|
||||
# Remove the seq axis.
|
||||
# Now unique_word_seq has only two axes [path][word]
|
||||
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
|
||||
|
||||
# word_fsa is an FsaVec with axes [path][state][arc]
|
||||
word_fsa = k2.linear_fsa(unique_word_seq)
|
||||
|
||||
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
|
||||
|
||||
am_scores, ngram_lm_scores = compute_am_and_lm_scores(
|
||||
lattice, word_fsa_with_epsilon_loops, path_to_seq_map
|
||||
)
|
||||
# Now we use the attention decoder to compute another
|
||||
# score: attention_scores.
|
||||
#
|
||||
# To do that, we have to get the input and output for the attention
|
||||
# decoder.
|
||||
|
||||
# CAUTION: The "tokens" attribute is set in the file
|
||||
# local/compile_hlg.py
|
||||
token_seq = k2.index(lattice.tokens, path)
|
||||
|
||||
# Remove epsilons and -1 from token_seq
|
||||
token_seq = k2.ragged.remove_values_leq(token_seq, 0)
|
||||
|
||||
# Remove the seq axis.
|
||||
token_seq = k2.ragged.remove_axis(token_seq, 0)
|
||||
|
||||
token_seq, _ = k2.ragged.index(
|
||||
token_seq, indexes=new2old, axis=0, need_value_indexes=False
|
||||
)
|
||||
|
||||
# Now word in unique_word_seq has its corresponding token IDs.
|
||||
token_ids = k2.ragged.to_list(token_seq)
|
||||
|
||||
num_word_seqs = new2old.numel()
|
||||
|
||||
path_to_seq_map_long = path_to_seq_map.to(torch.long)
|
||||
expanded_memory = memory.index_select(1, path_to_seq_map_long)
|
||||
|
||||
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||
0, path_to_seq_map_long
|
||||
)
|
||||
|
||||
nll = model.decoder_nll(
|
||||
expanded_memory, expanded_memory_key_padding_mask, token_ids
|
||||
)
|
||||
assert nll.ndim == 2
|
||||
assert nll.shape[0] == num_word_seqs
|
||||
|
||||
attention_scores = -nll.sum(dim=1)
|
||||
assert attention_scores.ndim == 1
|
||||
assert attention_scores.numel() == num_word_seqs
|
||||
|
||||
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
|
||||
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
|
||||
path_2axes = k2.ragged.remove_axis(path, 0)
|
||||
|
||||
ans = dict()
|
||||
for n_scale in ngram_lm_scale_list:
|
||||
for a_scale in attention_scale_list:
|
||||
tot_scores = (
|
||||
am_scores
|
||||
+ n_scale * ngram_lm_scores
|
||||
+ a_scale * attention_scores
|
||||
)
|
||||
ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores)
|
||||
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
|
||||
|
||||
best_path_indexes = k2.index(new2old, argmax_indexes)
|
||||
|
||||
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
|
||||
best_path = k2.index(path_2axes, best_path_indexes)
|
||||
|
||||
# labels is a k2.RaggedInt with 2 axes [path][token_id]
|
||||
# Note that it contains -1s.
|
||||
labels = k2.index(lattice.labels.contiguous(), best_path)
|
||||
|
||||
labels = k2.ragged.remove_values_eq(labels, -1)
|
||||
|
||||
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
|
||||
# aux_labels is also a k2.RaggedInt with 2 axes
|
||||
aux_labels = k2.index(lattice.aux_labels, best_path.values())
|
||||
|
||||
best_path_fsa = k2.linear_fsa(labels)
|
||||
best_path_fsa.aux_labels = aux_labels
|
||||
|
||||
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
||||
ans[key] = best_path_fsa
|
||||
return ans
|
||||
|
Loading…
x
Reference in New Issue
Block a user