Add prefix beam search to aishell

This commit is contained in:
pkufool 2024-09-29 12:00:45 +08:00
parent 906e833361
commit 33fa9e8b00
2 changed files with 108 additions and 10 deletions

View File

@ -123,6 +123,11 @@ from lhotse import set_caching_enabled
from lhotse.cut import Cut from lhotse.cut import Cut
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall.context_graph import ContextGraph, ContextState
from icefall.ngram_lm import NgramLm, NgramLmStateCost
from icefall.lm_wrapper import LmScorer
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -131,6 +136,9 @@ from icefall.checkpoint import (
) )
from icefall.decode import ( from icefall.decode import (
ctc_greedy_search, ctc_greedy_search,
ctc_prefix_beam_search,
ctc_prefix_beam_search_attention_decoder_rescoring,
ctc_prefix_beam_search_shallow_fussion,
get_lattice, get_lattice,
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder_no_ngram, rescore_with_attention_decoder_no_ngram,
@ -249,7 +257,24 @@ def get_parser():
"--skip-scoring", "--skip-scoring",
type=str2bool, type=str2bool,
default=False, default=False,
help="""Skip scoring, but still save the ASR output (for eval sets).""" help="""Skip scoring, but still save the ASR output (for eval sets).""",
)
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -262,8 +287,9 @@ def get_decoding_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"frame_shift_ms": 10, "frame_shift_ms": 10,
"search_beam": 20, "search_beam": 20, # for k2 fsa composition
"output_beam": 8, "output_beam": 8, # for k2 fsa composition
"beam": 4, # for prefix-beam-search
"min_active_states": 30, "min_active_states": 30,
"max_active_states": 10000, "max_active_states": 10000,
"use_double_scores": True, "use_double_scores": True,
@ -278,6 +304,7 @@ def decode_one_batch(
lexicon: Lexicon, lexicon: Lexicon,
batch: dict, batch: dict,
H: Optional[k2.Fsa], H: Optional[k2.Fsa],
LM: Optional[LmScorer] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -355,6 +382,43 @@ def decode_one_batch(
key = "ctc-greedy-search" key = "ctc-greedy-search"
return {key: hyps} return {key: hyps}
if params.decoding_method == "prefix-beam-search":
hyp_tokens = ctc_prefix_beam_search(
ctc_output=ctc_output, encoder_out_lens=encoder_out_lens
)
hyps = []
for i in range(batch_size):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
key = "prefix-beam-search"
return {key: hyps}
if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring":
best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring(
ctc_output=ctc_output,
attention_decoder=model.attention_decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
ans = dict()
for a_scale_str, hyp_tokens in best_path_dict.items():
hyps = []
for i in range(batch_size):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
ans[a_scale_str] = hyps
return ans
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
hyp_tokens = ctc_prefix_beam_search_shallow_fussion(
ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens,
LM=LM,
)
hyps = []
for i in range(batch_size):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
key = "prefix-beam-search-shallow-fussion"
return {key: hyps}
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (
supervisions["sequence_idx"], supervisions["sequence_idx"],
@ -428,6 +492,7 @@ def decode_dataset(
model: nn.Module, model: nn.Module,
lexicon: Lexicon, lexicon: Lexicon,
H: Optional[k2.Fsa] = None, H: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -476,6 +541,7 @@ def decode_dataset(
batch=batch, batch=batch,
lexicon=lexicon, lexicon=lexicon,
H=H, H=H,
LM=LM,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -530,7 +596,9 @@ def save_wer_results(
for key, results in results_dict.items(): for key, results in results_dict.items():
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w", encoding="utf8") as fd: with open(errs_filename, "w", encoding="utf8") as fd:
wer = write_error_stats( wer = write_error_stats(
fd, fd,
@ -545,7 +613,9 @@ def save_wer_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" wer_filename = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(wer_filename, "w", encoding="utf8") as fd: with open(wer_filename, "w", encoding="utf8") as fd:
print("settings\tWER", file=fd) print("settings\tWER", file=fd)
@ -564,6 +634,7 @@ def save_wer_results(
def main(): def main():
parser = get_parser() parser = get_parser()
AishellAsrDataModule.add_arguments(parser) AishellAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir) args.lang_dir = Path(args.lang_dir)
@ -578,15 +649,18 @@ def main():
assert params.decoding_method in ( assert params.decoding_method in (
"ctc-greedy-search", "ctc-greedy-search",
"prefix-beam-search",
"ctc-prefix-beam-search-attention-decoder-rescoring",
"ctc-prefix-beam-search-shallow-fussion",
"ctc-decoding", "ctc-decoding",
"attention-decoder-rescoring-no-ngram", "attention-decoder-rescoring-no-ngram",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0: if params.iter > 0:
params.suffix = f"iter-{params.iter}_avg-{params.avg}" params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else: else:
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.causal: if params.causal:
assert ( assert (
@ -598,6 +672,11 @@ def main():
params.suffix += f"_chunk-{params.chunk_size}" params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"_left-context-{params.left_context_frames}" params.suffix += f"_left-context-{params.left_context_frames}"
if "prefix-beam-search" in params.decoding_method:
params.suffix += f"_beam-{params.beam}"
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
params.suffix += f"_lm-scale-{params.lm_scale}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "_use-averaged-model" params.suffix += "_use-averaged-model"
@ -621,7 +700,10 @@ def main():
params.eos_id = 1 params.eos_id = 1
params.sos_id = 1 params.sos_id = 1
if params.decoding_method != "ctc-greedy-search": if params.decoding_method in [
"ctc-decoding",
"attention-decoder-rescoring-no-ngram",
]:
H = k2.ctc_topo( H = k2.ctc_topo(
max_token=max_token_id, max_token=max_token_id,
modified=True, modified=True,
@ -630,6 +712,19 @@ def main():
else: else:
H = None H = None
# only load the neural network LM if required
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
LM.to(device)
LM.eval()
else:
LM = None
logging.info("About to create model") logging.info("About to create model")
model = get_model(params) model = get_model(params)
@ -746,6 +841,7 @@ def main():
model=model, model=model,
H=H, H=H,
lexicon=lexicon, lexicon=lexicon,
LM=LM,
) )
save_asr_output( save_asr_output(

View File

@ -1064,11 +1064,13 @@ def main():
gigaspeech = GigaSpeechAsrDataModule(args) gigaspeech = GigaSpeechAsrDataModule(args)
test_cuts = gigaspeech.test_cuts() test_cuts = gigaspeech.test_cuts()
dev_cuts = gigaspeech.dev_cuts()
test_dl = gigaspeech.test_dataloaders(test_cuts) test_dl = gigaspeech.test_dataloaders(test_cuts)
dev_dl = gigaspeech.test_dataloaders(dev_cuts)
test_sets = ["test"] test_sets = ["test", "dev"]
test_dls = [test_dl] test_dls = [test_dl, dev_dl]
for test_set, test_dl in zip(test_sets, test_dls): for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset( results_dict = decode_dataset(