Add attention rescore pipeline

This commit is contained in:
pkufool 2021-08-09 12:47:11 +08:00
parent 286dce7b0f
commit 0669aa8ab9
5 changed files with 317 additions and 108 deletions

View File

@ -100,6 +100,7 @@ def decode_one_batch(
model: nn.Module,
HLG: k2.Fsa,
batch: dict,
batch_idx: int,
lexicon: Lexicon,
sos_id: int,
eos_id: int,
@ -201,6 +202,7 @@ def decode_one_batch(
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
"attention-decoder-v2",
]
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
@ -232,6 +234,23 @@ def decode_one_batch(
sos_id=sos_id,
eos_id=eos_id,
)
elif params.method == "attention-decoder-v2":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
)
best_path_dict = rescore_with_attention_decoder_v2(
lattice=rescored_lattice,
batch_idx=batch_idx,
dump_best_matching_feature=params.dump_feature,
num_paths=params.num_paths,
top_k=params.top_k,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
@ -295,6 +314,7 @@ def decode_dataset(
model=model,
HLG=HLG,
batch=batch,
batch_idx,
lexicon=lexicon,
G=G,
sos_id=sos_id,

View File

@ -25,7 +25,7 @@ stop_stage=100
# - librispeech-vocab.txt
# - librispeech-lexicon.txt
#
# - $do_dir/musan
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#

View File

@ -721,3 +721,248 @@ def rescore_with_attention_decoder(
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
ans[key] = best_path_fsa
return ans
def rescore_nbest_with_attention_decoder(
nbest: Nbest,
model: nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
sos_id: int,
eos_id: int,
) -> Nbest:
"""This function rescores an nbest list with an attention decoder. The paths
with rescored scores are returned as a new nbest.
Args:
nbest:
An Nbest, the nbest path of given sequences.
It can be the return value of :func:`generate_nbest_list`.
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `[T, N, C]`.
memory_key_padding_mask:
The padding mask for memory with shape [N, T].
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each sequence in the lattice.
"""
num_seqs = nbest.shape.Dim0()
token_seq = k2.RaggedInt(nbest.shape, nbest.fsas.labels().contiguous())
# Remove -1 from token_seq, there is no epsilon tokens in token_seq, we
# removed it when generating nbest list
token_seq = k2.ragged.remove_values_leq(token_seq, -1)
token_ids = k2.ragged.to_list(token_seq)
path_to_seq_map_long = token_seq.shape.row_ids(1).to(torch.long)
expanded_memory = memory.index_select(1, path_to_seq_map_long)
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_seq_map_long
)
# TODO: pass the sos_token_id and eos_token_id via function arguments
nll = model.decoder_nll(
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
token_ids=token_ids,
sos_id=sos_id,
eos_id=eos_id,
)
assert nll.ndim == 2
assert nll.shape[0] == num_seqs
attention_scores = torch.zeros(
nbest.fsas.labels().size()[0],
dtype=torch.float32,
device=nbest.device
)
start_index = 0
for i in range(num_seqs):
# Plus 1 to fill the score of final arc
tokens_num = len(tokens_ids[i]) + 1
attention_scores[start_index: start_index + tokens_num] =
nll[i][0: tokens_num]
start_index += tokens_num
fsas = nbest.fsas.clone()
fsas.score = attention_scores
return Nbest(fsas, nbest.shape.clone())
def rescore_with_attention_decoder_v2(
lattice: k2.Fsa,
batch_idx: int,
dump_best_matching_feature: bool,
num_paths: int,
top_k: int,
model: nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
sos_id: int,
eos_id: int,
) -> Dict[str, k2.Fsa]:
"""This function extracts n paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest
score is used as the decoding output.
Args:
lattice:
An FsaVec. It can be the return value of :func:`get_lattice`.
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `[T, N, C]`.
memory_key_padding_mask:
The padding mask for memory with shape [N, T].
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each sequence in the lattice.
"""
nbest = generate_nbest_list(lattice, num_paths)
# Now we have nbest with scores
nbest = nbest.intersect(lattice)
if dump_best_matching_feature:
nbest_k, nbest_q = nbest.split(k=top_k, sort=False)
rescored_nbest_k = rescore_nbest_with_attention_decoder(
nbest=nbest_k,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
)
stats_tensor = get_best_matching_stats(
rescored_nbest_k,
nbest_q,
max_order=3
)
rescored_nbest_q = rescore_nbest_with_attention_decoder(
nbest=nbest_q,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
# return feature & label or dump to file
nbest_topk, nbest_remain = nbest.split(k=top_k)
rescored_nbest_topk = rescore_nbest_with_attention_decoder(
nbest=nbest_topk,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
)
stats_tensor = get_best_matching_stats(
rescored_nbest_topk,
nbest_remain,
max_order=3
)
# run rescore estimation model to get the mean and var of each token
mean, var = rescore_est_model(stats_tensor)
# calculate nbest_remain estimated score and select topk
nbest_remain_topk = nbest_remain.top_k(k=top_k)
rescored_nbest_remain_topk = rescore_nbest_with_attention_decoder(
nbest=nbest_remain_topk,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
)
best_path_dict=get_best_path_from_nbests(
rescored_nbest_topk,
rescored_nbest_remain_topk,
)
return ans
def generate_nbest_list(
lats: k2.Fsa,
num_paths: int,
aux_labels: bool = False
) -> Nbest:
'''Generate an n-best list from a lattice.
Args:
lats:
The decoding lattice from the first pass after LM rescoring.
lats is an FsaVec. It can be the return value of
:func:`rescore_with_whole_lattice`
num_paths:
Size of n for n-best list. CAUTION: After removing paths
that represent the same word sequences, the number of paths
in different sequences may not be equal.
Return:
Return an Nbest object. Note the returned FSAs don't have epsilon
self-loops.
'''
assert len(lats.shape) == 3
# First, extract `num_paths` paths for each sequence.
# paths is a k2.RaggedInt with axes [seq][path][arc_pos]
paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
# Seqs is a k2.RaggedInt sharing the same shape as `paths`.
# Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.
# Its axes are [seq][path][word_id]
if aux_labels:
# if aux_labels enable, seqs contains word_id
assert hasattr(lats, "aux_labels")
seqs = k2.index(lats.aux_labels, paths)
else:
# CAUTION: We use `phones` instead of `tokens` here because
# :func:`compile_HLG` uses `phones`
#
# Note: compile_HLG is from k2-fsa/snowfall
assert hasattr(lats, 'phones')
assert not hasattr(lats, 'tokens')
lats.tokens = lats.phones
seqs = k2.index(lats.tokens, paths)
# Remove epsilons (0s) and -1 from word_seqs
seqs = k2.ragged.remove_values_leq(seqs, 0)
# unique_word_seqs is still a k2.RaggedInt with axes [seq][path][word_id].
# But then number of pathsin each sequence may be different.
unique_seqs, _, _ = k2.ragged.unique_sequences(
seqs, need_num_repeats=False, need_new2old_indexes=False)
seq_to_path_shape = k2.ragged.get_layer(unique_seqs.shape(), 0)
# Remove the seq axis.
# Now unique_word_seqs has only two axes [path][word_id]
unique_seqs = k2.ragged.remove_axis(unique_seqs, 0)
fsas = k2.linear_fsa(unique_seqs)
return Nbest(fsa=fsas, shape=seq_to_path_shape)

View File

@ -5,10 +5,9 @@
# See https://github.com/k2-fsa/snowfall/issues/232 for more details
#
import logging
from typing import List
from typing import List, Tuple
import torch
import _k2
import k2
# Note: We use `utterance` and `sequence` interchangeably in the comment
@ -19,7 +18,7 @@ class Nbest(object):
An Nbest object contains two fields:
(1) fsa, its type is k2.Fsa
(2) shape, its type is k2.RaggedShape (alias to _k2.RaggedShape)
(2) shape, its type is k2.RaggedShape
The field `fsa` is an FsaVec containing a vector of **linear** FSAs.
@ -29,7 +28,7 @@ class Nbest(object):
of paths, which is also the number of FSAs in `fsa`.
'''
def __init__(self, fsa: k2.Fsa, shape: _k2.RaggedShape) -> None:
def __init__(self, fsa: k2.Fsa, shape: k2.RaggedShape) -> None:
assert len(fsa.shape) == 3, f'fsa.shape: {fsa.shape}'
assert shape.num_axes() == 2, f'num_axes: {shape.num_axes()}'
@ -85,7 +84,7 @@ class Nbest(object):
return Nbest(fsa=one_best, shape=self.shape)
def total_scores(self) -> _k2.RaggedFloat:
def total_scores(self) -> k2.RaggedFloat:
'''Get total scores of the FSAs in this Nbest.
Note:
@ -99,7 +98,7 @@ class Nbest(object):
log_semiring=False)
# We use single precision here since we only wrap k2.RaggedFloat.
# If k2.RaggedDouble is wrapped, we can use double precision here.
return _k2.RaggedFloat(self.shape, scores.float())
return k2.RaggedFloat(self.shape, scores.float())
def top_k(self, k: int) -> 'Nbest':
'''Get a subset of paths in the Nbest. The resulting Nbest is regular
@ -144,121 +143,66 @@ class Nbest(object):
return Nbest(top_k_fsas, top_k_shape)
def whole_lattice_rescoring(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa) -> k2.Fsa:
'''Rescore the 1st pass lattice with an LM.
def split(self, k: int, sort: bool = True) -> Tuple['Nbest', 'Nbest']:
'''Split the paths in the Nbest into two parts, the first part is the
first k paths for each sequence in the Nbest, the second part is the
remaining paths.
There may be less than k paths for the responding sequence in the part,
In general, the G in HLG used to obtain `lats` is a 3-gram LM.
This function replaces the 3-gram LM in `lats` with a 4-gram LM.
If the sort flag is true, we select the top-k paths according to the
total_scores of each path in descending order, If a utterance has less
than k paths, then the first part will have the really number of paths
and leaving the second part empty.
Args:
lats:
The decoding lattice from the 1st pass. We assume it is the result
of intersecting HLG with the network output.
G_with_epsilon_loops:
An LM. It is usually a 4-gram LM with epsilon self-loops.
It should be arc sorted.
k:
Number of paths in the first part of each utterance.
Returns:
Return a new lattice rescored with a given G.
Return a tuple of new Nbest.
'''
assert len(lats.shape) == 3, f'{lats.shape}'
assert hasattr(lats, 'lm_scores')
assert G_with_epsilon_loops.shape == (1, None, None), \
f'{G_with_epsilon_loops.shape}'
# indexes contains idx01's for self.shape
indexes = torch.arange(
self.shape.num_elements(), dtype=torch.int32,
device=self.shape.device
)
device = lats.device
lats.scores = lats.scores - lats.lm_scores
# Now lats contains only acoustic scores
if sort:
ragged_scores = self.total_scores()
# We will use lm_scores from the given G, so remove lats.lm_scores here
del lats.lm_scores
assert hasattr(lats, 'lm_scores') is False
# ragged_scores.values()[indexes] is sorted
indexes = k2.ragged.sort_sublist(
ragged_scores, descending=True, need_new2old_indexes=True
)
# inverted_lats has word IDs as labels.
# Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt
# if lats.aux_labels is a ragged tensor
inverted_lats = k2.invert(lats)
num_seqs = lats.shape[0]
ragged_indexes = k2.RaggedInt(self.shape, indexes)
b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32)
padded_indexes = k2.ragged.pad(ragged_indexes, value=-1)
while True:
try:
rescoring_lats = k2.intersect_device(G_with_epsilon_loops,
inverted_lats,
b_to_a_map,
sorted_match_a=True)
break
except RuntimeError as e:
logging.info(f'Caught exception:\n{e}\n')
# Usually, this is an OOM exception. We reduce
# the size of the lattice and redo k2.intersect_device()
# Select the idx01's of top-k paths of each utterance
first_indexes = padded_indexes[:, :k].flatten().contiguous()
# NOTE(fangjun): The choice of the threshold 1e-5 is arbitrary here
# to avoid OOM. We may need to fine tune it.
logging.info(f'num_arcs before: {inverted_lats.num_arcs}')
inverted_lats = k2.prune_on_arc_post(inverted_lats, 1e-5, True)
logging.info(f'num_arcs after: {inverted_lats.num_arcs}')
# Remove the padding elements
first_indexes = first_indexes[first_indexes >= 0]
rescoring_lats = k2.top_sort(k2.connect(rescoring_lats))
first_fsas = k2.index_fsa(self.fsa, first_indexes)
# inv_rescoring_lats has token IDs as labels
# and word IDs as aux_labels.
inv_rescoring_lats = k2.invert(rescoring_lats)
return inv_rescoring_lats
first_row_ids = k2.index(self.shape.row_ids(1), first_indexes)
first_shape = k2.ragged.create_ragged_shape2(row_ids=first_row_ids)
first_nbest = Nbest(first_fsas, first_shape)
def generate_nbest_list(lats: k2.Fsa, num_paths: int) -> Nbest:
'''Generate an n-best list from a lattice.
# Select the idx01's of remaining paths of each utterance
second_indexes = padded_indexes[:, k:].flatten().contiguous()
Args:
lats:
The decoding lattice from the first pass after LM rescoring.
lats is an FsaVec. It can be the return value of
:func:`whole_lattice_rescoring`
num_paths:
Size of n for n-best list. CAUTION: After removing paths
that represent the same token sequences, the number of paths
in different sequences may not be equal.
Return:
Return an Nbest object. Note the returned FSAs don't have epsilon
self-loops.
'''
assert len(lats.shape) == 3
# Remove the padding elements
second_indexes = second_indexes[second_indexes >= 0]
# CAUTION: We use `phones` instead of `tokens` here because
# :func:`compile_HLG` uses `phones`
#
# Note: compile_HLG is from k2-fsa/snowfall
assert hasattr(lats, 'phones')
second_fsas = k2.index_fsa(self.fsa, second_indexes)
assert not hasattr(lats, 'tokens')
lats.tokens = lats.phones
# we use tokens instead of phones in the following code
second_row_ids = k2.index(self.shape.row_ids(1), second_indexes)
second_shape = k2.ragged.create_ragged_shape2(row_ids=second_row_ids)
# First, extract `num_paths` paths for each sequence.
# paths is a k2.RaggedInt with axes [seq][path][arc_pos]
paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
second_nbest = Nbest(second_fsas, second_shape)
# token_seqs is a k2.RaggedInt sharing the same shape as `paths`
# but it contains token IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.
# Its axes are [seq][path][token_id]
token_seqs = k2.index(lats.tokens, paths)
return first_nbest, second_nbest
# Remove epsilons (0s) and -1 from token_seqs
token_seqs = k2.ragged.remove_values_leq(token_seqs, 0)
# unique_token_seqs is still a k2.RaggedInt with axes [seq][path]token_id].
# But then number of pathsin each sequence may be different.
unique_token_seqs, _, _ = k2.ragged.unique_sequences(
token_seqs, need_num_repeats=False, need_new2old_indexes=False)
seq_to_path_shape = k2.ragged.get_layer(unique_token_seqs.shape(), 0)
# Remove the seq axis.
# Now unique_token_seqs has only two axes [path][token_id]
unique_token_seqs = k2.ragged.remove_axis(unique_token_seqs, 0)
token_fsas = k2.linear_fsa(unique_token_seqs)
return Nbest(fsa=token_fsas, shape=seq_to_path_shape)

View File

@ -5,7 +5,7 @@ import subprocess
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
from nbest import Nbest
from icefall.nbest import Nbest
from pathlib import Path
from typing import Dict, Iterable, List, TextIO, Tuple, Union