mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 20:42:22 +00:00
Support prefix beam search / shallow fussion / hotwords in librispeech ctc decode
This commit is contained in:
parent
adec8554cd
commit
154ef4cfa5
@ -111,6 +111,7 @@ Usage:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -129,8 +130,14 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from icefall.context_graph import ContextGraph, ContextState
|
||||||
|
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
ctc_greedy_search,
|
ctc_greedy_search,
|
||||||
|
ctc_prefix_beam_search,
|
||||||
|
ctc_prefix_beam_search_attention_decoder_rescoring,
|
||||||
|
ctc_prefix_beam_search_shallow_fussion,
|
||||||
get_lattice,
|
get_lattice,
|
||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
nbest_oracle,
|
nbest_oracle,
|
||||||
@ -140,7 +147,11 @@ from icefall.decode import (
|
|||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from icefall.ngram_lm import NgramLm, NgramLmStateCost
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.lm_wrapper import LmScorer
|
||||||
|
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_texts,
|
get_texts,
|
||||||
@ -255,6 +266,12 @@ def get_parser():
|
|||||||
lattice, rescore them with the attention decoder.
|
lattice, rescore them with the attention decoder.
|
||||||
- (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
|
- (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
|
||||||
rescored lattice, rescore them with the attention decoder.
|
rescored lattice, rescore them with the attention decoder.
|
||||||
|
- (10) ctc-prefix-beam-search. Extract n paths with the given beam, the best
|
||||||
|
path of the n paths is the decoding result.
|
||||||
|
- (11) ctc-prefix-beam-search-attention-decoder-rescoring. Extract n paths with
|
||||||
|
the given beam, rescore them with the attention decoder.
|
||||||
|
- (12) ctc-prefix-beam-search-shallow-fussion. Use NNLM shallow fussion during
|
||||||
|
beam search, LODR and hotwords are also supported in this decoding method.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -280,6 +297,23 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nnlm-type",
|
||||||
|
type=str,
|
||||||
|
default="rnn",
|
||||||
|
help="Type of NN lm",
|
||||||
|
choices=["rnn", "transformer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nnlm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0,
|
||||||
|
help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion.
|
||||||
|
Used only when `--use-shallow-fusion` is set to True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hlg-scale",
|
"--hlg-scale",
|
||||||
type=float,
|
type=float,
|
||||||
@ -297,11 +331,52 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--backoff-id",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="ID of the backoff symbol in the ngram LM",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lodr-ngram",
|
||||||
|
type=str,
|
||||||
|
help="The path to the lodr ngram",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lodr-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0,
|
||||||
|
help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-score",
|
||||||
|
type=float,
|
||||||
|
default=0,
|
||||||
|
help="""
|
||||||
|
The bonus score of each token for the context biasing words/phrases.
|
||||||
|
0 means don't use contextual biasing.
|
||||||
|
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The path of the context biasing lists, one word/phrase each line
|
||||||
|
Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip-scoring",
|
"--skip-scoring",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="""Skip scoring, but still save the ASR output (for eval sets)."""
|
help="""Skip scoring, but still save the ASR output (for eval sets).""",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -314,11 +389,12 @@ def get_decoding_params() -> AttributeDict:
|
|||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"frame_shift_ms": 10,
|
"frame_shift_ms": 10,
|
||||||
"search_beam": 20,
|
"search_beam": 20, # for k2 fsa composition
|
||||||
"output_beam": 8,
|
"output_beam": 8, # for k2 fsa composition
|
||||||
"min_active_states": 30,
|
"min_active_states": 30,
|
||||||
"max_active_states": 10000,
|
"max_active_states": 10000,
|
||||||
"use_double_scores": True,
|
"use_double_scores": True,
|
||||||
|
"beam": 4, # for prefix-beam-search
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
@ -333,6 +409,9 @@ def decode_one_batch(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: k2.SymbolTable,
|
word_table: k2.SymbolTable,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
|
NNLM: Optional[LmScorer] = None,
|
||||||
|
LODR_lm: Optional[NgramLm] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -377,10 +456,7 @@ def decode_one_batch(
|
|||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict. Note: If it decodes to nothing, then return None.
|
the returned dict. Note: If it decodes to nothing, then return None.
|
||||||
"""
|
"""
|
||||||
if HLG is not None:
|
device = params.device
|
||||||
device = HLG.device
|
|
||||||
else:
|
|
||||||
device = H.device
|
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
@ -411,6 +487,51 @@ def decode_one_batch(
|
|||||||
key = "ctc-greedy-search"
|
key = "ctc-greedy-search"
|
||||||
return {key: hyps}
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search":
|
||||||
|
token_ids = ctc_prefix_beam_search(
|
||||||
|
ctc_output=ctc_output, encoder_out_lens=encoder_out_lens
|
||||||
|
)
|
||||||
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
|
hyps = bpe_model.decode(token_ids)
|
||||||
|
|
||||||
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||||
|
hyps = [s.split() for s in hyps]
|
||||||
|
key = "prefix-beam-search"
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring":
|
||||||
|
best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring(
|
||||||
|
ctc_output=ctc_output,
|
||||||
|
attention_decoder=model.attention_decoder,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
|
ans = dict()
|
||||||
|
for a_scale_str, token_ids in best_path_dict.items():
|
||||||
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
|
hyps = bpe_model.decode(token_ids)
|
||||||
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||||
|
hyps = [s.split() for s in hyps]
|
||||||
|
ans[a_scale_str] = hyps
|
||||||
|
return ans
|
||||||
|
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||||
|
token_ids = ctc_prefix_beam_search_shallow_fussion(
|
||||||
|
ctc_output=ctc_output,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
NNLM=NNLM,
|
||||||
|
LODR_lm=LODR_lm,
|
||||||
|
LODR_lm_scale=params.lodr_lm_scale,
|
||||||
|
context_graph=context_graph,
|
||||||
|
)
|
||||||
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
|
hyps = bpe_model.decode(token_ids)
|
||||||
|
|
||||||
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||||
|
hyps = [s.split() for s in hyps]
|
||||||
|
key = "prefix-beam-search-shallow-fussion"
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
supervision_segments = torch.stack(
|
||||||
(
|
(
|
||||||
supervisions["sequence_idx"],
|
supervisions["sequence_idx"],
|
||||||
@ -584,6 +705,9 @@ def decode_dataset(
|
|||||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
word_table: k2.SymbolTable,
|
word_table: k2.SymbolTable,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
|
NNLM: Optional[LmScorer] = None,
|
||||||
|
LODR_lm: Optional[NgramLm] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -634,6 +758,9 @@ def decode_dataset(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
G=G,
|
G=G,
|
||||||
|
NNLM=NNLM,
|
||||||
|
LODR_lm=LODR_lm,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -664,9 +791,7 @@ def save_asr_output(
|
|||||||
"""
|
"""
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
|
|
||||||
recogs_filename = (
|
recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
|
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recogs_filename, texts=results)
|
store_transcripts(filename=recogs_filename, texts=results)
|
||||||
@ -680,7 +805,8 @@ def save_wer_results(
|
|||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
):
|
):
|
||||||
if params.decoding_method in (
|
if params.decoding_method in (
|
||||||
"attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring"
|
"attention-decoder-rescoring-with-ngram",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
):
|
):
|
||||||
# Set it to False since there are too many logs.
|
# Set it to False since there are too many logs.
|
||||||
enable_log = False
|
enable_log = False
|
||||||
@ -721,6 +847,7 @@ def save_wer_results(
|
|||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
LmScorer.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
args.lang_dir = Path(args.lang_dir)
|
||||||
@ -735,8 +862,11 @@ def main():
|
|||||||
set_caching_enabled(True) # lhotse
|
set_caching_enabled(True) # lhotse
|
||||||
|
|
||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"ctc-greedy-search",
|
|
||||||
"ctc-decoding",
|
"ctc-decoding",
|
||||||
|
"ctc-greedy-search",
|
||||||
|
"ctc-prefix-beam-search",
|
||||||
|
"ctc-prefix-beam-search-attention-decoder-rescoring",
|
||||||
|
"ctc-prefix-beam-search-shallow-fussion",
|
||||||
"1best",
|
"1best",
|
||||||
"nbest",
|
"nbest",
|
||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
@ -762,6 +892,16 @@ def main():
|
|||||||
params.suffix += f"_chunk-{params.chunk_size}"
|
params.suffix += f"_chunk-{params.chunk_size}"
|
||||||
params.suffix += f"_left-context-{params.left_context_frames}"
|
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||||
|
|
||||||
|
if "prefix-beam-search" in params.decoding_method:
|
||||||
|
params.suffix += f"_beam-{params.beam}"
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||||
|
if params.nnlm_scale != 0:
|
||||||
|
params.suffix += f"_nnlm-scale-{params.nnlm_scale}"
|
||||||
|
if params.lodr_lm_scale != 0:
|
||||||
|
params.suffix += f"_lodr-scale-{params.lodr_lm_scale}"
|
||||||
|
if params.context_score != 0:
|
||||||
|
params.suffix += f"_context_score-{params.context_score}"
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "_use-averaged-model"
|
params.suffix += "_use-averaged-model"
|
||||||
|
|
||||||
@ -771,6 +911,7 @@ def main():
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", 0)
|
device = torch.device("cuda", 0)
|
||||||
|
params.device = device
|
||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
@ -786,14 +927,24 @@ def main():
|
|||||||
params.sos_id = 1
|
params.sos_id = 1
|
||||||
|
|
||||||
if params.decoding_method in [
|
if params.decoding_method in [
|
||||||
"ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram"
|
"ctc-decoding",
|
||||||
|
"ctc-greedy-search",
|
||||||
|
"ctc-prefix-beam-search",
|
||||||
|
"ctc-prefix-beam-search-attention-decoder-rescoring",
|
||||||
|
"ctc-prefix-beam-search-shallow-fussion",
|
||||||
|
"attention-decoder-rescoring-no-ngram",
|
||||||
]:
|
]:
|
||||||
HLG = None
|
HLG = None
|
||||||
H = k2.ctc_topo(
|
H = None
|
||||||
max_token=max_token_id,
|
if params.decoding_method in [
|
||||||
modified=False,
|
"ctc-decoding",
|
||||||
device=device,
|
"attention-decoder-rescoring-no-ngram",
|
||||||
)
|
]:
|
||||||
|
H = k2.ctc_topo(
|
||||||
|
max_token=max_token_id,
|
||||||
|
modified=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
bpe_model = spm.SentencePieceProcessor()
|
||||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||||
else:
|
else:
|
||||||
@ -844,7 +995,8 @@ def main():
|
|||||||
G = k2.Fsa.from_dict(d)
|
G = k2.Fsa.from_dict(d)
|
||||||
|
|
||||||
if params.decoding_method in [
|
if params.decoding_method in [
|
||||||
"whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram"
|
"whole-lattice-rescoring",
|
||||||
|
"attention-decoder-rescoring-with-ngram",
|
||||||
]:
|
]:
|
||||||
# Add epsilon self-loops to G as we will compose
|
# Add epsilon self-loops to G as we will compose
|
||||||
# it with the whole lattice later
|
# it with the whole lattice later
|
||||||
@ -858,6 +1010,51 @@ def main():
|
|||||||
else:
|
else:
|
||||||
G = None
|
G = None
|
||||||
|
|
||||||
|
# only load the neural network LM if required
|
||||||
|
NNLM = None
|
||||||
|
if (
|
||||||
|
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
|
||||||
|
and params.nnlm_scale != 0
|
||||||
|
):
|
||||||
|
NNLM = LmScorer(
|
||||||
|
lm_type=params.nnlm_type,
|
||||||
|
params=params,
|
||||||
|
device=device,
|
||||||
|
lm_scale=params.nnlm_scale,
|
||||||
|
)
|
||||||
|
NNLM.to(device)
|
||||||
|
NNLM.eval()
|
||||||
|
|
||||||
|
LODR_lm = None
|
||||||
|
if (
|
||||||
|
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
|
||||||
|
and params.lodr_lm_scale != 0
|
||||||
|
):
|
||||||
|
assert os.path.exists(
|
||||||
|
params.lodr_ngram
|
||||||
|
), f"LODR ngram does not exists, given path : {params.lodr_ngram}"
|
||||||
|
logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}")
|
||||||
|
LODR_lm = NgramLm(
|
||||||
|
params.lodr_ngram,
|
||||||
|
backoff_id=params.backoff_id,
|
||||||
|
is_binary=False,
|
||||||
|
)
|
||||||
|
logging.info(f"num states: {LODR_lm.lm.num_states}")
|
||||||
|
|
||||||
|
context_graph = None
|
||||||
|
if (
|
||||||
|
params.decoding_method == "ctc-prefix-beam-search-shallow-fussion"
|
||||||
|
and params.context_score != 0
|
||||||
|
):
|
||||||
|
assert os.path.exists(
|
||||||
|
params.context_file
|
||||||
|
), f"context_file does not exists, given path : {params.context_file}"
|
||||||
|
contexts = []
|
||||||
|
for line in open(params.context_file).readlines():
|
||||||
|
contexts.append(bpe_model.encode(line.strip()))
|
||||||
|
context_graph = ContextGraph(params.context_score)
|
||||||
|
context_graph.build(contexts)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
|
|
||||||
@ -967,6 +1164,9 @@ def main():
|
|||||||
bpe_model=bpe_model,
|
bpe_model=bpe_model,
|
||||||
word_table=lexicon.word_table,
|
word_table=lexicon.word_table,
|
||||||
G=G,
|
G=G,
|
||||||
|
NNLM=NNLM,
|
||||||
|
LODR_lm=LODR_lm,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_asr_output(
|
save_asr_output(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user