Use attention decoder for rescoring.

This commit is contained in:
Fangjun Kuang 2021-07-28 12:22:09 +08:00
parent f65854cca5
commit bd69e4be32
4 changed files with 239 additions and 30 deletions

View File

@ -21,6 +21,7 @@ from icefall.decode import (
get_lattice, get_lattice,
nbest_decoding, nbest_decoding,
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list, rescore_with_n_best_list,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
@ -82,9 +83,12 @@ def get_params() -> AttributeDict:
# - nbest # - nbest
# - nbest-rescoring # - nbest-rescoring
# - whole-lattice-rescoring # - whole-lattice-rescoring
"method": "nbest-rescoring", # - attention-decoder
# num_paths is used when method is "nbest" and "nbest-rescoring" # "method": "whole-lattice-rescoring",
"num_paths": 100, "method": "attention-decoder",
# num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder
"num_paths": 1000,
} }
) )
return params return params
@ -147,7 +151,7 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
nnet_output, encoder_memory, memory_mask = model(feature, supervisions) nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is [N, C, T] # nnet_output is [N, C, T]
nnet_output = nnet_output.permute(0, 2, 1) nnet_output = nnet_output.permute(0, 2, 1)
@ -191,7 +195,11 @@ def decode_one_batch(
hyps = [[lexicon.words[i] for i in ids] for ids in hyps] hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
return {key: hyps} return {key: hyps}
assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"] assert params.method in [
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
]
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
@ -203,10 +211,25 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
) )
else: elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice( best_path_dict = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
) )
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
ans = dict() ans = dict()
for lm_scale_str, best_path in best_path_dict.items(): for lm_scale_str, best_path in best_path_dict.items():
@ -351,7 +374,11 @@ def main():
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone() HLG.lm_scores = HLG.scores.clone()
if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]: if params.method in (
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
):
if not (params.lm_dir / "G_4_gram.pt").is_file(): if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt") logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.") logging.warning("It may take 8 minutes.")
@ -374,7 +401,7 @@ def main():
d = torch.load(params.lm_dir / "G_4_gram.pt") d = torch.load(params.lm_dir / "G_4_gram.pt")
G = k2.Fsa.from_dict(d).to(device) G = k2.Fsa.from_dict(d).to(device)
if params.method == "whole-lattice-rescoring": if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
# Add epsilon self-loops to G as we will compose # Add epsilon self-loops to G as we will compose
# it with the whole lattice later # it with the whole lattice later
G = k2.add_epsilon_self_loops(G) G = k2.add_epsilon_self_loops(G)

View File

@ -259,7 +259,7 @@ class Transformer(nn.Module):
return decoder_loss return decoder_loss
def decoder_nll( def decoder_nll(
self, x: Tensor, encoder_mask: Tensor, token_ids: List[int] = None self, x: Tensor, encoder_mask: Tensor, token_ids: List[List[int]] = None
) -> Tensor: ) -> Tensor:
""" """
Args: Args:

9
egs/librispeech/ASR/local/compile_hlg.py Normal file → Executable file
View File

@ -32,7 +32,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
""" """
lexicon = Lexicon(lang_dir) lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens) max_token_id = max(lexicon.tokens)
print(f"building ctc_top. max_token_id: {max_token_id}") print(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id) H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
@ -86,13 +86,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
LG = k2.arc_sort(LG) LG = k2.arc_sort(LG)
print("Composing H and LG") print("Composing H and LG")
HLG = k2.compose(H, LG, inner_labels="phones") # CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")
print("Connecting LG") print("Connecting LG")
HLG = k2.connect(HLG) HLG = k2.connect(HLG)
print("Arc sorting LG") print("Arc sorting LG")
HLG = k2.arc_sort(HLG) HLG = k2.arc_sort(HLG)
print(f"HLG.shape: {HLG.shape}")
return HLG return HLG

View File

@ -1,8 +1,9 @@
import logging import logging
from typing import Dict, List from typing import Dict, List, Optional, Tuple, Union
import k2 import k2
import torch import torch
import torch.nn as nn
def _intersect_device( def _intersect_device(
@ -11,7 +12,7 @@ def _intersect_device(
b_to_a_map: torch.Tensor, b_to_a_map: torch.Tensor,
sorted_match_a: bool, sorted_match_a: bool,
batch_size: int = 50, batch_size: int = 50,
): ) -> k2.Fsa:
"""This is a wrapper of k2.intersect_device and its purpose is to split """This is a wrapper of k2.intersect_device and its purpose is to split
b_fsas into several batches and process each batch separately to avoid b_fsas into several batches and process each batch separately to avoid
CUDA OOM error. CUDA OOM error.
@ -55,7 +56,7 @@ def get_lattice(
min_active_states: int, min_active_states: int,
max_active_states: int, max_active_states: int,
subsampling_factor: int = 1, subsampling_factor: int = 1,
): ) -> k2.Fsa:
"""Get the decoding lattice from a decoding graph and neural """Get the decoding lattice from a decoding graph and neural
network output. network output.
@ -129,7 +130,7 @@ def one_best_decoding(
def nbest_decoding( def nbest_decoding(
lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True
): ) -> k2.Fsa:
"""It implements something like CTC prefix beam search using n-best lists. """It implements something like CTC prefix beam search using n-best lists.
The basic idea is to first extra n-best paths from the given lattice, The basic idea is to first extra n-best paths from the given lattice,
@ -253,11 +254,11 @@ def nbest_decoding(
return best_path_fsa return best_path_fsa
def compute_am_scores( def compute_am_and_lm_scores(
lattice: k2.Fsa, lattice: k2.Fsa,
word_fsa_with_epsilon_loops: k2.Fsa, word_fsa_with_epsilon_loops: k2.Fsa,
path_to_seq_map: torch.Tensor, path_to_seq_map: torch.Tensor,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute AM scores of n-best lists (represented as word_fsas). """Compute AM scores of n-best lists (represented as word_fsas).
Args: Args:
@ -272,8 +273,8 @@ def compute_am_scores(
which sequence the i-th Fsa in word_fsa_with_epsilon_loops belongs to. which sequence the i-th Fsa in word_fsa_with_epsilon_loops belongs to.
path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0(). path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0().
Returns: Returns:
Return a 1-D torch.Tensor containing the AM scores of each path. Return a tuple containing two 1-D torch.Tensors: (am_scores, lm_scores).
`ans.numel() == word_fsas_with_epsilon_loops.shape[0]` Each tensor's `numel()' equals to `word_fsas_with_epsilon_loops.shape[0]`
""" """
assert len(lattice.shape) == 3 assert len(lattice.shape) == 3
assert hasattr(lattice, "lm_scores") assert hasattr(lattice, "lm_scores")
@ -293,23 +294,29 @@ def compute_am_scores(
del inv_lattice.aux_labels del inv_lattice.aux_labels
inv_lattice = k2.arc_sort(inv_lattice) inv_lattice = k2.arc_sort(inv_lattice)
am_path_lattice = _intersect_device( path_lattice = _intersect_device(
inv_lattice, inv_lattice,
word_fsa_with_epsilon_loops, word_fsa_with_epsilon_loops,
b_to_a_map=path_to_seq_map, b_to_a_map=path_to_seq_map,
sorted_match_a=True, sorted_match_a=True,
) )
am_path_lattice = k2.top_sort(k2.connect(am_path_lattice)) path_lattice = k2.top_sort(k2.connect(path_lattice))
# The `scores` of every arc consists of `am_scores` and `lm_scores` # The `scores` of every arc consists of `am_scores` and `lm_scores`
am_path_lattice.scores = am_path_lattice.scores - am_path_lattice.lm_scores path_lattice.scores = path_lattice.scores - path_lattice.lm_scores
am_scores = am_path_lattice.get_tot_scores( am_scores = path_lattice.get_tot_scores(
use_double_scores=True, log_semiring=False use_double_scores=True, log_semiring=False
) )
return am_scores path_lattice.scores = path_lattice.lm_scores
lm_scores = path_lattice.get_tot_scores(
use_double_scores=True, log_semiring=False
)
return am_scores.to(torch.float32), lm_scores.to(torch.float32)
def rescore_with_n_best_list( def rescore_with_n_best_list(
@ -395,7 +402,7 @@ def rescore_with_n_best_list(
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa) word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
am_scores = compute_am_scores( am_scores, _ = compute_am_and_lm_scores(
lattice, word_fsa_with_epsilon_loops, path_to_seq_map lattice, word_fsa_with_epsilon_loops, path_to_seq_map
) )
@ -456,8 +463,10 @@ def rescore_with_n_best_list(
def rescore_with_whole_lattice( def rescore_with_whole_lattice(
lattice: k2.Fsa, G_with_epsilon_loops: k2.Fsa, lm_scale_list: List[float] lattice: k2.Fsa,
) -> Dict[str, k2.Fsa]: G_with_epsilon_loops: k2.Fsa,
lm_scale_list: Optional[List[float]] = None,
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
"""Use whole lattice to rescore. """Use whole lattice to rescore.
Args: Args:
@ -467,10 +476,13 @@ def rescore_with_whole_lattice(
An FsaVec representing the language model (LM). Note that it An FsaVec representing the language model (LM). Note that it
is an FsaVec, but it contains only one Fsa. is an FsaVec, but it contains only one Fsa.
lm_scale_list: lm_scale_list:
A list containing lm_scale values. A list containing lm_scale values or None.
Returns: Returns:
A dict of FsaVec, whose key is a lm_scale and the value represents the If lm_scale_list is not None, return a dict of FsaVec, whose key
best decoding path for each sequence in the lattice. is a lm_scale and the value represents the best decoding path for
each sequence in the lattice.
If lm_scale_list is not None, return a lattice that is rescored
with the given LM.
""" """
assert len(lattice.shape) == 3 assert len(lattice.shape) == 3
assert hasattr(lattice, "lm_scores") assert hasattr(lattice, "lm_scores")
@ -517,6 +529,9 @@ def rescore_with_whole_lattice(
# and word IDs as aux_labels. # and word IDs as aux_labels.
lat = k2.invert(rescoring_lattice) lat = k2.invert(rescoring_lattice)
if lm_scale_list is None:
return lat
ans = dict() ans = dict()
# #
# The following implements # The following implements
@ -532,3 +547,165 @@ def rescore_with_whole_lattice(
key = f"lm_scale_{lm_scale}" key = f"lm_scale_{lm_scale}"
ans[key] = best_path ans[key] = best_path
return ans return ans
def rescore_with_attention_decoder(
lattice: k2.Fsa,
num_paths: int,
model: nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
) -> Dict[str, k2.Fsa]:
"""This function extracts n paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest
score is used as the decoding output.
lattice:
An FsaVec. It can be the return value of :func:`get_lattice`.
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `[T, N, C]`.
memory_key_padding_mask:
The padding mask for memory with shape [N, T].
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each sequence in the lattice.
"""
# First, extract `num_paths` paths for each sequence.
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
# word_seq is a k2.RaggedInt sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.
word_seq = k2.index(lattice.aux_labels, path)
# Remove epsilons and -1 from word_seq
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
# Remove paths that has identical word sequences.
#
# unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word]
# except that there are no repeated paths with the same word_seq
# within a sequence.
#
# num_repeats is also a k2.RaggedInt with 2 axes containing the
# multiplicities of each path.
# num_repeats.num_elements() == unique_word_seqs.num_elements()
#
# Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index.
# new2old.numel() == unique_word_seqs.tot_size(1)
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
word_seq, need_num_repeats=True, need_new2old_indexes=True
)
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
# path_to_seq_map is a 1-D torch.Tensor.
# path_to_seq_map[i] is the seq to which the i-th path
# belongs.
path_to_seq_map = seq_to_path_shape.row_ids(1)
# Remove the seq axis.
# Now unique_word_seq has only two axes [path][word]
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
# word_fsa is an FsaVec with axes [path][state][arc]
word_fsa = k2.linear_fsa(unique_word_seq)
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
am_scores, ngram_lm_scores = compute_am_and_lm_scores(
lattice, word_fsa_with_epsilon_loops, path_to_seq_map
)
# Now we use the attention decoder to compute another
# score: attention_scores.
#
# To do that, we have to get the input and output for the attention
# decoder.
# CAUTION: The "tokens" attribute is set in the file
# local/compile_hlg.py
token_seq = k2.index(lattice.tokens, path)
# Remove epsilons and -1 from token_seq
token_seq = k2.ragged.remove_values_leq(token_seq, 0)
# Remove the seq axis.
token_seq = k2.ragged.remove_axis(token_seq, 0)
token_seq, _ = k2.ragged.index(
token_seq, indexes=new2old, axis=0, need_value_indexes=False
)
# Now word in unique_word_seq has its corresponding token IDs.
token_ids = k2.ragged.to_list(token_seq)
num_word_seqs = new2old.numel()
path_to_seq_map_long = path_to_seq_map.to(torch.long)
expanded_memory = memory.index_select(1, path_to_seq_map_long)
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_seq_map_long
)
nll = model.decoder_nll(
expanded_memory, expanded_memory_key_padding_mask, token_ids
)
assert nll.ndim == 2
assert nll.shape[0] == num_word_seqs
attention_scores = -nll.sum(dim=1)
assert attention_scores.ndim == 1
assert attention_scores.numel() == num_word_seqs
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
path_2axes = k2.ragged.remove_axis(path, 0)
ans = dict()
for n_scale in ngram_lm_scale_list:
for a_scale in attention_scale_list:
tot_scores = (
am_scores
+ n_scale * ngram_lm_scores
+ a_scale * attention_scores
)
ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores)
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
best_path_indexes = k2.index(new2old, argmax_indexes)
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
best_path = k2.index(path_2axes, best_path_indexes)
# labels is a k2.RaggedInt with 2 axes [path][token_id]
# Note that it contains -1s.
labels = k2.index(lattice.labels.contiguous(), best_path)
labels = k2.ragged.remove_values_eq(labels, -1)
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
# aux_labels is also a k2.RaggedInt with 2 axes
aux_labels = k2.index(lattice.aux_labels, best_path.values())
best_path_fsa = k2.linear_fsa(labels)
best_path_fsa.aux_labels = aux_labels
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
ans[key] = best_path_fsa
return ans