Add scale to all nbest based decoding/rescoring methods.

This commit is contained in:
Fangjun Kuang 2021-08-18 18:42:30 +08:00
parent 401c1c5143
commit 38d06049de
2 changed files with 83 additions and 16 deletions

View File

@ -63,7 +63,10 @@ def get_parser():
type=float, type=float,
default=1.0, default=1.0,
help="The scale to be applied to `lattice.scores`." help="The scale to be applied to `lattice.scores`."
"A smaller value results in more unique paths", "It's needed if you use any kinds of n-best based rescoring. "
"Currently, it is used when the decoding method is: nbest, "
"nbest-rescoring, attention-decoder, and nbest-oracle. "
"A smaller value results in more unique paths.",
) )
return parser return parser
@ -96,6 +99,8 @@ def get_params() -> AttributeDict:
# - whole-lattice-rescoring # - whole-lattice-rescoring
# - attention-decoder # - attention-decoder
# - nbest-oracle # - nbest-oracle
# "method": "nbest",
# "method": "nbest-rescoring",
# "method": "whole-lattice-rescoring", # "method": "whole-lattice-rescoring",
"method": "attention-decoder", "method": "attention-decoder",
# "method": "nbest-oracle", # "method": "nbest-oracle",
@ -215,8 +220,9 @@ def decode_one_batch(
lattice=lattice, lattice=lattice,
num_paths=params.num_paths, num_paths=params.num_paths,
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
scale=params.scale,
) )
key = f"no_rescore-{params.num_paths}" key = f"no_rescore-scale-{params.scale}-{params.num_paths}"
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
@ -237,6 +243,7 @@ def decode_one_batch(
G=G, G=G,
num_paths=params.num_paths, num_paths=params.num_paths,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
scale=params.scale,
) )
elif params.method == "whole-lattice-rescoring": elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice( best_path_dict = rescore_with_whole_lattice(
@ -256,6 +263,7 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
scale=params.scale,
) )
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"

View File

@ -9,6 +9,36 @@ import torch.nn as nn
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
def _get_random_paths(
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
scale: float = 1.0,
):
"""
Args:
lattice:
The decoding lattice, returned by :func:`get_lattice`.
num_paths:
It specifies the size `n` in n-best. Note: Paths are selected randomly
and those containing identical word sequences are remove dand only one
of them is kept.
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
Return a k2.RaggedInt with 3 axes [seq][path][arc_pos]
"""
saved_scores = lattice.scores.clone()
lattice.scores *= scale
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
lattice.scores = saved_scores
return path
def _intersect_device( def _intersect_device(
a_fsas: k2.Fsa, a_fsas: k2.Fsa,
b_fsas: k2.Fsa, b_fsas: k2.Fsa,
@ -132,7 +162,10 @@ def one_best_decoding(
def nbest_decoding( def nbest_decoding(
lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
scale: float = 1.0,
) -> k2.Fsa: ) -> k2.Fsa:
"""It implements something like CTC prefix beam search using n-best lists. """It implements something like CTC prefix beam search using n-best lists.
@ -155,12 +188,18 @@ def nbest_decoding(
use_double_scores: use_double_scores:
True to use double precision floating point in the computation. True to use double precision floating point in the computation.
False to use single precision. False to use single precision.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns: Returns:
An FsaVec containing linear FSAs. An FsaVec containing linear FSAs.
""" """
# First, extract `num_paths` paths for each sequence. path = _get_random_paths(
# path is a k2.RaggedInt with axes [seq][path][arc_pos] lattice=lattice,
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) num_paths=num_paths,
use_double_scores=use_double_scores,
scale=scale,
)
# word_seq is a k2.RaggedInt sharing the same shape as `path` # word_seq is a k2.RaggedInt sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s. # but it contains word IDs. Note that it also contains 0s and -1s.
@ -323,7 +362,11 @@ def compute_am_and_lm_scores(
def rescore_with_n_best_list( def rescore_with_n_best_list(
lattice: k2.Fsa, G: k2.Fsa, num_paths: int, lm_scale_list: List[float] lattice: k2.Fsa,
G: k2.Fsa,
num_paths: int,
lm_scale_list: List[float],
scale: float = 1.0,
) -> Dict[str, k2.Fsa]: ) -> Dict[str, k2.Fsa]:
"""Decode using n-best list with LM rescoring. """Decode using n-best list with LM rescoring.
@ -345,6 +388,9 @@ def rescore_with_n_best_list(
It is the size `n` in `n-best` list. It is the size `n` in `n-best` list.
lm_scale_list: lm_scale_list:
A list containing lm_scale values. A list containing lm_scale values.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns: Returns:
A dict of FsaVec, whose key is an lm_scale and the value is the A dict of FsaVec, whose key is an lm_scale and the value is the
best decoding path for each sequence in the lattice. best decoding path for each sequence in the lattice.
@ -359,9 +405,12 @@ def rescore_with_n_best_list(
assert G.device == device assert G.device == device
assert hasattr(G, "aux_labels") is False assert hasattr(G, "aux_labels") is False
# First, extract `num_paths` paths for each sequence. path = _get_random_paths(
# path is a k2.RaggedInt with axes [seq][path][arc_pos] lattice=lattice,
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) num_paths=num_paths,
use_double_scores=True,
scale=scale,
)
# word_seq is a k2.RaggedInt sharing the same shape as `path` # word_seq is a k2.RaggedInt sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s. # but it contains word IDs. Note that it also contains 0s and -1s.
@ -587,11 +636,12 @@ def nbest_oracle(
when calling this function, while its value contains the decoding output. when calling this function, while its value contains the decoding output.
`len(ans_dict) == len(ref_texts)` `len(ans_dict) == len(ref_texts)`
""" """
saved_scores = lattice.scores.clone() path = _get_random_paths(
lattice=lattice,
lattice.scores *= scale num_paths=num_paths,
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) use_double_scores=True,
lattice.scores = saved_scores scale=scale,
)
word_seq = k2.index(lattice.aux_labels, path) word_seq = k2.index(lattice.aux_labels, path)
word_seq = k2.ragged.remove_values_leq(word_seq, 0) word_seq = k2.ragged.remove_values_leq(word_seq, 0)
@ -630,6 +680,7 @@ def rescore_with_attention_decoder(
memory_key_padding_mask: torch.Tensor, memory_key_padding_mask: torch.Tensor,
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
scale: float = 1.0,
) -> Dict[str, k2.Fsa]: ) -> Dict[str, k2.Fsa]:
"""This function extracts n paths from the given lattice and uses """This function extracts n paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest an attention decoder to rescore them. The path with the highest
@ -653,6 +704,9 @@ def rescore_with_attention_decoder(
The token ID for SOS. The token ID for SOS.
eos_id: eos_id:
The token ID for EOS. The token ID for EOS.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns: Returns:
A dict of FsaVec, whose key contains a string A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the ngram_lm_scale_attention_scale and the value is the
@ -660,7 +714,12 @@ def rescore_with_attention_decoder(
""" """
# First, extract `num_paths` paths for each sequence. # First, extract `num_paths` paths for each sequence.
# path is a k2.RaggedInt with axes [seq][path][arc_pos] # path is a k2.RaggedInt with axes [seq][path][arc_pos]
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True) path = _get_random_paths(
lattice=lattice,
num_paths=num_paths,
use_double_scores=True,
scale=scale,
)
# word_seq is a k2.RaggedInt sharing the same shape as `path` # word_seq is a k2.RaggedInt sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s. # but it contains word IDs. Note that it also contains 0s and -1s.