Support LG for fast beam search.

This commit is contained in:
Fangjun Kuang 2022-06-21 22:55:18 +08:00
parent f5af662b7b
commit 284cbf7ed1
5 changed files with 361 additions and 153 deletions

View File

@ -482,7 +482,8 @@ def decode_dataset(
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.

View File

@ -177,6 +177,13 @@ def get_parser():
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@ -482,8 +489,8 @@ def decode_dataset(
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest, or fast_beam_search_nbest_oracle. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -726,7 +733,6 @@ def main():
test_set_name=test_set, test_set_name=test_set,
results_dict=results_dict, results_dict=results_dict,
) )
break
logging.info("Done!") logging.info("Done!")

View File

@ -50,9 +50,9 @@ Usage:
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 --max-states 64
(5) fast beam search (nbest) (5) fast beam search (nbest)
./pruned_transducer_stateless3/decode.py \ ./pruned_transducer_stateless3/decode.py \
@ -61,9 +61,9 @@ Usage:
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest \ --decoding-method fast_beam_search_nbest \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 \ --max-states 64 \
--num-paths 200 \ --num-paths 200 \
--nbest-scale 0.5 --nbest-scale 0.5
@ -74,11 +74,22 @@ Usage:
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \ --decoding-method fast_beam_search_nbest_oracle \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 \ --max-states 64 \
--num-paths 200 \ --num-paths 200 \
--nbest-scale 0.5 --nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
""" """
@ -96,6 +107,7 @@ from asr_datamodule import AsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest, fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle, fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
@ -110,6 +122,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -165,6 +178,13 @@ def get_parser():
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@ -176,6 +196,9 @@ def get_parser():
- fast_beam_search - fast_beam_search
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""", """,
) )
@ -191,31 +214,42 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
default=4, default=20.0,
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --decoding-method is Used only when --decoding-method is fast_beam_search,
fast_beam_search, fast_beam_search_nbest, or fast_beam_search_nbest, fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
) )
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
default=4, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, or fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
"--max-states", "--max-states",
type=int, type=int,
default=8, default=64,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, or fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
@ -238,9 +272,8 @@ def get_parser():
type=int, type=int,
default=200, default=200,
help="""Number of paths for nbest decoding. help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest or Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_oracle fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
""",
) )
parser.add_argument( parser.add_argument(
@ -248,9 +281,8 @@ def get_parser():
type=float, type=float,
default=0.5, default=0.5,
help="""Scale applied to lattice scores when computing nbest paths. help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest or Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_oracle fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
""",
) )
return parser return parser
@ -261,6 +293,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -284,10 +317,12 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest, or fast_beam_search_nbest_oracle. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -319,6 +354,20 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest": elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest( hyp_tokens = fast_beam_search_nbest(
model=model, model=model,
@ -403,16 +452,25 @@ def decode_one_batch(
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
): hyps ): hyps
} }
elif "fast_beam_search_nbest" in params.decoding_method: elif params.decoding_method == "fast_beam_search":
return { return {
( (
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_" f"max_states_{params.max_states}"
f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}"
): hyps ): hyps
} }
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -422,6 +480,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -435,10 +494,12 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest, or fast_beam_search_nbest_oracle. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -466,6 +527,7 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -549,6 +611,7 @@ def main():
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_nbest", "fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
) )
@ -559,16 +622,15 @@ def main():
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "fast_beam_search": if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "fast_beam_search_nbest" in params.decoding_method: if "nbest" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-num-paths-{params.num_paths}"
params.suffix += f"-max-states-{params.max_states}" if "LG" in params.decoding_method:
params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
params.suffix += f"-nbest-scale-{params.nbest_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}" f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -634,9 +696,23 @@ def main():
model.unk_id = params.unk_id model.unk_id = params.unk_id
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -659,6 +735,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )

View File

@ -51,9 +51,9 @@ Usage:
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 --max-states 64
(5) fast beam search (nbest) (5) fast beam search (nbest)
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless4/decode.py \
@ -62,9 +62,9 @@ Usage:
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest \ --decoding-method fast_beam_search_nbest \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 \ --max-states 64 \
--num-paths 200 \ --num-paths 200 \
--nbest-scale 0.5 --nbest-scale 0.5
@ -75,11 +75,22 @@ Usage:
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \ --decoding-method fast_beam_search_nbest_oracle \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 \ --max-states 64 \
--num-paths 200 \ --num-paths 200 \
--nbest-scale 0.5 --nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless4/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
""" """
@ -97,6 +108,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest, fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle, fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
@ -111,6 +123,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -178,6 +191,13 @@ def get_parser():
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@ -189,6 +209,9 @@ def get_parser():
- fast_beam_search - fast_beam_search
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""", """,
) )
@ -204,31 +227,42 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
default=4, default=20.0,
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --decoding-method is Used only when --decoding-method is fast_beam_search,
fast_beam_search, fast_beam_search_nbest, or fast_beam_search_nbest, fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
) )
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
default=4, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, or fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
"--max-states", "--max-states",
type=int, type=int,
default=8, default=64,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, or fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
@ -251,9 +285,8 @@ def get_parser():
type=int, type=int,
default=200, default=200,
help="""Number of paths for nbest decoding. help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest or Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_oracle fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
""",
) )
parser.add_argument( parser.add_argument(
@ -261,9 +294,8 @@ def get_parser():
type=float, type=float,
default=0.5, default=0.5,
help="""Scale applied to lattice scores when computing nbest paths. help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest or Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_oracle fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
""",
) )
return parser return parser
@ -274,6 +306,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -297,9 +330,12 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -331,6 +367,20 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest": elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest( hyp_tokens = fast_beam_search_nbest(
model=model, model=model,
@ -407,24 +457,17 @@ def decode_one_batch(
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif "fast_beam_search" in params.decoding_method:
return { key = f"beam_{params.beam}_"
( key += f"max_contexts_{params.max_contexts}_"
f"beam_{params.beam}_" key += f"max_states_{params.max_states}"
f"max_contexts_{params.max_contexts}_" if "nbest" in params.decoding_method:
f"max_states_{params.max_states}" key += f"num_paths_{params.num_paths}_"
): hyps key += f"nbest_scale_{params.nbest_scale}"
} if "LG" in params.decoding_method:
elif "fast_beam_search_nbest" in params.decoding_method: key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {
( return {key: hyps}
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_"
f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}"
): hyps
}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -434,6 +477,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -447,10 +491,12 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest, or fast_beam_search_nbest_oracle. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -479,6 +525,7 @@ def decode_dataset(
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
word_table=word_table,
batch=batch, batch=batch,
) )
@ -561,6 +608,7 @@ def main():
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_nbest", "fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
) )
@ -571,16 +619,15 @@ def main():
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "fast_beam_search": if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "fast_beam_search_nbest" in params.decoding_method: if "nbest" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-num-paths-{params.num_paths}"
params.suffix += f"-max-states-{params.max_states}" if "LG" in params.decoding_method:
params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
params.suffix += f"-nbest-scale-{params.nbest_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}" f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -695,9 +742,23 @@ def main():
model.eval() model.eval()
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -719,6 +780,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )

View File

@ -51,9 +51,9 @@ Usage:
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 --max-states 64
(5) fast beam search (nbest) (5) fast beam search (nbest)
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
@ -62,9 +62,9 @@ Usage:
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest \ --decoding-method fast_beam_search_nbest \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 \ --max-states 64 \
--num-paths 200 \ --num-paths 200 \
--nbest-scale 0.5 --nbest-scale 0.5
@ -75,11 +75,22 @@ Usage:
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \ --decoding-method fast_beam_search_nbest_oracle \
--beam 4 \ --beam 20.0 \
--max-contexts 4 \ --max-contexts 8 \
--max-states 8 \ --max-states 64 \
--num-paths 200 \ --num-paths 200 \
--nbest-scale 0.5 --nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
""" """
@ -97,6 +108,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest, fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle, fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
@ -111,6 +123,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -178,6 +191,13 @@ def get_parser():
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@ -189,6 +209,9 @@ def get_parser():
- fast_beam_search - fast_beam_search
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""", """,
) )
@ -204,31 +227,42 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
default=4, default=20.0,
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
Used only when --decoding-method is Used only when --decoding-method is fast_beam_search,
fast_beam_search, fast_beam_search_nbest, or fast_beam_search_nbest, fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
) )
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
default=4, default=8,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, or fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
"--max-states", "--max-states",
type=int, type=int,
default=8, default=64,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, or fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle""", and fast_beam_search_nbest_oracle""",
) )
parser.add_argument( parser.add_argument(
@ -251,9 +285,8 @@ def get_parser():
type=int, type=int,
default=200, default=200,
help="""Number of paths for nbest decoding. help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest or Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_oracle fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
""",
) )
parser.add_argument( parser.add_argument(
@ -261,9 +294,8 @@ def get_parser():
type=float, type=float,
default=0.5, default=0.5,
help="""Scale applied to lattice scores when computing nbest paths. help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest or Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_oracle fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
""",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -276,6 +308,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -299,9 +332,12 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -333,6 +369,20 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest": elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest( hyp_tokens = fast_beam_search_nbest(
model=model, model=model,
@ -409,24 +459,17 @@ def decode_one_batch(
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search": elif "fast_beam_search" in params.decoding_method:
return { key = f"beam_{params.beam}_"
( key += f"max_contexts_{params.max_contexts}_"
f"beam_{params.beam}_" key += f"max_states_{params.max_states}"
f"max_contexts_{params.max_contexts}_" if "nbest" in params.decoding_method:
f"max_states_{params.max_states}" key += f"num_paths_{params.num_paths}_"
): hyps key += f"nbest_scale_{params.nbest_scale}"
} if "LG" in params.decoding_method:
elif "fast_beam_search_nbest" in params.decoding_method: key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {
( return {key: hyps}
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_"
f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}"
): hyps
}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -436,6 +479,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -449,10 +493,12 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest, or fast_beam_search_nbest_oracle. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -481,6 +527,7 @@ def decode_dataset(
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
word_table=word_table,
batch=batch, batch=batch,
) )
@ -563,6 +610,7 @@ def main():
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_nbest", "fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
) )
@ -573,16 +621,15 @@ def main():
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "fast_beam_search": if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "fast_beam_search_nbest" in params.decoding_method: if "nbest" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-num-paths-{params.num_paths}"
params.suffix += f"-max-states-{params.max_states}" if "LG" in params.decoding_method:
params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
params.suffix += f"-nbest-scale-{params.nbest_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}" f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -697,9 +744,23 @@ def main():
model.eval() model.eval()
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -721,6 +782,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )