mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Add rescoring with attention decoder.
This commit is contained in:
parent
b0b942b355
commit
38cfd06ccb
@ -33,9 +33,9 @@ from icefall.checkpoint import average_checkpoints, load_checkpoint
|
|||||||
from icefall.decode import get_lattice
|
from icefall.decode import get_lattice
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
one_best_decoding, # done
|
one_best_decoding, # done
|
||||||
rescore_with_attention_decoder,
|
rescore_with_attention_decoder, # done
|
||||||
rescore_with_n_best_list, # done
|
rescore_with_n_best_list, # done
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice, # done
|
||||||
nbest_oracle, # done
|
nbest_oracle, # done
|
||||||
)
|
)
|
||||||
from icefall.decode2 import (
|
from icefall.decode2 import (
|
||||||
@ -43,6 +43,7 @@ from icefall.decode2 import (
|
|||||||
nbest_oracle as nbest_oracle2,
|
nbest_oracle as nbest_oracle2,
|
||||||
rescore_with_n_best_list as rescore_with_n_best_list2,
|
rescore_with_n_best_list as rescore_with_n_best_list2,
|
||||||
rescore_with_whole_lattice as rescore_with_whole_lattice2,
|
rescore_with_whole_lattice as rescore_with_whole_lattice2,
|
||||||
|
rescore_with_attention_decoder as rescore_with_attention_decoder2,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -340,16 +341,28 @@ def decode_one_batch(
|
|||||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
||||||
)
|
)
|
||||||
|
|
||||||
best_path_dict = rescore_with_attention_decoder(
|
if True:
|
||||||
lattice=rescored_lattice,
|
best_path_dict = rescore_with_attention_decoder2(
|
||||||
num_paths=params.num_paths,
|
lattice=rescored_lattice,
|
||||||
model=model,
|
num_paths=params.num_paths,
|
||||||
memory=memory,
|
model=model,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory=memory,
|
||||||
sos_id=sos_id,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
eos_id=eos_id,
|
sos_id=sos_id,
|
||||||
scale=params.lattice_score_scale,
|
eos_id=eos_id,
|
||||||
)
|
lattice_score_scale=params.lattice_score_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
best_path_dict = rescore_with_attention_decoder(
|
||||||
|
lattice=rescored_lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
model=model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
scale=params.lattice_score_scale,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert False, f"Unsupported decoding method: {params.method}"
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
|
|
||||||
|
@ -857,13 +857,15 @@ def rescore_with_attention_decoder(
|
|||||||
assert attention_scores.numel() == num_word_seqs
|
assert attention_scores.numel() == num_word_seqs
|
||||||
|
|
||||||
if ngram_lm_scale is None:
|
if ngram_lm_scale is None:
|
||||||
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||||
|
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
else:
|
else:
|
||||||
ngram_lm_scale_list = [ngram_lm_scale]
|
ngram_lm_scale_list = [ngram_lm_scale]
|
||||||
|
|
||||||
if attention_scale is None:
|
if attention_scale is None:
|
||||||
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
attention_scale_list = [0.01, 0.05, 0.08]
|
||||||
|
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
else:
|
else:
|
||||||
attention_scale_list = [attention_scale]
|
attention_scale_list = [attention_scale]
|
||||||
|
@ -310,6 +310,9 @@ class Nbest(object):
|
|||||||
Hint:
|
Hint:
|
||||||
`self.fsa.scores` contains two parts: am scores and lm scores.
|
`self.fsa.scores` contains two parts: am scores and lm scores.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
We require that ``self.fsa`` has an attribute ``lm_scores``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a ragged tensor with 2 axes [utt][path_scores].
|
Return a ragged tensor with 2 axes [utt][path_scores].
|
||||||
Its dtype is torch.float64.
|
Its dtype is torch.float64.
|
||||||
@ -326,6 +329,35 @@ class Nbest(object):
|
|||||||
|
|
||||||
return k2.RaggedTensor(self.shape, am_scores)
|
return k2.RaggedTensor(self.shape, am_scores)
|
||||||
|
|
||||||
|
def compute_lm_scores(self) -> k2.RaggedTensor:
|
||||||
|
"""Compute LM scores of each linear FSA (i.e., each path within
|
||||||
|
an utterance).
|
||||||
|
|
||||||
|
Hint:
|
||||||
|
`self.fsa.scores` contains two parts: am scores and lm scores.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
We require that ``self.fsa`` has an attribute ``lm_scores``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a ragged tensor with 2 axes [utt][path_scores].
|
||||||
|
Its dtype is torch.float64.
|
||||||
|
"""
|
||||||
|
saved_scores = self.fsa.scores
|
||||||
|
|
||||||
|
# The `scores` of every arc consists of `am_scores` and `lm_scores`
|
||||||
|
self.fsa.scores = self.fsa.lm_scores
|
||||||
|
|
||||||
|
# Caution: self.fsa.lm_scores is per arc
|
||||||
|
# while lm_scores in the following is per path
|
||||||
|
#
|
||||||
|
lm_scores = self.fsa.get_tot_scores(
|
||||||
|
use_double_scores=True, log_semiring=False
|
||||||
|
)
|
||||||
|
self.fsa.scores = saved_scores
|
||||||
|
|
||||||
|
return k2.RaggedTensor(self.shape, lm_scores)
|
||||||
|
|
||||||
def tot_scores(self) -> k2.RaggedTensor:
|
def tot_scores(self) -> k2.RaggedTensor:
|
||||||
"""Get total scores of the FSAs in this Nbest.
|
"""Get total scores of the FSAs in this Nbest.
|
||||||
|
|
||||||
@ -547,7 +579,7 @@ def rescore_with_n_best_list(
|
|||||||
# nbest.fsa.scores are all 0s at this point
|
# nbest.fsa.scores are all 0s at this point
|
||||||
|
|
||||||
nbest = nbest.intersect(lattice)
|
nbest = nbest.intersect(lattice)
|
||||||
# Now nbest.fsa has it scores set
|
# Now nbest.fsa has its scores set
|
||||||
assert hasattr(nbest.fsa, "lm_scores")
|
assert hasattr(nbest.fsa, "lm_scores")
|
||||||
|
|
||||||
am_scores = nbest.compute_am_scores()
|
am_scores = nbest.compute_am_scores()
|
||||||
@ -639,3 +671,98 @@ def rescore_with_whole_lattice(
|
|||||||
key = f"lm_scale_{lm_scale}_yy"
|
key = f"lm_scale_{lm_scale}_yy"
|
||||||
ans[key] = best_path
|
ans[key] = best_path
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def rescore_with_attention_decoder(
|
||||||
|
lattice: k2.Fsa,
|
||||||
|
num_paths: int,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
memory_key_padding_mask: Optional[torch.Tensor],
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
lattice_score_scale: float = 1.0,
|
||||||
|
ngram_lm_scale: Optional[float] = None,
|
||||||
|
attention_scale: Optional[float] = None,
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
) -> Dict[str, k2.Fsa]:
|
||||||
|
nbest = Nbest.from_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
lattice_score_scale=lattice_score_scale,
|
||||||
|
)
|
||||||
|
# nbest.fsa.scores are all 0s at this point
|
||||||
|
|
||||||
|
nbest = nbest.intersect(lattice)
|
||||||
|
# Now nbest.fsa has its scores set.
|
||||||
|
# Also, nbest.fsa inherits the attributes from `lattice`.
|
||||||
|
assert hasattr(nbest.fsa, "lm_scores")
|
||||||
|
|
||||||
|
am_scores = nbest.compute_am_scores()
|
||||||
|
ngram_lm_scores = nbest.compute_lm_scores()
|
||||||
|
|
||||||
|
# The `tokens` attribute is set inside `compile_hlg.py`
|
||||||
|
assert hasattr(nbest.fsa, "tokens")
|
||||||
|
assert isinstance(nbest.fsa.tokens, torch.Tensor)
|
||||||
|
|
||||||
|
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
|
||||||
|
# the shape of memory is (T, N, C), so we use axis=1 here
|
||||||
|
expanded_memory = memory.index_select(1, path_to_utt_map)
|
||||||
|
|
||||||
|
if memory_key_padding_mask is not None:
|
||||||
|
# The shape of memory_key_padding_mask is (N, T), so we
|
||||||
|
# use axis=0 here.
|
||||||
|
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||||
|
0, path_to_utt_map
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
expanded_memory_key_padding_mask = None
|
||||||
|
|
||||||
|
# remove axis corresponding to states.
|
||||||
|
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
|
||||||
|
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
|
||||||
|
tokens = tokens.remove_values_leq(0)
|
||||||
|
token_ids = tokens.tolist()
|
||||||
|
|
||||||
|
nll = model.decoder_nll(
|
||||||
|
memory=expanded_memory,
|
||||||
|
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||||
|
token_ids=token_ids,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
)
|
||||||
|
assert nll.ndim == 2
|
||||||
|
assert nll.shape[0] == len(token_ids)
|
||||||
|
|
||||||
|
attention_scores = -nll.sum(dim=1)
|
||||||
|
|
||||||
|
if ngram_lm_scale is None:
|
||||||
|
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||||
|
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
else:
|
||||||
|
ngram_lm_scale_list = [ngram_lm_scale]
|
||||||
|
|
||||||
|
if attention_scale is None:
|
||||||
|
attention_scale_list = [0.01, 0.05, 0.08]
|
||||||
|
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
else:
|
||||||
|
attention_scale_list = [attention_scale]
|
||||||
|
|
||||||
|
ans = dict()
|
||||||
|
for n_scale in ngram_lm_scale_list:
|
||||||
|
for a_scale in attention_scale_list:
|
||||||
|
tot_scores = (
|
||||||
|
am_scores.values
|
||||||
|
+ n_scale * ngram_lm_scores.values
|
||||||
|
+ a_scale * attention_scores
|
||||||
|
)
|
||||||
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
|
||||||
|
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
||||||
|
ans[key] = best_path
|
||||||
|
return ans
|
||||||
|
Loading…
x
Reference in New Issue
Block a user