mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Add prefix beam search to aishell
This commit is contained in:
parent
906e833361
commit
33fa9e8b00
@ -123,6 +123,11 @@ from lhotse import set_caching_enabled
|
||||
from lhotse.cut import Cut
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.context_graph import ContextGraph, ContextState
|
||||
from icefall.ngram_lm import NgramLm, NgramLmStateCost
|
||||
from icefall.lm_wrapper import LmScorer
|
||||
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -131,6 +136,9 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.decode import (
|
||||
ctc_greedy_search,
|
||||
ctc_prefix_beam_search,
|
||||
ctc_prefix_beam_search_attention_decoder_rescoring,
|
||||
ctc_prefix_beam_search_shallow_fussion,
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder_no_ngram,
|
||||
@ -249,7 +257,24 @@ def get_parser():
|
||||
"--skip-scoring",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Skip scoring, but still save the ASR output (for eval sets)."""
|
||||
help="""Skip scoring, but still save the ASR output (for eval sets).""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-type",
|
||||
type=str,
|
||||
default="rnn",
|
||||
help="Type of NN lm",
|
||||
choices=["rnn", "transformer"],
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-scale",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="""The scale of the neural network LM
|
||||
Used only when `--use-shallow-fusion` is set to True.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -262,8 +287,9 @@ def get_decoding_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"frame_shift_ms": 10,
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
"search_beam": 20, # for k2 fsa composition
|
||||
"output_beam": 8, # for k2 fsa composition
|
||||
"beam": 4, # for prefix-beam-search
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
@ -278,6 +304,7 @@ def decode_one_batch(
|
||||
lexicon: Lexicon,
|
||||
batch: dict,
|
||||
H: Optional[k2.Fsa],
|
||||
LM: Optional[LmScorer] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -355,6 +382,43 @@ def decode_one_batch(
|
||||
key = "ctc-greedy-search"
|
||||
return {key: hyps}
|
||||
|
||||
if params.decoding_method == "prefix-beam-search":
|
||||
hyp_tokens = ctc_prefix_beam_search(
|
||||
ctc_output=ctc_output, encoder_out_lens=encoder_out_lens
|
||||
)
|
||||
hyps = []
|
||||
for i in range(batch_size):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
key = "prefix-beam-search"
|
||||
return {key: hyps}
|
||||
|
||||
if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring":
|
||||
best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring(
|
||||
ctc_output=ctc_output,
|
||||
attention_decoder=model.attention_decoder,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
ans = dict()
|
||||
for a_scale_str, hyp_tokens in best_path_dict.items():
|
||||
hyps = []
|
||||
for i in range(batch_size):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
ans[a_scale_str] = hyps
|
||||
return ans
|
||||
|
||||
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||
hyp_tokens = ctc_prefix_beam_search_shallow_fussion(
|
||||
ctc_output=ctc_output,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
LM=LM,
|
||||
)
|
||||
hyps = []
|
||||
for i in range(batch_size):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
key = "prefix-beam-search-shallow-fussion"
|
||||
return {key: hyps}
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
@ -428,6 +492,7 @@ def decode_dataset(
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
H: Optional[k2.Fsa] = None,
|
||||
LM: Optional[LmScorer] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -476,6 +541,7 @@ def decode_dataset(
|
||||
batch=batch,
|
||||
lexicon=lexicon,
|
||||
H=H,
|
||||
LM=LM,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
@ -530,7 +596,9 @@ def save_wer_results(
|
||||
for key, results in results_dict.items():
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w", encoding="utf8") as fd:
|
||||
wer = write_error_stats(
|
||||
fd,
|
||||
@ -545,7 +613,9 @@ def save_wer_results(
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
|
||||
wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
wer_filename = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
|
||||
with open(wer_filename, "w", encoding="utf8") as fd:
|
||||
print("settings\tWER", file=fd)
|
||||
@ -564,6 +634,7 @@ def save_wer_results(
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AishellAsrDataModule.add_arguments(parser)
|
||||
LmScorer.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
@ -578,15 +649,18 @@ def main():
|
||||
|
||||
assert params.decoding_method in (
|
||||
"ctc-greedy-search",
|
||||
"prefix-beam-search",
|
||||
"ctc-prefix-beam-search-attention-decoder-rescoring",
|
||||
"ctc-prefix-beam-search-shallow-fussion",
|
||||
"ctc-decoding",
|
||||
"attention-decoder-rescoring-no-ngram",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
@ -598,6 +672,11 @@ def main():
|
||||
params.suffix += f"_chunk-{params.chunk_size}"
|
||||
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||
|
||||
if "prefix-beam-search" in params.decoding_method:
|
||||
params.suffix += f"_beam-{params.beam}"
|
||||
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||
params.suffix += f"_lm-scale-{params.lm_scale}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "_use-averaged-model"
|
||||
|
||||
@ -621,7 +700,10 @@ def main():
|
||||
params.eos_id = 1
|
||||
params.sos_id = 1
|
||||
|
||||
if params.decoding_method != "ctc-greedy-search":
|
||||
if params.decoding_method in [
|
||||
"ctc-decoding",
|
||||
"attention-decoder-rescoring-no-ngram",
|
||||
]:
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=True,
|
||||
@ -630,6 +712,19 @@ def main():
|
||||
else:
|
||||
H = None
|
||||
|
||||
# only load the neural network LM if required
|
||||
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||
LM = LmScorer(
|
||||
lm_type=params.lm_type,
|
||||
params=params,
|
||||
device=device,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
LM.to(device)
|
||||
LM.eval()
|
||||
else:
|
||||
LM = None
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
@ -746,6 +841,7 @@ def main():
|
||||
model=model,
|
||||
H=H,
|
||||
lexicon=lexicon,
|
||||
LM=LM,
|
||||
)
|
||||
|
||||
save_asr_output(
|
||||
|
@ -1064,11 +1064,13 @@ def main():
|
||||
gigaspeech = GigaSpeechAsrDataModule(args)
|
||||
|
||||
test_cuts = gigaspeech.test_cuts()
|
||||
dev_cuts = gigaspeech.dev_cuts()
|
||||
|
||||
test_dl = gigaspeech.test_dataloaders(test_cuts)
|
||||
dev_dl = gigaspeech.test_dataloaders(dev_cuts)
|
||||
|
||||
test_sets = ["test"]
|
||||
test_dls = [test_dl]
|
||||
test_sets = ["test", "dev"]
|
||||
test_dls = [test_dl, dev_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
|
Loading…
x
Reference in New Issue
Block a user