Use attention decoder for rescoring.

This commit is contained in:
Fangjun Kuang 2021-07-28 12:22:09 +08:00
parent f65854cca5
commit bd69e4be32
4 changed files with 239 additions and 30 deletions

View File

@ -21,6 +21,7 @@ from icefall.decode import (
get_lattice,
nbest_decoding,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
@ -82,9 +83,12 @@ def get_params() -> AttributeDict:
# - nbest
# - nbest-rescoring
# - whole-lattice-rescoring
"method": "nbest-rescoring",
# num_paths is used when method is "nbest" and "nbest-rescoring"
"num_paths": 100,
# - attention-decoder
# "method": "whole-lattice-rescoring",
"method": "attention-decoder",
# num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder
"num_paths": 1000,
}
)
return params
@ -147,7 +151,7 @@ def decode_one_batch(
supervisions = batch["supervisions"]
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is [N, C, T]
nnet_output = nnet_output.permute(0, 2, 1)
@ -191,7 +195,11 @@ def decode_one_batch(
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"]
assert params.method in [
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
]
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
@ -203,10 +211,25 @@ def decode_one_batch(
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
)
else:
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
ans = dict()
for lm_scale_str, best_path in best_path_dict.items():
@ -351,7 +374,11 @@ def main():
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]:
if params.method in (
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
@ -374,7 +401,7 @@ def main():
d = torch.load(params.lm_dir / "G_4_gram.pt")
G = k2.Fsa.from_dict(d).to(device)
if params.method == "whole-lattice-rescoring":
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)

View File

@ -259,7 +259,7 @@ class Transformer(nn.Module):
return decoder_loss
def decoder_nll(
self, x: Tensor, encoder_mask: Tensor, token_ids: List[int] = None
self, x: Tensor, encoder_mask: Tensor, token_ids: List[List[int]] = None
) -> Tensor:
"""
Args:

9
egs/librispeech/ASR/local/compile_hlg.py Normal file → Executable file
View File

@ -32,7 +32,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
print(f"building ctc_top. max_token_id: {max_token_id}")
print(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
@ -86,13 +86,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
LG = k2.arc_sort(LG)
print("Composing H and LG")
HLG = k2.compose(H, LG, inner_labels="phones")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")
print("Connecting LG")
HLG = k2.connect(HLG)
print("Arc sorting LG")
HLG = k2.arc_sort(HLG)
print(f"HLG.shape: {HLG.shape}")
return HLG

View File

@ -1,8 +1,9 @@
import logging
from typing import Dict, List
from typing import Dict, List, Optional, Tuple, Union
import k2
import torch
import torch.nn as nn
def _intersect_device(
@ -11,7 +12,7 @@ def _intersect_device(
b_to_a_map: torch.Tensor,
sorted_match_a: bool,
batch_size: int = 50,
):
) -> k2.Fsa:
"""This is a wrapper of k2.intersect_device and its purpose is to split
b_fsas into several batches and process each batch separately to avoid
CUDA OOM error.
@ -55,7 +56,7 @@ def get_lattice(
min_active_states: int,
max_active_states: int,
subsampling_factor: int = 1,
):
) -> k2.Fsa:
"""Get the decoding lattice from a decoding graph and neural
network output.
@ -129,7 +130,7 @@ def one_best_decoding(
def nbest_decoding(
lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True
):
) -> k2.Fsa:
"""It implements something like CTC prefix beam search using n-best lists.
The basic idea is to first extra n-best paths from the given lattice,
@ -253,11 +254,11 @@ def nbest_decoding(
return best_path_fsa
def compute_am_scores(
def compute_am_and_lm_scores(
lattice: k2.Fsa,
word_fsa_with_epsilon_loops: k2.Fsa,
path_to_seq_map: torch.Tensor,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute AM scores of n-best lists (represented as word_fsas).
Args:
@ -272,8 +273,8 @@ def compute_am_scores(
which sequence the i-th Fsa in word_fsa_with_epsilon_loops belongs to.
path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0().
Returns:
Return a 1-D torch.Tensor containing the AM scores of each path.
`ans.numel() == word_fsas_with_epsilon_loops.shape[0]`
Return a tuple containing two 1-D torch.Tensors: (am_scores, lm_scores).
Each tensor's `numel()' equals to `word_fsas_with_epsilon_loops.shape[0]`
"""
assert len(lattice.shape) == 3
assert hasattr(lattice, "lm_scores")
@ -293,23 +294,29 @@ def compute_am_scores(
del inv_lattice.aux_labels
inv_lattice = k2.arc_sort(inv_lattice)
am_path_lattice = _intersect_device(
path_lattice = _intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_seq_map,
sorted_match_a=True,
)
am_path_lattice = k2.top_sort(k2.connect(am_path_lattice))
path_lattice = k2.top_sort(k2.connect(path_lattice))
# The `scores` of every arc consists of `am_scores` and `lm_scores`
am_path_lattice.scores = am_path_lattice.scores - am_path_lattice.lm_scores
path_lattice.scores = path_lattice.scores - path_lattice.lm_scores
am_scores = am_path_lattice.get_tot_scores(
am_scores = path_lattice.get_tot_scores(
use_double_scores=True, log_semiring=False
)
return am_scores
path_lattice.scores = path_lattice.lm_scores
lm_scores = path_lattice.get_tot_scores(
use_double_scores=True, log_semiring=False
)
return am_scores.to(torch.float32), lm_scores.to(torch.float32)
def rescore_with_n_best_list(
@ -395,7 +402,7 @@ def rescore_with_n_best_list(
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
am_scores = compute_am_scores(
am_scores, _ = compute_am_and_lm_scores(
lattice, word_fsa_with_epsilon_loops, path_to_seq_map
)
@ -456,8 +463,10 @@ def rescore_with_n_best_list(
def rescore_with_whole_lattice(
lattice: k2.Fsa, G_with_epsilon_loops: k2.Fsa, lm_scale_list: List[float]
) -> Dict[str, k2.Fsa]:
lattice: k2.Fsa,
G_with_epsilon_loops: k2.Fsa,
lm_scale_list: Optional[List[float]] = None,
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
"""Use whole lattice to rescore.
Args:
@ -467,10 +476,13 @@ def rescore_with_whole_lattice(
An FsaVec representing the language model (LM). Note that it
is an FsaVec, but it contains only one Fsa.
lm_scale_list:
A list containing lm_scale values.
A list containing lm_scale values or None.
Returns:
A dict of FsaVec, whose key is a lm_scale and the value represents the
best decoding path for each sequence in the lattice.
If lm_scale_list is not None, return a dict of FsaVec, whose key
is a lm_scale and the value represents the best decoding path for
each sequence in the lattice.
If lm_scale_list is not None, return a lattice that is rescored
with the given LM.
"""
assert len(lattice.shape) == 3
assert hasattr(lattice, "lm_scores")
@ -517,6 +529,9 @@ def rescore_with_whole_lattice(
# and word IDs as aux_labels.
lat = k2.invert(rescoring_lattice)
if lm_scale_list is None:
return lat
ans = dict()
#
# The following implements
@ -532,3 +547,165 @@ def rescore_with_whole_lattice(
key = f"lm_scale_{lm_scale}"
ans[key] = best_path
return ans
def rescore_with_attention_decoder(
lattice: k2.Fsa,
num_paths: int,
model: nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
) -> Dict[str, k2.Fsa]:
"""This function extracts n paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest
score is used as the decoding output.
lattice:
An FsaVec. It can be the return value of :func:`get_lattice`.
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `[T, N, C]`.
memory_key_padding_mask:
The padding mask for memory with shape [N, T].
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each sequence in the lattice.
"""
# First, extract `num_paths` paths for each sequence.
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
# word_seq is a k2.RaggedInt sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.
word_seq = k2.index(lattice.aux_labels, path)
# Remove epsilons and -1 from word_seq
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
# Remove paths that has identical word sequences.
#
# unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word]
# except that there are no repeated paths with the same word_seq
# within a sequence.
#
# num_repeats is also a k2.RaggedInt with 2 axes containing the
# multiplicities of each path.
# num_repeats.num_elements() == unique_word_seqs.num_elements()
#
# Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index.
# new2old.numel() == unique_word_seqs.tot_size(1)
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
word_seq, need_num_repeats=True, need_new2old_indexes=True
)
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
# path_to_seq_map is a 1-D torch.Tensor.
# path_to_seq_map[i] is the seq to which the i-th path
# belongs.
path_to_seq_map = seq_to_path_shape.row_ids(1)
# Remove the seq axis.
# Now unique_word_seq has only two axes [path][word]
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
# word_fsa is an FsaVec with axes [path][state][arc]
word_fsa = k2.linear_fsa(unique_word_seq)
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
am_scores, ngram_lm_scores = compute_am_and_lm_scores(
lattice, word_fsa_with_epsilon_loops, path_to_seq_map
)
# Now we use the attention decoder to compute another
# score: attention_scores.
#
# To do that, we have to get the input and output for the attention
# decoder.
# CAUTION: The "tokens" attribute is set in the file
# local/compile_hlg.py
token_seq = k2.index(lattice.tokens, path)
# Remove epsilons and -1 from token_seq
token_seq = k2.ragged.remove_values_leq(token_seq, 0)
# Remove the seq axis.
token_seq = k2.ragged.remove_axis(token_seq, 0)
token_seq, _ = k2.ragged.index(
token_seq, indexes=new2old, axis=0, need_value_indexes=False
)
# Now word in unique_word_seq has its corresponding token IDs.
token_ids = k2.ragged.to_list(token_seq)
num_word_seqs = new2old.numel()
path_to_seq_map_long = path_to_seq_map.to(torch.long)
expanded_memory = memory.index_select(1, path_to_seq_map_long)
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_seq_map_long
)
nll = model.decoder_nll(
expanded_memory, expanded_memory_key_padding_mask, token_ids
)
assert nll.ndim == 2
assert nll.shape[0] == num_word_seqs
attention_scores = -nll.sum(dim=1)
assert attention_scores.ndim == 1
assert attention_scores.numel() == num_word_seqs
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
path_2axes = k2.ragged.remove_axis(path, 0)
ans = dict()
for n_scale in ngram_lm_scale_list:
for a_scale in attention_scale_list:
tot_scores = (
am_scores
+ n_scale * ngram_lm_scores
+ a_scale * attention_scores
)
ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores)
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
best_path_indexes = k2.index(new2old, argmax_indexes)
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
best_path = k2.index(path_2axes, best_path_indexes)
# labels is a k2.RaggedInt with 2 axes [path][token_id]
# Note that it contains -1s.
labels = k2.index(lattice.labels.contiguous(), best_path)
labels = k2.ragged.remove_values_eq(labels, -1)
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
# aux_labels is also a k2.RaggedInt with 2 axes
aux_labels = k2.index(lattice.aux_labels, best_path.values())
best_path_fsa = k2.linear_fsa(labels)
best_path_fsa.aux_labels = aux_labels
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
ans[key] = best_path_fsa
return ans