mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Add scale to all nbest based decoding/rescoring methods.
This commit is contained in:
parent
401c1c5143
commit
38d06049de
@ -63,7 +63,10 @@ def get_parser():
|
|||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
help="The scale to be applied to `lattice.scores`."
|
help="The scale to be applied to `lattice.scores`."
|
||||||
"A smaller value results in more unique paths",
|
"It's needed if you use any kinds of n-best based rescoring. "
|
||||||
|
"Currently, it is used when the decoding method is: nbest, "
|
||||||
|
"nbest-rescoring, attention-decoder, and nbest-oracle. "
|
||||||
|
"A smaller value results in more unique paths.",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -96,6 +99,8 @@ def get_params() -> AttributeDict:
|
|||||||
# - whole-lattice-rescoring
|
# - whole-lattice-rescoring
|
||||||
# - attention-decoder
|
# - attention-decoder
|
||||||
# - nbest-oracle
|
# - nbest-oracle
|
||||||
|
# "method": "nbest",
|
||||||
|
# "method": "nbest-rescoring",
|
||||||
# "method": "whole-lattice-rescoring",
|
# "method": "whole-lattice-rescoring",
|
||||||
"method": "attention-decoder",
|
"method": "attention-decoder",
|
||||||
# "method": "nbest-oracle",
|
# "method": "nbest-oracle",
|
||||||
@ -215,8 +220,9 @@ def decode_one_batch(
|
|||||||
lattice=lattice,
|
lattice=lattice,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
use_double_scores=params.use_double_scores,
|
use_double_scores=params.use_double_scores,
|
||||||
|
scale=params.scale,
|
||||||
)
|
)
|
||||||
key = f"no_rescore-{params.num_paths}"
|
key = f"no_rescore-scale-{params.scale}-{params.num_paths}"
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
hyps = get_texts(best_path)
|
||||||
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
||||||
@ -237,6 +243,7 @@ def decode_one_batch(
|
|||||||
G=G,
|
G=G,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
lm_scale_list=lm_scale_list,
|
lm_scale_list=lm_scale_list,
|
||||||
|
scale=params.scale,
|
||||||
)
|
)
|
||||||
elif params.method == "whole-lattice-rescoring":
|
elif params.method == "whole-lattice-rescoring":
|
||||||
best_path_dict = rescore_with_whole_lattice(
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
@ -256,6 +263,7 @@ def decode_one_batch(
|
|||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
|
scale=params.scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert False, f"Unsupported decoding method: {params.method}"
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
|
@ -9,6 +9,36 @@ import torch.nn as nn
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
|
||||||
|
|
||||||
|
def _get_random_paths(
|
||||||
|
lattice: k2.Fsa,
|
||||||
|
num_paths: int,
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
scale: float = 1.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lattice:
|
||||||
|
The decoding lattice, returned by :func:`get_lattice`.
|
||||||
|
num_paths:
|
||||||
|
It specifies the size `n` in n-best. Note: Paths are selected randomly
|
||||||
|
and those containing identical word sequences are remove dand only one
|
||||||
|
of them is kept.
|
||||||
|
use_double_scores:
|
||||||
|
True to use double precision floating point in the computation.
|
||||||
|
False to use single precision.
|
||||||
|
scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
|
Returns:
|
||||||
|
Return a k2.RaggedInt with 3 axes [seq][path][arc_pos]
|
||||||
|
"""
|
||||||
|
saved_scores = lattice.scores.clone()
|
||||||
|
lattice.scores *= scale
|
||||||
|
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
|
||||||
|
lattice.scores = saved_scores
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
def _intersect_device(
|
def _intersect_device(
|
||||||
a_fsas: k2.Fsa,
|
a_fsas: k2.Fsa,
|
||||||
b_fsas: k2.Fsa,
|
b_fsas: k2.Fsa,
|
||||||
@ -132,7 +162,10 @@ def one_best_decoding(
|
|||||||
|
|
||||||
|
|
||||||
def nbest_decoding(
|
def nbest_decoding(
|
||||||
lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True
|
lattice: k2.Fsa,
|
||||||
|
num_paths: int,
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
scale: float = 1.0,
|
||||||
) -> k2.Fsa:
|
) -> k2.Fsa:
|
||||||
"""It implements something like CTC prefix beam search using n-best lists.
|
"""It implements something like CTC prefix beam search using n-best lists.
|
||||||
|
|
||||||
@ -155,12 +188,18 @@ def nbest_decoding(
|
|||||||
use_double_scores:
|
use_double_scores:
|
||||||
True to use double precision floating point in the computation.
|
True to use double precision floating point in the computation.
|
||||||
False to use single precision.
|
False to use single precision.
|
||||||
|
scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
Returns:
|
Returns:
|
||||||
An FsaVec containing linear FSAs.
|
An FsaVec containing linear FSAs.
|
||||||
"""
|
"""
|
||||||
# First, extract `num_paths` paths for each sequence.
|
path = _get_random_paths(
|
||||||
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
|
lattice=lattice,
|
||||||
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
||||||
# but it contains word IDs. Note that it also contains 0s and -1s.
|
# but it contains word IDs. Note that it also contains 0s and -1s.
|
||||||
@ -323,7 +362,11 @@ def compute_am_and_lm_scores(
|
|||||||
|
|
||||||
|
|
||||||
def rescore_with_n_best_list(
|
def rescore_with_n_best_list(
|
||||||
lattice: k2.Fsa, G: k2.Fsa, num_paths: int, lm_scale_list: List[float]
|
lattice: k2.Fsa,
|
||||||
|
G: k2.Fsa,
|
||||||
|
num_paths: int,
|
||||||
|
lm_scale_list: List[float],
|
||||||
|
scale: float = 1.0,
|
||||||
) -> Dict[str, k2.Fsa]:
|
) -> Dict[str, k2.Fsa]:
|
||||||
"""Decode using n-best list with LM rescoring.
|
"""Decode using n-best list with LM rescoring.
|
||||||
|
|
||||||
@ -345,6 +388,9 @@ def rescore_with_n_best_list(
|
|||||||
It is the size `n` in `n-best` list.
|
It is the size `n` in `n-best` list.
|
||||||
lm_scale_list:
|
lm_scale_list:
|
||||||
A list containing lm_scale values.
|
A list containing lm_scale values.
|
||||||
|
scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
Returns:
|
Returns:
|
||||||
A dict of FsaVec, whose key is an lm_scale and the value is the
|
A dict of FsaVec, whose key is an lm_scale and the value is the
|
||||||
best decoding path for each sequence in the lattice.
|
best decoding path for each sequence in the lattice.
|
||||||
@ -359,9 +405,12 @@ def rescore_with_n_best_list(
|
|||||||
assert G.device == device
|
assert G.device == device
|
||||||
assert hasattr(G, "aux_labels") is False
|
assert hasattr(G, "aux_labels") is False
|
||||||
|
|
||||||
# First, extract `num_paths` paths for each sequence.
|
path = _get_random_paths(
|
||||||
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
|
lattice=lattice,
|
||||||
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
|
num_paths=num_paths,
|
||||||
|
use_double_scores=True,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
||||||
# but it contains word IDs. Note that it also contains 0s and -1s.
|
# but it contains word IDs. Note that it also contains 0s and -1s.
|
||||||
@ -587,11 +636,12 @@ def nbest_oracle(
|
|||||||
when calling this function, while its value contains the decoding output.
|
when calling this function, while its value contains the decoding output.
|
||||||
`len(ans_dict) == len(ref_texts)`
|
`len(ans_dict) == len(ref_texts)`
|
||||||
"""
|
"""
|
||||||
saved_scores = lattice.scores.clone()
|
path = _get_random_paths(
|
||||||
|
lattice=lattice,
|
||||||
lattice.scores *= scale
|
num_paths=num_paths,
|
||||||
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
|
use_double_scores=True,
|
||||||
lattice.scores = saved_scores
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
word_seq = k2.index(lattice.aux_labels, path)
|
word_seq = k2.index(lattice.aux_labels, path)
|
||||||
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
|
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
|
||||||
@ -630,6 +680,7 @@ def rescore_with_attention_decoder(
|
|||||||
memory_key_padding_mask: torch.Tensor,
|
memory_key_padding_mask: torch.Tensor,
|
||||||
sos_id: int,
|
sos_id: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
|
scale: float = 1.0,
|
||||||
) -> Dict[str, k2.Fsa]:
|
) -> Dict[str, k2.Fsa]:
|
||||||
"""This function extracts n paths from the given lattice and uses
|
"""This function extracts n paths from the given lattice and uses
|
||||||
an attention decoder to rescore them. The path with the highest
|
an attention decoder to rescore them. The path with the highest
|
||||||
@ -653,6 +704,9 @@ def rescore_with_attention_decoder(
|
|||||||
The token ID for SOS.
|
The token ID for SOS.
|
||||||
eos_id:
|
eos_id:
|
||||||
The token ID for EOS.
|
The token ID for EOS.
|
||||||
|
scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
Returns:
|
Returns:
|
||||||
A dict of FsaVec, whose key contains a string
|
A dict of FsaVec, whose key contains a string
|
||||||
ngram_lm_scale_attention_scale and the value is the
|
ngram_lm_scale_attention_scale and the value is the
|
||||||
@ -660,7 +714,12 @@ def rescore_with_attention_decoder(
|
|||||||
"""
|
"""
|
||||||
# First, extract `num_paths` paths for each sequence.
|
# First, extract `num_paths` paths for each sequence.
|
||||||
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
|
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
|
||||||
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
|
path = _get_random_paths(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=True,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
||||||
# but it contains word IDs. Note that it also contains 0s and -1s.
|
# but it contains word IDs. Note that it also contains 0s and -1s.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user