Rename lattice_score_scale to nbest_scale.

This commit is contained in:
Fangjun Kuang 2021-09-26 11:43:03 +08:00
parent 455693aede
commit cd7a36b0a2
7 changed files with 32 additions and 32 deletions

View File

@ -299,9 +299,9 @@ The commonly used options are:
.. code-block:: .. code-block::
$ cd egs/librispeech/ASR $ cd egs/librispeech/ASR
$ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --lattice-score-scale 0.5 $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --nbest-scale 0.5
- ``--lattice-score-scale`` - ``--nbest-scale``
It is used to scale down lattice scores so that there are more unique It is used to scale down lattice scores so that there are more unique
paths for rescoring. paths for rescoring.
@ -577,7 +577,7 @@ The command to run HLG decoding + LM rescoring + attention decoder rescoring is:
--G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 1.3 \ --ngram-lm-scale 1.3 \
--attention-decoder-scale 1.2 \ --attention-decoder-scale 1.2 \
--lattice-score-scale 0.5 \ --nbest-scale 0.5 \
--num-paths 100 \ --num-paths 100 \
--sos-id 1 \ --sos-id 1 \
--eos-id 1 \ --eos-id 1 \

View File

@ -40,7 +40,7 @@ python conformer_ctc/train.py --bucketing-sampler True \
--full-libri True \ --full-libri True \
--world-size 4 --world-size 4
python conformer_ctc/decode.py --lattice-score-scale 0.5 \ python conformer_ctc/decode.py --nbest-scale 0.5 \
--epoch 34 \ --epoch 34 \
--avg 20 \ --avg 20 \
--method attention-decoder \ --method attention-decoder \

View File

@ -106,7 +106,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lattice-score-scale", "--nbest-scale",
type=float, type=float,
default=0.5, default=0.5,
help="""The scale to be applied to `lattice.scores`. help="""The scale to be applied to `lattice.scores`.
@ -250,12 +250,12 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=supervisions["text"], ref_texts=supervisions["text"],
word_table=word_table, word_table=word_table,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
oov="<UNK>", oov="<UNK>",
) )
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps} return {key: hyps}
if params.method in ["1best", "nbest"]: if params.method in ["1best", "nbest"]:
@ -269,9 +269,9 @@ def decode_one_batch(
lattice=lattice, lattice=lattice,
num_paths=params.num_paths, num_paths=params.num_paths,
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
@ -293,7 +293,7 @@ def decode_one_batch(
G=G, G=G,
num_paths=params.num_paths, num_paths=params.num_paths,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
elif params.method == "whole-lattice-rescoring": elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice( best_path_dict = rescore_with_whole_lattice(
@ -319,7 +319,7 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"

View File

@ -125,7 +125,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lattice-score-scale", "--nbest-scale",
type=float, type=float,
default=0.5, default=0.5,
help=""" help="""
@ -336,7 +336,7 @@ def main():
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id, sos_id=params.sos_id,
eos_id=params.eos_id, eos_id=params.eos_id,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
ngram_lm_scale=params.ngram_lm_scale, ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale, attention_scale=params.attention_decoder_scale,
) )

View File

@ -97,7 +97,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lattice-score-scale", "--nbest-scale",
type=float, type=float,
default=0.5, default=0.5,
help="""The scale to be applied to `lattice.scores`. help="""The scale to be applied to `lattice.scores`.
@ -229,7 +229,7 @@ def decode_one_batch(
lattice=lattice, lattice=lattice,
num_paths=params.num_paths, num_paths=params.num_paths,
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
key = f"no_rescore-{params.num_paths}" key = f"no_rescore-{params.num_paths}"
hyps = get_texts(best_path) hyps = get_texts(best_path)
@ -248,7 +248,7 @@ def decode_one_batch(
G=G, G=G,
num_paths=params.num_paths, num_paths=params.num_paths,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
else: else:
best_path_dict = rescore_with_whole_lattice( best_path_dict = rescore_with_whole_lattice(

View File

@ -180,7 +180,7 @@ class Nbest(object):
lattice: k2.Fsa, lattice: k2.Fsa,
num_paths: int, num_paths: int,
use_double_scores: bool = True, use_double_scores: bool = True,
lattice_score_scale: float = 0.5, nbest_scale: float = 0.5,
) -> "Nbest": ) -> "Nbest":
"""Construct an Nbest object by **sampling** `num_paths` from a lattice. """Construct an Nbest object by **sampling** `num_paths` from a lattice.
@ -206,7 +206,7 @@ class Nbest(object):
Return an Nbest instance. Return an Nbest instance.
""" """
saved_scores = lattice.scores.clone() saved_scores = lattice.scores.clone()
lattice.scores *= lattice_score_scale lattice.scores *= nbest_scale
# path is a ragged tensor with dtype torch.int32. # path is a ragged tensor with dtype torch.int32.
# It has three axes [utt][path][arc_pos] # It has three axes [utt][path][arc_pos]
path = k2.random_paths( path = k2.random_paths(
@ -446,7 +446,7 @@ def nbest_decoding(
lattice: k2.Fsa, lattice: k2.Fsa,
num_paths: int, num_paths: int,
use_double_scores: bool = True, use_double_scores: bool = True,
lattice_score_scale: float = 1.0, nbest_scale: float = 1.0,
) -> k2.Fsa: ) -> k2.Fsa:
"""It implements something like CTC prefix beam search using n-best lists. """It implements something like CTC prefix beam search using n-best lists.
@ -474,7 +474,7 @@ def nbest_decoding(
use_double_scores: use_double_scores:
True to use double precision floating point in the computation. True to use double precision floating point in the computation.
False to use single precision. False to use single precision.
lattice_score_scale: nbest_scale:
It's the scale applied to the `lattice.scores`. A smaller value It's the scale applied to the `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path. leads to more unique paths at the risk of missing the correct path.
Returns: Returns:
@ -484,7 +484,7 @@ def nbest_decoding(
lattice=lattice, lattice=lattice,
num_paths=num_paths, num_paths=num_paths,
use_double_scores=use_double_scores, use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale, nbest_scale=nbest_scale,
) )
# nbest.fsa.scores contains 0s # nbest.fsa.scores contains 0s
@ -505,7 +505,7 @@ def nbest_oracle(
ref_texts: List[str], ref_texts: List[str],
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
use_double_scores: bool = True, use_double_scores: bool = True,
lattice_score_scale: float = 0.5, nbest_scale: float = 0.5,
oov: str = "<UNK>", oov: str = "<UNK>",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
"""Select the best hypothesis given a lattice and a reference transcript. """Select the best hypothesis given a lattice and a reference transcript.
@ -517,7 +517,7 @@ def nbest_oracle(
The decoding result returned from this function is the best result that The decoding result returned from this function is the best result that
we can obtain using n-best decoding with all kinds of rescoring techniques. we can obtain using n-best decoding with all kinds of rescoring techniques.
This function is useful to tune the value of `lattice_score_scale`. This function is useful to tune the value of `nbest_scale`.
Args: Args:
lattice: lattice:
@ -533,7 +533,7 @@ def nbest_oracle(
use_double_scores: use_double_scores:
True to use double precision for computation. False to use True to use double precision for computation. False to use
single precision. single precision.
lattice_score_scale: nbest_scale:
It's the scale applied to the lattice.scores. A smaller value It's the scale applied to the lattice.scores. A smaller value
yields more unique paths. yields more unique paths.
oov: oov:
@ -549,7 +549,7 @@ def nbest_oracle(
lattice=lattice, lattice=lattice,
num_paths=num_paths, num_paths=num_paths,
use_double_scores=use_double_scores, use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale, nbest_scale=nbest_scale,
) )
hyps = nbest.build_levenshtein_graphs() hyps = nbest.build_levenshtein_graphs()
@ -590,7 +590,7 @@ def rescore_with_n_best_list(
G: k2.Fsa, G: k2.Fsa,
num_paths: int, num_paths: int,
lm_scale_list: List[float], lm_scale_list: List[float],
lattice_score_scale: float = 1.0, nbest_scale: float = 1.0,
use_double_scores: bool = True, use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]: ) -> Dict[str, k2.Fsa]:
"""Rescore an n-best list with an n-gram LM. """Rescore an n-best list with an n-gram LM.
@ -607,7 +607,7 @@ def rescore_with_n_best_list(
Size of nbest list. Size of nbest list.
lm_scale_list: lm_scale_list:
A list of float representing LM score scales. A list of float representing LM score scales.
lattice_score_scale: nbest_scale:
Scale to be applied to ``lattice.score`` when sampling paths Scale to be applied to ``lattice.score`` when sampling paths
using ``k2.random_paths``. using ``k2.random_paths``.
use_double_scores: use_double_scores:
@ -631,7 +631,7 @@ def rescore_with_n_best_list(
lattice=lattice, lattice=lattice,
num_paths=num_paths, num_paths=num_paths,
use_double_scores=use_double_scores, use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale, nbest_scale=nbest_scale,
) )
# nbest.fsa.scores are all 0s at this point # nbest.fsa.scores are all 0s at this point
@ -769,7 +769,7 @@ def rescore_with_attention_decoder(
memory_key_padding_mask: Optional[torch.Tensor], memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
lattice_score_scale: float = 1.0, nbest_scale: float = 1.0,
ngram_lm_scale: Optional[float] = None, ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None, attention_scale: Optional[float] = None,
use_double_scores: bool = True, use_double_scores: bool = True,
@ -796,7 +796,7 @@ def rescore_with_attention_decoder(
The token ID for SOS. The token ID for SOS.
eos_id: eos_id:
The token ID for EOS. The token ID for EOS.
lattice_score_scale: nbest_scale:
It's the scale applied to `lattice.scores`. A smaller value It's the scale applied to `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path. leads to more unique paths at the risk of missing the correct path.
ngram_lm_scale: ngram_lm_scale:
@ -812,7 +812,7 @@ def rescore_with_attention_decoder(
lattice=lattice, lattice=lattice,
num_paths=num_paths, num_paths=num_paths,
use_double_scores=use_double_scores, use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale, nbest_scale=nbest_scale,
) )
# nbest.fsa.scores are all 0s at this point # nbest.fsa.scores are all 0s at this point

View File

@ -43,7 +43,7 @@ def test_nbest_from_lattice():
lattice=lattice, lattice=lattice,
num_paths=10, num_paths=10,
use_double_scores=True, use_double_scores=True,
lattice_score_scale=0.5, nbest_scale=0.5,
) )
# each lattice has only 4 distinct paths that have different word sequences: # each lattice has only 4 distinct paths that have different word sequences:
# 10->30 # 10->30