Support pure ctc decoding requiring neither a lexicon nor an n-gram LM (#58)

* Rename lattice_score_scale to nbest_scale.

* Support pure CTC decoding requiring neither a lexicion nor an n-gram LM.

* Fix style issues.

* Fix a typo.

* Minor fixes.
This commit is contained in:
Fangjun Kuang 2021-09-26 14:21:49 +08:00 committed by GitHub
parent 455693aede
commit 707d7017a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 136 additions and 60 deletions

View File

@ -299,9 +299,9 @@ The commonly used options are:
.. code-block:: .. code-block::
$ cd egs/librispeech/ASR $ cd egs/librispeech/ASR
$ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --lattice-score-scale 0.5 $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --nbest-scale 0.5
- ``--lattice-score-scale`` - ``--nbest-scale``
It is used to scale down lattice scores so that there are more unique It is used to scale down lattice scores so that there are more unique
paths for rescoring. paths for rescoring.
@ -577,7 +577,7 @@ The command to run HLG decoding + LM rescoring + attention decoder rescoring is:
--G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 1.3 \ --ngram-lm-scale 1.3 \
--attention-decoder-scale 1.2 \ --attention-decoder-scale 1.2 \
--lattice-score-scale 0.5 \ --nbest-scale 0.5 \
--num-paths 100 \ --num-paths 100 \
--sos-id 1 \ --sos-id 1 \
--eos-id 1 \ --eos-id 1 \

View File

@ -40,7 +40,7 @@ python conformer_ctc/train.py --bucketing-sampler True \
--full-libri True \ --full-libri True \
--world-size 4 --world-size 4
python conformer_ctc/decode.py --lattice-score-scale 0.5 \ python conformer_ctc/decode.py --nbest-scale 0.5 \
--epoch 34 \ --epoch 34 \
--avg 20 \ --avg 20 \
--method attention-decoder \ --method attention-decoder \

View File

@ -23,6 +23,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
@ -77,6 +78,9 @@ def get_parser():
default="attention-decoder", default="attention-decoder",
help="""Decoding method. help="""Decoding method.
Supported values are: Supported values are:
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (1) 1best. Extract the best path from the decoding lattice as the - (1) 1best. Extract the best path from the decoding lattice as the
decoding result. decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path - (2) nbest. Extract n paths from the decoding lattice; the path
@ -106,7 +110,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lattice-score-scale", "--nbest-scale",
type=float, type=float,
default=0.5, default=0.5,
help="""The scale to be applied to `lattice.scores`. help="""The scale to be applied to `lattice.scores`.
@ -128,14 +132,26 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe",
help="The lang dir",
)
return parser return parser
def get_params() -> AttributeDict: def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"), "lm_dir": Path("data/lm"),
# parameters for conformer # parameters for conformer
"subsampling_factor": 4, "subsampling_factor": 4,
@ -159,13 +175,15 @@ def get_params() -> AttributeDict:
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
HLG: k2.Fsa, HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict, batch: dict,
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -190,7 +208,11 @@ def decode_one_batch(
model: model:
The neural model. The neural model.
HLG: HLG:
The decoding graph. The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
batch: batch:
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
@ -209,7 +231,10 @@ def decode_one_batch(
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
""" """
device = HLG.device if HLG is not None:
device = HLG.device
else:
device = H.device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
@ -229,9 +254,17 @@ def decode_one_batch(
1, 1,
).to(torch.int32) ).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=decoding_graph,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,
@ -240,6 +273,24 @@ def decode_one_batch(
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
) )
if params.method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
if params.method == "nbest-oracle": if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it. # Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons # We choose the HLG decoded lattice for speed reasons
@ -250,12 +301,12 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=supervisions["text"], ref_texts=supervisions["text"],
word_table=word_table, word_table=word_table,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
oov="<UNK>", oov="<UNK>",
) )
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps} return {key: hyps}
if params.method in ["1best", "nbest"]: if params.method in ["1best", "nbest"]:
@ -269,9 +320,9 @@ def decode_one_batch(
lattice=lattice, lattice=lattice,
num_paths=params.num_paths, num_paths=params.num_paths,
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
@ -293,7 +344,7 @@ def decode_one_batch(
G=G, G=G,
num_paths=params.num_paths, num_paths=params.num_paths,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
elif params.method == "whole-lattice-rescoring": elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice( best_path_dict = rescore_with_whole_lattice(
@ -319,7 +370,7 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"
@ -340,12 +391,14 @@ def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
HLG: k2.Fsa, HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
Args: Args:
@ -356,7 +409,11 @@ def decode_dataset(
model: model:
The neural model. The neural model.
HLG: HLG:
The decoding graph. The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
word_table: word_table:
It is the word symbol table. It is the word symbol table.
sos_id: sos_id:
@ -391,6 +448,8 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
HLG=HLG, HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch, batch=batch,
word_table=word_table, word_table=word_table,
G=G, G=G,
@ -469,6 +528,8 @@ def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -496,14 +557,26 @@ def main():
sos_id = graph_compiler.sos_id sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id eos_id = graph_compiler.eos_id
HLG = k2.Fsa.from_dict( if params.method == "ctc-decoding":
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") HLG = None
) H = k2.ctc_topo(
HLG = HLG.to(device) max_token=max_token_id,
assert HLG.requires_grad is False modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone() HLG.lm_scores = HLG.scores.clone()
if params.method in ( if params.method in (
"nbest-rescoring", "nbest-rescoring",
@ -593,6 +666,8 @@ def main():
params=params, params=params,
model=model, model=model,
HLG=HLG, HLG=HLG,
H=H,
bpe_model=bpe_model,
word_table=lexicon.word_table, word_table=lexicon.word_table,
G=G, G=G,
sos_id=sos_id, sos_id=sos_id,

View File

@ -125,7 +125,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lattice-score-scale", "--nbest-scale",
type=float, type=float,
default=0.5, default=0.5,
help=""" help="""
@ -301,7 +301,7 @@ def main():
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=HLG,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,
@ -336,7 +336,7 @@ def main():
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id, sos_id=params.sos_id,
eos_id=params.eos_id, eos_id=params.eos_id,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
ngram_lm_scale=params.ngram_lm_scale, ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale, attention_scale=params.attention_decoder_scale,
) )

View File

@ -97,7 +97,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lattice-score-scale", "--nbest-scale",
type=float, type=float,
default=0.5, default=0.5,
help="""The scale to be applied to `lattice.scores`. help="""The scale to be applied to `lattice.scores`.
@ -146,7 +146,7 @@ def decode_one_batch(
batch: dict, batch: dict,
lexicon: Lexicon, lexicon: Lexicon,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -210,7 +210,7 @@ def decode_one_batch(
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=HLG,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,
@ -229,7 +229,7 @@ def decode_one_batch(
lattice=lattice, lattice=lattice,
num_paths=params.num_paths, num_paths=params.num_paths,
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
key = f"no_rescore-{params.num_paths}" key = f"no_rescore-{params.num_paths}"
hyps = get_texts(best_path) hyps = get_texts(best_path)
@ -248,7 +248,7 @@ def decode_one_batch(
G=G, G=G,
num_paths=params.num_paths, num_paths=params.num_paths,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
else: else:
best_path_dict = rescore_with_whole_lattice( best_path_dict = rescore_with_whole_lattice(
@ -272,7 +272,7 @@ def decode_dataset(
HLG: k2.Fsa, HLG: k2.Fsa,
lexicon: Lexicon, lexicon: Lexicon,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
Args: Args:

View File

@ -232,7 +232,7 @@ def main():
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=HLG,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,

View File

@ -124,7 +124,7 @@ def decode_one_batch(
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=HLG,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,

View File

@ -175,7 +175,7 @@ def main():
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=HLG,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,

View File

@ -66,7 +66,7 @@ def _intersect_device(
def get_lattice( def get_lattice(
nnet_output: torch.Tensor, nnet_output: torch.Tensor,
HLG: k2.Fsa, decoding_graph: k2.Fsa,
supervision_segments: torch.Tensor, supervision_segments: torch.Tensor,
search_beam: float, search_beam: float,
output_beam: float, output_beam: float,
@ -79,8 +79,9 @@ def get_lattice(
Args: Args:
nnet_output: nnet_output:
It is the output of a neural model of shape `(N, T, C)`. It is the output of a neural model of shape `(N, T, C)`.
HLG: decoding_graph:
An Fsa, the decoding graph. See also `compile_HLG.py`. An Fsa, the decoding graph. It can be either an HLG
(see `compile_HLG.py`) or an H (see `k2.ctc_topo`).
supervision_segments: supervision_segments:
A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns. A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns.
Each row contains information for a supervision segment. Column 0 Each row contains information for a supervision segment. Column 0
@ -117,7 +118,7 @@ def get_lattice(
) )
lattice = k2.intersect_dense_pruned( lattice = k2.intersect_dense_pruned(
HLG, decoding_graph,
dense_fsa_vec, dense_fsa_vec,
search_beam=search_beam, search_beam=search_beam,
output_beam=output_beam, output_beam=output_beam,
@ -180,7 +181,7 @@ class Nbest(object):
lattice: k2.Fsa, lattice: k2.Fsa,
num_paths: int, num_paths: int,
use_double_scores: bool = True, use_double_scores: bool = True,
lattice_score_scale: float = 0.5, nbest_scale: float = 0.5,
) -> "Nbest": ) -> "Nbest":
"""Construct an Nbest object by **sampling** `num_paths` from a lattice. """Construct an Nbest object by **sampling** `num_paths` from a lattice.
@ -206,7 +207,7 @@ class Nbest(object):
Return an Nbest instance. Return an Nbest instance.
""" """
saved_scores = lattice.scores.clone() saved_scores = lattice.scores.clone()
lattice.scores *= lattice_score_scale lattice.scores *= nbest_scale
# path is a ragged tensor with dtype torch.int32. # path is a ragged tensor with dtype torch.int32.
# It has three axes [utt][path][arc_pos] # It has three axes [utt][path][arc_pos]
path = k2.random_paths( path = k2.random_paths(
@ -446,7 +447,7 @@ def nbest_decoding(
lattice: k2.Fsa, lattice: k2.Fsa,
num_paths: int, num_paths: int,
use_double_scores: bool = True, use_double_scores: bool = True,
lattice_score_scale: float = 1.0, nbest_scale: float = 1.0,
) -> k2.Fsa: ) -> k2.Fsa:
"""It implements something like CTC prefix beam search using n-best lists. """It implements something like CTC prefix beam search using n-best lists.
@ -474,7 +475,7 @@ def nbest_decoding(
use_double_scores: use_double_scores:
True to use double precision floating point in the computation. True to use double precision floating point in the computation.
False to use single precision. False to use single precision.
lattice_score_scale: nbest_scale:
It's the scale applied to the `lattice.scores`. A smaller value It's the scale applied to the `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path. leads to more unique paths at the risk of missing the correct path.
Returns: Returns:
@ -484,7 +485,7 @@ def nbest_decoding(
lattice=lattice, lattice=lattice,
num_paths=num_paths, num_paths=num_paths,
use_double_scores=use_double_scores, use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale, nbest_scale=nbest_scale,
) )
# nbest.fsa.scores contains 0s # nbest.fsa.scores contains 0s
@ -505,7 +506,7 @@ def nbest_oracle(
ref_texts: List[str], ref_texts: List[str],
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
use_double_scores: bool = True, use_double_scores: bool = True,
lattice_score_scale: float = 0.5, nbest_scale: float = 0.5,
oov: str = "<UNK>", oov: str = "<UNK>",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
"""Select the best hypothesis given a lattice and a reference transcript. """Select the best hypothesis given a lattice and a reference transcript.
@ -517,7 +518,7 @@ def nbest_oracle(
The decoding result returned from this function is the best result that The decoding result returned from this function is the best result that
we can obtain using n-best decoding with all kinds of rescoring techniques. we can obtain using n-best decoding with all kinds of rescoring techniques.
This function is useful to tune the value of `lattice_score_scale`. This function is useful to tune the value of `nbest_scale`.
Args: Args:
lattice: lattice:
@ -533,7 +534,7 @@ def nbest_oracle(
use_double_scores: use_double_scores:
True to use double precision for computation. False to use True to use double precision for computation. False to use
single precision. single precision.
lattice_score_scale: nbest_scale:
It's the scale applied to the lattice.scores. A smaller value It's the scale applied to the lattice.scores. A smaller value
yields more unique paths. yields more unique paths.
oov: oov:
@ -549,7 +550,7 @@ def nbest_oracle(
lattice=lattice, lattice=lattice,
num_paths=num_paths, num_paths=num_paths,
use_double_scores=use_double_scores, use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale, nbest_scale=nbest_scale,
) )
hyps = nbest.build_levenshtein_graphs() hyps = nbest.build_levenshtein_graphs()
@ -590,7 +591,7 @@ def rescore_with_n_best_list(
G: k2.Fsa, G: k2.Fsa,
num_paths: int, num_paths: int,
lm_scale_list: List[float], lm_scale_list: List[float],
lattice_score_scale: float = 1.0, nbest_scale: float = 1.0,
use_double_scores: bool = True, use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]: ) -> Dict[str, k2.Fsa]:
"""Rescore an n-best list with an n-gram LM. """Rescore an n-best list with an n-gram LM.
@ -607,7 +608,7 @@ def rescore_with_n_best_list(
Size of nbest list. Size of nbest list.
lm_scale_list: lm_scale_list:
A list of float representing LM score scales. A list of float representing LM score scales.
lattice_score_scale: nbest_scale:
Scale to be applied to ``lattice.score`` when sampling paths Scale to be applied to ``lattice.score`` when sampling paths
using ``k2.random_paths``. using ``k2.random_paths``.
use_double_scores: use_double_scores:
@ -631,7 +632,7 @@ def rescore_with_n_best_list(
lattice=lattice, lattice=lattice,
num_paths=num_paths, num_paths=num_paths,
use_double_scores=use_double_scores, use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale, nbest_scale=nbest_scale,
) )
# nbest.fsa.scores are all 0s at this point # nbest.fsa.scores are all 0s at this point
@ -769,7 +770,7 @@ def rescore_with_attention_decoder(
memory_key_padding_mask: Optional[torch.Tensor], memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
lattice_score_scale: float = 1.0, nbest_scale: float = 1.0,
ngram_lm_scale: Optional[float] = None, ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None, attention_scale: Optional[float] = None,
use_double_scores: bool = True, use_double_scores: bool = True,
@ -796,7 +797,7 @@ def rescore_with_attention_decoder(
The token ID for SOS. The token ID for SOS.
eos_id: eos_id:
The token ID for EOS. The token ID for EOS.
lattice_score_scale: nbest_scale:
It's the scale applied to `lattice.scores`. A smaller value It's the scale applied to `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path. leads to more unique paths at the risk of missing the correct path.
ngram_lm_scale: ngram_lm_scale:
@ -812,7 +813,7 @@ def rescore_with_attention_decoder(
lattice=lattice, lattice=lattice,
num_paths=num_paths, num_paths=num_paths,
use_double_scores=use_double_scores, use_double_scores=use_double_scores,
lattice_score_scale=lattice_score_scale, nbest_scale=nbest_scale,
) )
# nbest.fsa.scores are all 0s at this point # nbest.fsa.scores are all 0s at this point

View File

@ -43,7 +43,7 @@ def test_nbest_from_lattice():
lattice=lattice, lattice=lattice,
num_paths=10, num_paths=10,
use_double_scores=True, use_double_scores=True,
lattice_score_scale=0.5, nbest_scale=0.5,
) )
# each lattice has only 4 distinct paths that have different word sequences: # each lattice has only 4 distinct paths that have different word sequences:
# 10->30 # 10->30