mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Use attention decoder for rescoring.
This commit is contained in:
parent
f65854cca5
commit
bd69e4be32
@ -21,6 +21,7 @@ from icefall.decode import (
|
|||||||
get_lattice,
|
get_lattice,
|
||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
|
rescore_with_attention_decoder,
|
||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
@ -82,9 +83,12 @@ def get_params() -> AttributeDict:
|
|||||||
# - nbest
|
# - nbest
|
||||||
# - nbest-rescoring
|
# - nbest-rescoring
|
||||||
# - whole-lattice-rescoring
|
# - whole-lattice-rescoring
|
||||||
"method": "nbest-rescoring",
|
# - attention-decoder
|
||||||
# num_paths is used when method is "nbest" and "nbest-rescoring"
|
# "method": "whole-lattice-rescoring",
|
||||||
"num_paths": 100,
|
"method": "attention-decoder",
|
||||||
|
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||||
|
# and attention-decoder
|
||||||
|
"num_paths": 1000,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
@ -147,7 +151,7 @@ def decode_one_batch(
|
|||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
|
|
||||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||||
# nnet_output is [N, C, T]
|
# nnet_output is [N, C, T]
|
||||||
|
|
||||||
nnet_output = nnet_output.permute(0, 2, 1)
|
nnet_output = nnet_output.permute(0, 2, 1)
|
||||||
@ -191,7 +195,11 @@ def decode_one_batch(
|
|||||||
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
|
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
|
||||||
return {key: hyps}
|
return {key: hyps}
|
||||||
|
|
||||||
assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"]
|
assert params.method in [
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
"attention-decoder",
|
||||||
|
]
|
||||||
|
|
||||||
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||||
@ -203,10 +211,25 @@ def decode_one_batch(
|
|||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
lm_scale_list=lm_scale_list,
|
lm_scale_list=lm_scale_list,
|
||||||
)
|
)
|
||||||
else:
|
elif params.method == "whole-lattice-rescoring":
|
||||||
best_path_dict = rescore_with_whole_lattice(
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
|
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
|
||||||
)
|
)
|
||||||
|
elif params.method == "attention-decoder":
|
||||||
|
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||||
|
rescored_lattice = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
||||||
|
)
|
||||||
|
|
||||||
|
best_path_dict = rescore_with_attention_decoder(
|
||||||
|
lattice=rescored_lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
model=model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
|
|
||||||
ans = dict()
|
ans = dict()
|
||||||
for lm_scale_str, best_path in best_path_dict.items():
|
for lm_scale_str, best_path in best_path_dict.items():
|
||||||
@ -351,7 +374,11 @@ def main():
|
|||||||
if not hasattr(HLG, "lm_scores"):
|
if not hasattr(HLG, "lm_scores"):
|
||||||
HLG.lm_scores = HLG.scores.clone()
|
HLG.lm_scores = HLG.scores.clone()
|
||||||
|
|
||||||
if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]:
|
if params.method in (
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
"attention-decoder",
|
||||||
|
):
|
||||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||||
logging.info("Loading G_4_gram.fst.txt")
|
logging.info("Loading G_4_gram.fst.txt")
|
||||||
logging.warning("It may take 8 minutes.")
|
logging.warning("It may take 8 minutes.")
|
||||||
@ -374,7 +401,7 @@ def main():
|
|||||||
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
||||||
G = k2.Fsa.from_dict(d).to(device)
|
G = k2.Fsa.from_dict(d).to(device)
|
||||||
|
|
||||||
if params.method == "whole-lattice-rescoring":
|
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||||
# Add epsilon self-loops to G as we will compose
|
# Add epsilon self-loops to G as we will compose
|
||||||
# it with the whole lattice later
|
# it with the whole lattice later
|
||||||
G = k2.add_epsilon_self_loops(G)
|
G = k2.add_epsilon_self_loops(G)
|
||||||
|
@ -259,7 +259,7 @@ class Transformer(nn.Module):
|
|||||||
return decoder_loss
|
return decoder_loss
|
||||||
|
|
||||||
def decoder_nll(
|
def decoder_nll(
|
||||||
self, x: Tensor, encoder_mask: Tensor, token_ids: List[int] = None
|
self, x: Tensor, encoder_mask: Tensor, token_ids: List[List[int]] = None
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
9
egs/librispeech/ASR/local/compile_hlg.py
Normal file → Executable file
9
egs/librispeech/ASR/local/compile_hlg.py
Normal file → Executable file
@ -32,7 +32,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
|||||||
"""
|
"""
|
||||||
lexicon = Lexicon(lang_dir)
|
lexicon = Lexicon(lang_dir)
|
||||||
max_token_id = max(lexicon.tokens)
|
max_token_id = max(lexicon.tokens)
|
||||||
print(f"building ctc_top. max_token_id: {max_token_id}")
|
print(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||||
H = k2.ctc_topo(max_token_id)
|
H = k2.ctc_topo(max_token_id)
|
||||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||||
|
|
||||||
@ -86,13 +86,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
|||||||
LG = k2.arc_sort(LG)
|
LG = k2.arc_sort(LG)
|
||||||
|
|
||||||
print("Composing H and LG")
|
print("Composing H and LG")
|
||||||
HLG = k2.compose(H, LG, inner_labels="phones")
|
# CAUTION: The name of the inner_labels is fixed
|
||||||
|
# to `tokens`. If you want to change it, please
|
||||||
|
# also change other places in icefall that are using
|
||||||
|
# it.
|
||||||
|
HLG = k2.compose(H, LG, inner_labels="tokens")
|
||||||
|
|
||||||
print("Connecting LG")
|
print("Connecting LG")
|
||||||
HLG = k2.connect(HLG)
|
HLG = k2.connect(HLG)
|
||||||
|
|
||||||
print("Arc sorting LG")
|
print("Arc sorting LG")
|
||||||
HLG = k2.arc_sort(HLG)
|
HLG = k2.arc_sort(HLG)
|
||||||
|
print(f"HLG.shape: {HLG.shape}")
|
||||||
|
|
||||||
return HLG
|
return HLG
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def _intersect_device(
|
def _intersect_device(
|
||||||
@ -11,7 +12,7 @@ def _intersect_device(
|
|||||||
b_to_a_map: torch.Tensor,
|
b_to_a_map: torch.Tensor,
|
||||||
sorted_match_a: bool,
|
sorted_match_a: bool,
|
||||||
batch_size: int = 50,
|
batch_size: int = 50,
|
||||||
):
|
) -> k2.Fsa:
|
||||||
"""This is a wrapper of k2.intersect_device and its purpose is to split
|
"""This is a wrapper of k2.intersect_device and its purpose is to split
|
||||||
b_fsas into several batches and process each batch separately to avoid
|
b_fsas into several batches and process each batch separately to avoid
|
||||||
CUDA OOM error.
|
CUDA OOM error.
|
||||||
@ -55,7 +56,7 @@ def get_lattice(
|
|||||||
min_active_states: int,
|
min_active_states: int,
|
||||||
max_active_states: int,
|
max_active_states: int,
|
||||||
subsampling_factor: int = 1,
|
subsampling_factor: int = 1,
|
||||||
):
|
) -> k2.Fsa:
|
||||||
"""Get the decoding lattice from a decoding graph and neural
|
"""Get the decoding lattice from a decoding graph and neural
|
||||||
network output.
|
network output.
|
||||||
|
|
||||||
@ -129,7 +130,7 @@ def one_best_decoding(
|
|||||||
|
|
||||||
def nbest_decoding(
|
def nbest_decoding(
|
||||||
lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True
|
lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True
|
||||||
):
|
) -> k2.Fsa:
|
||||||
"""It implements something like CTC prefix beam search using n-best lists.
|
"""It implements something like CTC prefix beam search using n-best lists.
|
||||||
|
|
||||||
The basic idea is to first extra n-best paths from the given lattice,
|
The basic idea is to first extra n-best paths from the given lattice,
|
||||||
@ -253,11 +254,11 @@ def nbest_decoding(
|
|||||||
return best_path_fsa
|
return best_path_fsa
|
||||||
|
|
||||||
|
|
||||||
def compute_am_scores(
|
def compute_am_and_lm_scores(
|
||||||
lattice: k2.Fsa,
|
lattice: k2.Fsa,
|
||||||
word_fsa_with_epsilon_loops: k2.Fsa,
|
word_fsa_with_epsilon_loops: k2.Fsa,
|
||||||
path_to_seq_map: torch.Tensor,
|
path_to_seq_map: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Compute AM scores of n-best lists (represented as word_fsas).
|
"""Compute AM scores of n-best lists (represented as word_fsas).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -272,8 +273,8 @@ def compute_am_scores(
|
|||||||
which sequence the i-th Fsa in word_fsa_with_epsilon_loops belongs to.
|
which sequence the i-th Fsa in word_fsa_with_epsilon_loops belongs to.
|
||||||
path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0().
|
path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0().
|
||||||
Returns:
|
Returns:
|
||||||
Return a 1-D torch.Tensor containing the AM scores of each path.
|
Return a tuple containing two 1-D torch.Tensors: (am_scores, lm_scores).
|
||||||
`ans.numel() == word_fsas_with_epsilon_loops.shape[0]`
|
Each tensor's `numel()' equals to `word_fsas_with_epsilon_loops.shape[0]`
|
||||||
"""
|
"""
|
||||||
assert len(lattice.shape) == 3
|
assert len(lattice.shape) == 3
|
||||||
assert hasattr(lattice, "lm_scores")
|
assert hasattr(lattice, "lm_scores")
|
||||||
@ -293,23 +294,29 @@ def compute_am_scores(
|
|||||||
del inv_lattice.aux_labels
|
del inv_lattice.aux_labels
|
||||||
inv_lattice = k2.arc_sort(inv_lattice)
|
inv_lattice = k2.arc_sort(inv_lattice)
|
||||||
|
|
||||||
am_path_lattice = _intersect_device(
|
path_lattice = _intersect_device(
|
||||||
inv_lattice,
|
inv_lattice,
|
||||||
word_fsa_with_epsilon_loops,
|
word_fsa_with_epsilon_loops,
|
||||||
b_to_a_map=path_to_seq_map,
|
b_to_a_map=path_to_seq_map,
|
||||||
sorted_match_a=True,
|
sorted_match_a=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
am_path_lattice = k2.top_sort(k2.connect(am_path_lattice))
|
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
||||||
|
|
||||||
# The `scores` of every arc consists of `am_scores` and `lm_scores`
|
# The `scores` of every arc consists of `am_scores` and `lm_scores`
|
||||||
am_path_lattice.scores = am_path_lattice.scores - am_path_lattice.lm_scores
|
path_lattice.scores = path_lattice.scores - path_lattice.lm_scores
|
||||||
|
|
||||||
am_scores = am_path_lattice.get_tot_scores(
|
am_scores = path_lattice.get_tot_scores(
|
||||||
use_double_scores=True, log_semiring=False
|
use_double_scores=True, log_semiring=False
|
||||||
)
|
)
|
||||||
|
|
||||||
return am_scores
|
path_lattice.scores = path_lattice.lm_scores
|
||||||
|
|
||||||
|
lm_scores = path_lattice.get_tot_scores(
|
||||||
|
use_double_scores=True, log_semiring=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return am_scores.to(torch.float32), lm_scores.to(torch.float32)
|
||||||
|
|
||||||
|
|
||||||
def rescore_with_n_best_list(
|
def rescore_with_n_best_list(
|
||||||
@ -395,7 +402,7 @@ def rescore_with_n_best_list(
|
|||||||
|
|
||||||
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
|
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
|
||||||
|
|
||||||
am_scores = compute_am_scores(
|
am_scores, _ = compute_am_and_lm_scores(
|
||||||
lattice, word_fsa_with_epsilon_loops, path_to_seq_map
|
lattice, word_fsa_with_epsilon_loops, path_to_seq_map
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -456,8 +463,10 @@ def rescore_with_n_best_list(
|
|||||||
|
|
||||||
|
|
||||||
def rescore_with_whole_lattice(
|
def rescore_with_whole_lattice(
|
||||||
lattice: k2.Fsa, G_with_epsilon_loops: k2.Fsa, lm_scale_list: List[float]
|
lattice: k2.Fsa,
|
||||||
) -> Dict[str, k2.Fsa]:
|
G_with_epsilon_loops: k2.Fsa,
|
||||||
|
lm_scale_list: Optional[List[float]] = None,
|
||||||
|
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
|
||||||
"""Use whole lattice to rescore.
|
"""Use whole lattice to rescore.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -467,10 +476,13 @@ def rescore_with_whole_lattice(
|
|||||||
An FsaVec representing the language model (LM). Note that it
|
An FsaVec representing the language model (LM). Note that it
|
||||||
is an FsaVec, but it contains only one Fsa.
|
is an FsaVec, but it contains only one Fsa.
|
||||||
lm_scale_list:
|
lm_scale_list:
|
||||||
A list containing lm_scale values.
|
A list containing lm_scale values or None.
|
||||||
Returns:
|
Returns:
|
||||||
A dict of FsaVec, whose key is a lm_scale and the value represents the
|
If lm_scale_list is not None, return a dict of FsaVec, whose key
|
||||||
best decoding path for each sequence in the lattice.
|
is a lm_scale and the value represents the best decoding path for
|
||||||
|
each sequence in the lattice.
|
||||||
|
If lm_scale_list is not None, return a lattice that is rescored
|
||||||
|
with the given LM.
|
||||||
"""
|
"""
|
||||||
assert len(lattice.shape) == 3
|
assert len(lattice.shape) == 3
|
||||||
assert hasattr(lattice, "lm_scores")
|
assert hasattr(lattice, "lm_scores")
|
||||||
@ -517,6 +529,9 @@ def rescore_with_whole_lattice(
|
|||||||
# and word IDs as aux_labels.
|
# and word IDs as aux_labels.
|
||||||
lat = k2.invert(rescoring_lattice)
|
lat = k2.invert(rescoring_lattice)
|
||||||
|
|
||||||
|
if lm_scale_list is None:
|
||||||
|
return lat
|
||||||
|
|
||||||
ans = dict()
|
ans = dict()
|
||||||
#
|
#
|
||||||
# The following implements
|
# The following implements
|
||||||
@ -532,3 +547,165 @@ def rescore_with_whole_lattice(
|
|||||||
key = f"lm_scale_{lm_scale}"
|
key = f"lm_scale_{lm_scale}"
|
||||||
ans[key] = best_path
|
ans[key] = best_path
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def rescore_with_attention_decoder(
|
||||||
|
lattice: k2.Fsa,
|
||||||
|
num_paths: int,
|
||||||
|
model: nn.Module,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
memory_key_padding_mask: torch.Tensor,
|
||||||
|
) -> Dict[str, k2.Fsa]:
|
||||||
|
"""This function extracts n paths from the given lattice and uses
|
||||||
|
an attention decoder to rescore them. The path with the highest
|
||||||
|
score is used as the decoding output.
|
||||||
|
|
||||||
|
lattice:
|
||||||
|
An FsaVec. It can be the return value of :func:`get_lattice`.
|
||||||
|
num_paths:
|
||||||
|
Number of paths to extract from the given lattice for rescoring.
|
||||||
|
model:
|
||||||
|
A transformer model. See the class "Transformer" in
|
||||||
|
conformer_ctc/transformer.py for its interface.
|
||||||
|
memory:
|
||||||
|
The encoder memory of the given model. It is the output of
|
||||||
|
the last torch.nn.TransformerEncoder layer in the given model.
|
||||||
|
Its shape is `[T, N, C]`.
|
||||||
|
memory_key_padding_mask:
|
||||||
|
The padding mask for memory with shape [N, T].
|
||||||
|
Returns:
|
||||||
|
A dict of FsaVec, whose key contains a string
|
||||||
|
ngram_lm_scale_attention_scale and the value is the
|
||||||
|
best decoding path for each sequence in the lattice.
|
||||||
|
"""
|
||||||
|
# First, extract `num_paths` paths for each sequence.
|
||||||
|
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
|
||||||
|
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
|
||||||
|
|
||||||
|
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
||||||
|
# but it contains word IDs. Note that it also contains 0s and -1s.
|
||||||
|
# The last entry in each sublist is -1.
|
||||||
|
word_seq = k2.index(lattice.aux_labels, path)
|
||||||
|
|
||||||
|
# Remove epsilons and -1 from word_seq
|
||||||
|
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
|
||||||
|
|
||||||
|
# Remove paths that has identical word sequences.
|
||||||
|
#
|
||||||
|
# unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word]
|
||||||
|
# except that there are no repeated paths with the same word_seq
|
||||||
|
# within a sequence.
|
||||||
|
#
|
||||||
|
# num_repeats is also a k2.RaggedInt with 2 axes containing the
|
||||||
|
# multiplicities of each path.
|
||||||
|
# num_repeats.num_elements() == unique_word_seqs.num_elements()
|
||||||
|
#
|
||||||
|
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
||||||
|
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
||||||
|
# to the input path index.
|
||||||
|
# new2old.numel() == unique_word_seqs.tot_size(1)
|
||||||
|
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
|
||||||
|
word_seq, need_num_repeats=True, need_new2old_indexes=True
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0)
|
||||||
|
|
||||||
|
# path_to_seq_map is a 1-D torch.Tensor.
|
||||||
|
# path_to_seq_map[i] is the seq to which the i-th path
|
||||||
|
# belongs.
|
||||||
|
path_to_seq_map = seq_to_path_shape.row_ids(1)
|
||||||
|
|
||||||
|
# Remove the seq axis.
|
||||||
|
# Now unique_word_seq has only two axes [path][word]
|
||||||
|
unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0)
|
||||||
|
|
||||||
|
# word_fsa is an FsaVec with axes [path][state][arc]
|
||||||
|
word_fsa = k2.linear_fsa(unique_word_seq)
|
||||||
|
|
||||||
|
word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa)
|
||||||
|
|
||||||
|
am_scores, ngram_lm_scores = compute_am_and_lm_scores(
|
||||||
|
lattice, word_fsa_with_epsilon_loops, path_to_seq_map
|
||||||
|
)
|
||||||
|
# Now we use the attention decoder to compute another
|
||||||
|
# score: attention_scores.
|
||||||
|
#
|
||||||
|
# To do that, we have to get the input and output for the attention
|
||||||
|
# decoder.
|
||||||
|
|
||||||
|
# CAUTION: The "tokens" attribute is set in the file
|
||||||
|
# local/compile_hlg.py
|
||||||
|
token_seq = k2.index(lattice.tokens, path)
|
||||||
|
|
||||||
|
# Remove epsilons and -1 from token_seq
|
||||||
|
token_seq = k2.ragged.remove_values_leq(token_seq, 0)
|
||||||
|
|
||||||
|
# Remove the seq axis.
|
||||||
|
token_seq = k2.ragged.remove_axis(token_seq, 0)
|
||||||
|
|
||||||
|
token_seq, _ = k2.ragged.index(
|
||||||
|
token_seq, indexes=new2old, axis=0, need_value_indexes=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now word in unique_word_seq has its corresponding token IDs.
|
||||||
|
token_ids = k2.ragged.to_list(token_seq)
|
||||||
|
|
||||||
|
num_word_seqs = new2old.numel()
|
||||||
|
|
||||||
|
path_to_seq_map_long = path_to_seq_map.to(torch.long)
|
||||||
|
expanded_memory = memory.index_select(1, path_to_seq_map_long)
|
||||||
|
|
||||||
|
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||||
|
0, path_to_seq_map_long
|
||||||
|
)
|
||||||
|
|
||||||
|
nll = model.decoder_nll(
|
||||||
|
expanded_memory, expanded_memory_key_padding_mask, token_ids
|
||||||
|
)
|
||||||
|
assert nll.ndim == 2
|
||||||
|
assert nll.shape[0] == num_word_seqs
|
||||||
|
|
||||||
|
attention_scores = -nll.sum(dim=1)
|
||||||
|
assert attention_scores.ndim == 1
|
||||||
|
assert attention_scores.numel() == num_word_seqs
|
||||||
|
|
||||||
|
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
|
||||||
|
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
|
||||||
|
path_2axes = k2.ragged.remove_axis(path, 0)
|
||||||
|
|
||||||
|
ans = dict()
|
||||||
|
for n_scale in ngram_lm_scale_list:
|
||||||
|
for a_scale in attention_scale_list:
|
||||||
|
tot_scores = (
|
||||||
|
am_scores
|
||||||
|
+ n_scale * ngram_lm_scores
|
||||||
|
+ a_scale * attention_scores
|
||||||
|
)
|
||||||
|
ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores)
|
||||||
|
argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores)
|
||||||
|
|
||||||
|
best_path_indexes = k2.index(new2old, argmax_indexes)
|
||||||
|
|
||||||
|
# best_path is a k2.RaggedInt with 2 axes [path][arc_pos]
|
||||||
|
best_path = k2.index(path_2axes, best_path_indexes)
|
||||||
|
|
||||||
|
# labels is a k2.RaggedInt with 2 axes [path][token_id]
|
||||||
|
# Note that it contains -1s.
|
||||||
|
labels = k2.index(lattice.labels.contiguous(), best_path)
|
||||||
|
|
||||||
|
labels = k2.ragged.remove_values_eq(labels, -1)
|
||||||
|
|
||||||
|
# lattice.aux_labels is a k2.RaggedInt tensor with 2 axes, so
|
||||||
|
# aux_labels is also a k2.RaggedInt with 2 axes
|
||||||
|
aux_labels = k2.index(lattice.aux_labels, best_path.values())
|
||||||
|
|
||||||
|
best_path_fsa = k2.linear_fsa(labels)
|
||||||
|
best_path_fsa.aux_labels = aux_labels
|
||||||
|
|
||||||
|
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
||||||
|
ans[key] = best_path_fsa
|
||||||
|
return ans
|
||||||
|
Loading…
x
Reference in New Issue
Block a user