mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Add rescoring with attention decoder.
This commit is contained in:
parent
b0b942b355
commit
38cfd06ccb
@ -33,9 +33,9 @@ from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import get_lattice
|
||||
from icefall.decode import (
|
||||
one_best_decoding, # done
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_attention_decoder, # done
|
||||
rescore_with_n_best_list, # done
|
||||
rescore_with_whole_lattice,
|
||||
rescore_with_whole_lattice, # done
|
||||
nbest_oracle, # done
|
||||
)
|
||||
from icefall.decode2 import (
|
||||
@ -43,6 +43,7 @@ from icefall.decode2 import (
|
||||
nbest_oracle as nbest_oracle2,
|
||||
rescore_with_n_best_list as rescore_with_n_best_list2,
|
||||
rescore_with_whole_lattice as rescore_with_whole_lattice2,
|
||||
rescore_with_attention_decoder as rescore_with_attention_decoder2,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
@ -340,6 +341,18 @@ def decode_one_batch(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
||||
)
|
||||
|
||||
if True:
|
||||
best_path_dict = rescore_with_attention_decoder2(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
else:
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
|
@ -857,13 +857,15 @@ def rescore_with_attention_decoder(
|
||||
assert attention_scores.numel() == num_word_seqs
|
||||
|
||||
if ngram_lm_scale is None:
|
||||
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
else:
|
||||
ngram_lm_scale_list = [ngram_lm_scale]
|
||||
|
||||
if attention_scale is None:
|
||||
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list = [0.01, 0.05, 0.08]
|
||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
else:
|
||||
attention_scale_list = [attention_scale]
|
||||
|
@ -310,6 +310,9 @@ class Nbest(object):
|
||||
Hint:
|
||||
`self.fsa.scores` contains two parts: am scores and lm scores.
|
||||
|
||||
Caution:
|
||||
We require that ``self.fsa`` has an attribute ``lm_scores``.
|
||||
|
||||
Returns:
|
||||
Return a ragged tensor with 2 axes [utt][path_scores].
|
||||
Its dtype is torch.float64.
|
||||
@ -326,6 +329,35 @@ class Nbest(object):
|
||||
|
||||
return k2.RaggedTensor(self.shape, am_scores)
|
||||
|
||||
def compute_lm_scores(self) -> k2.RaggedTensor:
|
||||
"""Compute LM scores of each linear FSA (i.e., each path within
|
||||
an utterance).
|
||||
|
||||
Hint:
|
||||
`self.fsa.scores` contains two parts: am scores and lm scores.
|
||||
|
||||
Caution:
|
||||
We require that ``self.fsa`` has an attribute ``lm_scores``.
|
||||
|
||||
Returns:
|
||||
Return a ragged tensor with 2 axes [utt][path_scores].
|
||||
Its dtype is torch.float64.
|
||||
"""
|
||||
saved_scores = self.fsa.scores
|
||||
|
||||
# The `scores` of every arc consists of `am_scores` and `lm_scores`
|
||||
self.fsa.scores = self.fsa.lm_scores
|
||||
|
||||
# Caution: self.fsa.lm_scores is per arc
|
||||
# while lm_scores in the following is per path
|
||||
#
|
||||
lm_scores = self.fsa.get_tot_scores(
|
||||
use_double_scores=True, log_semiring=False
|
||||
)
|
||||
self.fsa.scores = saved_scores
|
||||
|
||||
return k2.RaggedTensor(self.shape, lm_scores)
|
||||
|
||||
def tot_scores(self) -> k2.RaggedTensor:
|
||||
"""Get total scores of the FSAs in this Nbest.
|
||||
|
||||
@ -547,7 +579,7 @@ def rescore_with_n_best_list(
|
||||
# nbest.fsa.scores are all 0s at this point
|
||||
|
||||
nbest = nbest.intersect(lattice)
|
||||
# Now nbest.fsa has it scores set
|
||||
# Now nbest.fsa has its scores set
|
||||
assert hasattr(nbest.fsa, "lm_scores")
|
||||
|
||||
am_scores = nbest.compute_am_scores()
|
||||
@ -639,3 +671,98 @@ def rescore_with_whole_lattice(
|
||||
key = f"lm_scale_{lm_scale}_yy"
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
||||
|
||||
def rescore_with_attention_decoder(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
model: torch.nn.Module,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: Optional[torch.Tensor],
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
lattice_score_scale: float = 1.0,
|
||||
ngram_lm_scale: Optional[float] = None,
|
||||
attention_scale: Optional[float] = None,
|
||||
use_double_scores: bool = True,
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
lattice_score_scale=lattice_score_scale,
|
||||
)
|
||||
# nbest.fsa.scores are all 0s at this point
|
||||
|
||||
nbest = nbest.intersect(lattice)
|
||||
# Now nbest.fsa has its scores set.
|
||||
# Also, nbest.fsa inherits the attributes from `lattice`.
|
||||
assert hasattr(nbest.fsa, "lm_scores")
|
||||
|
||||
am_scores = nbest.compute_am_scores()
|
||||
ngram_lm_scores = nbest.compute_lm_scores()
|
||||
|
||||
# The `tokens` attribute is set inside `compile_hlg.py`
|
||||
assert hasattr(nbest.fsa, "tokens")
|
||||
assert isinstance(nbest.fsa.tokens, torch.Tensor)
|
||||
|
||||
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
|
||||
# the shape of memory is (T, N, C), so we use axis=1 here
|
||||
expanded_memory = memory.index_select(1, path_to_utt_map)
|
||||
|
||||
if memory_key_padding_mask is not None:
|
||||
# The shape of memory_key_padding_mask is (N, T), so we
|
||||
# use axis=0 here.
|
||||
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||
0, path_to_utt_map
|
||||
)
|
||||
else:
|
||||
expanded_memory_key_padding_mask = None
|
||||
|
||||
# remove axis corresponding to states.
|
||||
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
|
||||
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
|
||||
tokens = tokens.remove_values_leq(0)
|
||||
token_ids = tokens.tolist()
|
||||
|
||||
nll = model.decoder_nll(
|
||||
memory=expanded_memory,
|
||||
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
assert nll.ndim == 2
|
||||
assert nll.shape[0] == len(token_ids)
|
||||
|
||||
attention_scores = -nll.sum(dim=1)
|
||||
|
||||
if ngram_lm_scale is None:
|
||||
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
else:
|
||||
ngram_lm_scale_list = [ngram_lm_scale]
|
||||
|
||||
if attention_scale is None:
|
||||
attention_scale_list = [0.01, 0.05, 0.08]
|
||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
else:
|
||||
attention_scale_list = [attention_scale]
|
||||
|
||||
ans = dict()
|
||||
for n_scale in ngram_lm_scale_list:
|
||||
for a_scale in attention_scale_list:
|
||||
tot_scores = (
|
||||
am_scores.values
|
||||
+ n_scale * ngram_lm_scores.values
|
||||
+ a_scale * attention_scores
|
||||
)
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
|
||||
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
Loading…
x
Reference in New Issue
Block a user