Support prefix beam search / shallow fussion / hotwords in librispeech ctc decode

This commit is contained in:
pkufool 2024-10-29 15:36:30 +08:00
parent adec8554cd
commit 154ef4cfa5

View File

@ -111,6 +111,7 @@ Usage:
import argparse import argparse
import logging import logging
import math import math
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -129,8 +130,14 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.context_graph import ContextGraph, ContextState
from icefall.decode import ( from icefall.decode import (
ctc_greedy_search, ctc_greedy_search,
ctc_prefix_beam_search,
ctc_prefix_beam_search_attention_decoder_rescoring,
ctc_prefix_beam_search_shallow_fussion,
get_lattice, get_lattice,
nbest_decoding, nbest_decoding,
nbest_oracle, nbest_oracle,
@ -140,7 +147,11 @@ from icefall.decode import (
rescore_with_n_best_list, rescore_with_n_best_list,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.ngram_lm import NgramLm, NgramLmStateCost
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.lm_wrapper import LmScorer
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts, get_texts,
@ -255,6 +266,12 @@ def get_parser():
lattice, rescore them with the attention decoder. lattice, rescore them with the attention decoder.
- (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM - (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
rescored lattice, rescore them with the attention decoder. rescored lattice, rescore them with the attention decoder.
- (10) ctc-prefix-beam-search. Extract n paths with the given beam, the best
path of the n paths is the decoding result.
- (11) ctc-prefix-beam-search-attention-decoder-rescoring. Extract n paths with
the given beam, rescore them with the attention decoder.
- (12) ctc-prefix-beam-search-shallow-fussion. Use NNLM shallow fussion during
beam search, LODR and hotwords are also supported in this decoding method.
""", """,
) )
@ -280,6 +297,23 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--nnlm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--nnlm-scale",
type=float,
default=0,
help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion.
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument( parser.add_argument(
"--hlg-scale", "--hlg-scale",
type=float, type=float,
@ -297,11 +331,52 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="ID of the backoff symbol in the ngram LM",
)
parser.add_argument(
"--lodr-ngram",
type=str,
help="The path to the lodr ngram",
)
parser.add_argument(
"--lodr-lm-scale",
type=float,
default=0,
help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.",
)
parser.add_argument(
"--context-score",
type=float,
default=0,
help="""
The bonus score of each token for the context biasing words/phrases.
0 means don't use contextual biasing.
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
""",
)
parser.add_argument( parser.add_argument(
"--skip-scoring", "--skip-scoring",
type=str2bool, type=str2bool,
default=False, default=False,
help="""Skip scoring, but still save the ASR output (for eval sets).""" help="""Skip scoring, but still save the ASR output (for eval sets).""",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -314,11 +389,12 @@ def get_decoding_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"frame_shift_ms": 10, "frame_shift_ms": 10,
"search_beam": 20, "search_beam": 20, # for k2 fsa composition
"output_beam": 8, "output_beam": 8, # for k2 fsa composition
"min_active_states": 30, "min_active_states": 30,
"max_active_states": 10000, "max_active_states": 10000,
"use_double_scores": True, "use_double_scores": True,
"beam": 4, # for prefix-beam-search
} }
) )
return params return params
@ -333,6 +409,9 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
NNLM: Optional[LmScorer] = None,
LODR_lm: Optional[NgramLm] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -377,10 +456,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None. the returned dict. Note: If it decodes to nothing, then return None.
""" """
if HLG is not None: device = params.device
device = HLG.device
else:
device = H.device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
@ -411,6 +487,51 @@ def decode_one_batch(
key = "ctc-greedy-search" key = "ctc-greedy-search"
return {key: hyps} return {key: hyps}
if params.decoding_method == "ctc-prefix-beam-search":
token_ids = ctc_prefix_beam_search(
ctc_output=ctc_output, encoder_out_lens=encoder_out_lens
)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "prefix-beam-search"
return {key: hyps}
if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring":
best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring(
ctc_output=ctc_output,
attention_decoder=model.attention_decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
ans = dict()
for a_scale_str, token_ids in best_path_dict.items():
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
ans[a_scale_str] = hyps
return ans
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
token_ids = ctc_prefix_beam_search_shallow_fussion(
ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens,
NNLM=NNLM,
LODR_lm=LODR_lm,
LODR_lm_scale=params.lodr_lm_scale,
context_graph=context_graph,
)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "prefix-beam-search-shallow-fussion"
return {key: hyps}
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (
supervisions["sequence_idx"], supervisions["sequence_idx"],
@ -584,6 +705,9 @@ def decode_dataset(
bpe_model: Optional[spm.SentencePieceProcessor], bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
NNLM: Optional[LmScorer] = None,
LODR_lm: Optional[NgramLm] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -634,6 +758,9 @@ def decode_dataset(
batch=batch, batch=batch,
word_table=word_table, word_table=word_table,
G=G, G=G,
NNLM=NNLM,
LODR_lm=LODR_lm,
context_graph=context_graph,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -664,9 +791,7 @@ def save_asr_output(
""" """
for key, results in results_dict.items(): for key, results in results_dict.items():
recogs_filename = ( recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
)
results = sorted(results) results = sorted(results)
store_transcripts(filename=recogs_filename, texts=results) store_transcripts(filename=recogs_filename, texts=results)
@ -680,7 +805,8 @@ def save_wer_results(
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
): ):
if params.decoding_method in ( if params.decoding_method in (
"attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring" "attention-decoder-rescoring-with-ngram",
"whole-lattice-rescoring",
): ):
# Set it to False since there are too many logs. # Set it to False since there are too many logs.
enable_log = False enable_log = False
@ -721,6 +847,7 @@ def save_wer_results(
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir) args.lang_dir = Path(args.lang_dir)
@ -735,8 +862,11 @@ def main():
set_caching_enabled(True) # lhotse set_caching_enabled(True) # lhotse
assert params.decoding_method in ( assert params.decoding_method in (
"ctc-greedy-search",
"ctc-decoding", "ctc-decoding",
"ctc-greedy-search",
"ctc-prefix-beam-search",
"ctc-prefix-beam-search-attention-decoder-rescoring",
"ctc-prefix-beam-search-shallow-fussion",
"1best", "1best",
"nbest", "nbest",
"nbest-rescoring", "nbest-rescoring",
@ -762,6 +892,16 @@ def main():
params.suffix += f"_chunk-{params.chunk_size}" params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"_left-context-{params.left_context_frames}" params.suffix += f"_left-context-{params.left_context_frames}"
if "prefix-beam-search" in params.decoding_method:
params.suffix += f"_beam-{params.beam}"
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
if params.nnlm_scale != 0:
params.suffix += f"_nnlm-scale-{params.nnlm_scale}"
if params.lodr_lm_scale != 0:
params.suffix += f"_lodr-scale-{params.lodr_lm_scale}"
if params.context_score != 0:
params.suffix += f"_context_score-{params.context_score}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "_use-averaged-model" params.suffix += "_use-averaged-model"
@ -771,6 +911,7 @@ def main():
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
params.device = device
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
logging.info(params) logging.info(params)
@ -786,9 +927,19 @@ def main():
params.sos_id = 1 params.sos_id = 1
if params.decoding_method in [ if params.decoding_method in [
"ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram" "ctc-decoding",
"ctc-greedy-search",
"ctc-prefix-beam-search",
"ctc-prefix-beam-search-attention-decoder-rescoring",
"ctc-prefix-beam-search-shallow-fussion",
"attention-decoder-rescoring-no-ngram",
]: ]:
HLG = None HLG = None
H = None
if params.decoding_method in [
"ctc-decoding",
"attention-decoder-rescoring-no-ngram",
]:
H = k2.ctc_topo( H = k2.ctc_topo(
max_token=max_token_id, max_token=max_token_id,
modified=False, modified=False,
@ -844,7 +995,8 @@ def main():
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
if params.decoding_method in [ if params.decoding_method in [
"whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram" "whole-lattice-rescoring",
"attention-decoder-rescoring-with-ngram",
]: ]:
# Add epsilon self-loops to G as we will compose # Add epsilon self-loops to G as we will compose
# it with the whole lattice later # it with the whole lattice later
@ -858,6 +1010,51 @@ def main():
else: else:
G = None G = None
# only load the neural network LM if required
NNLM = None
if (
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
and params.nnlm_scale != 0
):
NNLM = LmScorer(
lm_type=params.nnlm_type,
params=params,
device=device,
lm_scale=params.nnlm_scale,
)
NNLM.to(device)
NNLM.eval()
LODR_lm = None
if (
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
and params.lodr_lm_scale != 0
):
assert os.path.exists(
params.lodr_ngram
), f"LODR ngram does not exists, given path : {params.lodr_ngram}"
logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}")
LODR_lm = NgramLm(
params.lodr_ngram,
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {LODR_lm.lm.num_states}")
context_graph = None
if (
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
and params.context_score != 0
):
assert os.path.exists(
params.context_file
), f"context_file does not exists, given path : {params.context_file}"
contexts = []
for line in open(params.context_file).readlines():
contexts.append(bpe_model.encode(line.strip()))
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)
logging.info("About to create model") logging.info("About to create model")
model = get_model(params) model = get_model(params)
@ -967,6 +1164,9 @@ def main():
bpe_model=bpe_model, bpe_model=bpe_model,
word_table=lexicon.word_table, word_table=lexicon.word_table,
G=G, G=G,
NNLM=NNLM,
LODR_lm=LODR_lm,
context_graph=context_graph,
) )
save_asr_output( save_asr_output(