Add rescoring with attention decoder.

This commit is contained in:
Fangjun Kuang 2021-09-18 13:28:02 +08:00
parent b0b942b355
commit 38cfd06ccb
3 changed files with 157 additions and 15 deletions

View File

@ -33,9 +33,9 @@ from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import get_lattice from icefall.decode import get_lattice
from icefall.decode import ( from icefall.decode import (
one_best_decoding, # done one_best_decoding, # done
rescore_with_attention_decoder, rescore_with_attention_decoder, # done
rescore_with_n_best_list, # done rescore_with_n_best_list, # done
rescore_with_whole_lattice, rescore_with_whole_lattice, # done
nbest_oracle, # done nbest_oracle, # done
) )
from icefall.decode2 import ( from icefall.decode2 import (
@ -43,6 +43,7 @@ from icefall.decode2 import (
nbest_oracle as nbest_oracle2, nbest_oracle as nbest_oracle2,
rescore_with_n_best_list as rescore_with_n_best_list2, rescore_with_n_best_list as rescore_with_n_best_list2,
rescore_with_whole_lattice as rescore_with_whole_lattice2, rescore_with_whole_lattice as rescore_with_whole_lattice2,
rescore_with_attention_decoder as rescore_with_attention_decoder2,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
@ -340,16 +341,28 @@ def decode_one_batch(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
) )
best_path_dict = rescore_with_attention_decoder( if True:
lattice=rescored_lattice, best_path_dict = rescore_with_attention_decoder2(
num_paths=params.num_paths, lattice=rescored_lattice,
model=model, num_paths=params.num_paths,
memory=memory, model=model,
memory_key_padding_mask=memory_key_padding_mask, memory=memory,
sos_id=sos_id, memory_key_padding_mask=memory_key_padding_mask,
eos_id=eos_id, sos_id=sos_id,
scale=params.lattice_score_scale, eos_id=eos_id,
) lattice_score_scale=params.lattice_score_scale,
)
else:
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
scale=params.lattice_score_scale,
)
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"

View File

@ -857,13 +857,15 @@ def rescore_with_attention_decoder(
assert attention_scores.numel() == num_word_seqs assert attention_scores.numel() == num_word_seqs
if ngram_lm_scale is None: if ngram_lm_scale is None:
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] ngram_lm_scale_list = [0.01, 0.05, 0.08]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else: else:
ngram_lm_scale_list = [ngram_lm_scale] ngram_lm_scale_list = [ngram_lm_scale]
if attention_scale is None: if attention_scale is None:
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else: else:
attention_scale_list = [attention_scale] attention_scale_list = [attention_scale]

View File

@ -310,6 +310,9 @@ class Nbest(object):
Hint: Hint:
`self.fsa.scores` contains two parts: am scores and lm scores. `self.fsa.scores` contains two parts: am scores and lm scores.
Caution:
We require that ``self.fsa`` has an attribute ``lm_scores``.
Returns: Returns:
Return a ragged tensor with 2 axes [utt][path_scores]. Return a ragged tensor with 2 axes [utt][path_scores].
Its dtype is torch.float64. Its dtype is torch.float64.
@ -326,6 +329,35 @@ class Nbest(object):
return k2.RaggedTensor(self.shape, am_scores) return k2.RaggedTensor(self.shape, am_scores)
def compute_lm_scores(self) -> k2.RaggedTensor:
"""Compute LM scores of each linear FSA (i.e., each path within
an utterance).
Hint:
`self.fsa.scores` contains two parts: am scores and lm scores.
Caution:
We require that ``self.fsa`` has an attribute ``lm_scores``.
Returns:
Return a ragged tensor with 2 axes [utt][path_scores].
Its dtype is torch.float64.
"""
saved_scores = self.fsa.scores
# The `scores` of every arc consists of `am_scores` and `lm_scores`
self.fsa.scores = self.fsa.lm_scores
# Caution: self.fsa.lm_scores is per arc
# while lm_scores in the following is per path
#
lm_scores = self.fsa.get_tot_scores(
use_double_scores=True, log_semiring=False
)
self.fsa.scores = saved_scores
return k2.RaggedTensor(self.shape, lm_scores)
def tot_scores(self) -> k2.RaggedTensor: def tot_scores(self) -> k2.RaggedTensor:
"""Get total scores of the FSAs in this Nbest. """Get total scores of the FSAs in this Nbest.
@ -547,7 +579,7 @@ def rescore_with_n_best_list(
# nbest.fsa.scores are all 0s at this point # nbest.fsa.scores are all 0s at this point
nbest = nbest.intersect(lattice) nbest = nbest.intersect(lattice)
# Now nbest.fsa has it scores set # Now nbest.fsa has its scores set
assert hasattr(nbest.fsa, "lm_scores") assert hasattr(nbest.fsa, "lm_scores")
am_scores = nbest.compute_am_scores() am_scores = nbest.compute_am_scores()
@ -639,3 +671,98 @@ def rescore_with_whole_lattice(
key = f"lm_scale_{lm_scale}_yy" key = f"lm_scale_{lm_scale}_yy"
ans[key] = best_path ans[key] = best_path
return ans return ans
def rescore_with_attention_decoder(
lattice: k2.Fsa,
num_paths: int,
model: torch.nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int,
eos_id: int,
lattice_score_scale: float = 1.0,
ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale,
)
# nbest.fsa.scores are all 0s at this point
nbest = nbest.intersect(lattice)
# Now nbest.fsa has its scores set.
# Also, nbest.fsa inherits the attributes from `lattice`.
assert hasattr(nbest.fsa, "lm_scores")
am_scores = nbest.compute_am_scores()
ngram_lm_scores = nbest.compute_lm_scores()
# The `tokens` attribute is set inside `compile_hlg.py`
assert hasattr(nbest.fsa, "tokens")
assert isinstance(nbest.fsa.tokens, torch.Tensor)
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
# the shape of memory is (T, N, C), so we use axis=1 here
expanded_memory = memory.index_select(1, path_to_utt_map)
if memory_key_padding_mask is not None:
# The shape of memory_key_padding_mask is (N, T), so we
# use axis=0 here.
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_utt_map
)
else:
expanded_memory_key_padding_mask = None
# remove axis corresponding to states.
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
tokens = tokens.remove_values_leq(0)
token_ids = tokens.tolist()
nll = model.decoder_nll(
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
token_ids=token_ids,
sos_id=sos_id,
eos_id=eos_id,
)
assert nll.ndim == 2
assert nll.shape[0] == len(token_ids)
attention_scores = -nll.sum(dim=1)
if ngram_lm_scale is None:
ngram_lm_scale_list = [0.01, 0.05, 0.08]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
ngram_lm_scale_list = [ngram_lm_scale]
if attention_scale is None:
attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
attention_scale_list = [attention_scale]
ans = dict()
for n_scale in ngram_lm_scale_list:
for a_scale in attention_scale_list:
tot_scores = (
am_scores.values
+ n_scale * ngram_lm_scores.values
+ a_scale * attention_scores
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
ans[key] = best_path
return ans