mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Add prefix beam search to aishell
This commit is contained in:
parent
906e833361
commit
33fa9e8b00
@ -123,6 +123,11 @@ from lhotse import set_caching_enabled
|
|||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
from icefall.context_graph import ContextGraph, ContextState
|
||||||
|
from icefall.ngram_lm import NgramLm, NgramLmStateCost
|
||||||
|
from icefall.lm_wrapper import LmScorer
|
||||||
|
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -131,6 +136,9 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
ctc_greedy_search,
|
ctc_greedy_search,
|
||||||
|
ctc_prefix_beam_search,
|
||||||
|
ctc_prefix_beam_search_attention_decoder_rescoring,
|
||||||
|
ctc_prefix_beam_search_shallow_fussion,
|
||||||
get_lattice,
|
get_lattice,
|
||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_attention_decoder_no_ngram,
|
rescore_with_attention_decoder_no_ngram,
|
||||||
@ -249,7 +257,24 @@ def get_parser():
|
|||||||
"--skip-scoring",
|
"--skip-scoring",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="""Skip scoring, but still save the ASR output (for eval sets)."""
|
help="""Skip scoring, but still save the ASR output (for eval sets).""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-type",
|
||||||
|
type=str,
|
||||||
|
default="rnn",
|
||||||
|
help="Type of NN lm",
|
||||||
|
choices=["rnn", "transformer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.3,
|
||||||
|
help="""The scale of the neural network LM
|
||||||
|
Used only when `--use-shallow-fusion` is set to True.
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -262,8 +287,9 @@ def get_decoding_params() -> AttributeDict:
|
|||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"frame_shift_ms": 10,
|
"frame_shift_ms": 10,
|
||||||
"search_beam": 20,
|
"search_beam": 20, # for k2 fsa composition
|
||||||
"output_beam": 8,
|
"output_beam": 8, # for k2 fsa composition
|
||||||
|
"beam": 4, # for prefix-beam-search
|
||||||
"min_active_states": 30,
|
"min_active_states": 30,
|
||||||
"max_active_states": 10000,
|
"max_active_states": 10000,
|
||||||
"use_double_scores": True,
|
"use_double_scores": True,
|
||||||
@ -278,6 +304,7 @@ def decode_one_batch(
|
|||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
H: Optional[k2.Fsa],
|
H: Optional[k2.Fsa],
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -355,6 +382,43 @@ def decode_one_batch(
|
|||||||
key = "ctc-greedy-search"
|
key = "ctc-greedy-search"
|
||||||
return {key: hyps}
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.decoding_method == "prefix-beam-search":
|
||||||
|
hyp_tokens = ctc_prefix_beam_search(
|
||||||
|
ctc_output=ctc_output, encoder_out_lens=encoder_out_lens
|
||||||
|
)
|
||||||
|
hyps = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
|
key = "prefix-beam-search"
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring":
|
||||||
|
best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring(
|
||||||
|
ctc_output=ctc_output,
|
||||||
|
attention_decoder=model.attention_decoder,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
|
ans = dict()
|
||||||
|
for a_scale_str, hyp_tokens in best_path_dict.items():
|
||||||
|
hyps = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
|
ans[a_scale_str] = hyps
|
||||||
|
return ans
|
||||||
|
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||||
|
hyp_tokens = ctc_prefix_beam_search_shallow_fussion(
|
||||||
|
ctc_output=ctc_output,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
LM=LM,
|
||||||
|
)
|
||||||
|
hyps = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
|
key = "prefix-beam-search-shallow-fussion"
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
supervision_segments = torch.stack(
|
||||||
(
|
(
|
||||||
supervisions["sequence_idx"],
|
supervisions["sequence_idx"],
|
||||||
@ -428,6 +492,7 @@ def decode_dataset(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
H: Optional[k2.Fsa] = None,
|
H: Optional[k2.Fsa] = None,
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -476,6 +541,7 @@ def decode_dataset(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
H=H,
|
H=H,
|
||||||
|
LM=LM,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -530,7 +596,9 @@ def save_wer_results(
|
|||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
errs_filename = (
|
||||||
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
with open(errs_filename, "w", encoding="utf8") as fd:
|
with open(errs_filename, "w", encoding="utf8") as fd:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
fd,
|
fd,
|
||||||
@ -545,7 +613,9 @@ def save_wer_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
|
||||||
wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
wer_filename = (
|
||||||
|
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
|
||||||
with open(wer_filename, "w", encoding="utf8") as fd:
|
with open(wer_filename, "w", encoding="utf8") as fd:
|
||||||
print("settings\tWER", file=fd)
|
print("settings\tWER", file=fd)
|
||||||
@ -564,6 +634,7 @@ def save_wer_results(
|
|||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
AishellAsrDataModule.add_arguments(parser)
|
AishellAsrDataModule.add_arguments(parser)
|
||||||
|
LmScorer.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
args.lang_dir = Path(args.lang_dir)
|
||||||
@ -578,15 +649,18 @@ def main():
|
|||||||
|
|
||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"ctc-greedy-search",
|
"ctc-greedy-search",
|
||||||
|
"prefix-beam-search",
|
||||||
|
"ctc-prefix-beam-search-attention-decoder-rescoring",
|
||||||
|
"ctc-prefix-beam-search-shallow-fussion",
|
||||||
"ctc-decoding",
|
"ctc-decoding",
|
||||||
"attention-decoder-rescoring-no-ngram",
|
"attention-decoder-rescoring-no-ngram",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
params.suffix = f"iter-{params.iter}_avg-{params.avg}"
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}_avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if params.causal:
|
if params.causal:
|
||||||
assert (
|
assert (
|
||||||
@ -598,6 +672,11 @@ def main():
|
|||||||
params.suffix += f"_chunk-{params.chunk_size}"
|
params.suffix += f"_chunk-{params.chunk_size}"
|
||||||
params.suffix += f"_left-context-{params.left_context_frames}"
|
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||||
|
|
||||||
|
if "prefix-beam-search" in params.decoding_method:
|
||||||
|
params.suffix += f"_beam-{params.beam}"
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||||
|
params.suffix += f"_lm-scale-{params.lm_scale}"
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "_use-averaged-model"
|
params.suffix += "_use-averaged-model"
|
||||||
|
|
||||||
@ -621,7 +700,10 @@ def main():
|
|||||||
params.eos_id = 1
|
params.eos_id = 1
|
||||||
params.sos_id = 1
|
params.sos_id = 1
|
||||||
|
|
||||||
if params.decoding_method != "ctc-greedy-search":
|
if params.decoding_method in [
|
||||||
|
"ctc-decoding",
|
||||||
|
"attention-decoder-rescoring-no-ngram",
|
||||||
|
]:
|
||||||
H = k2.ctc_topo(
|
H = k2.ctc_topo(
|
||||||
max_token=max_token_id,
|
max_token=max_token_id,
|
||||||
modified=True,
|
modified=True,
|
||||||
@ -630,6 +712,19 @@ def main():
|
|||||||
else:
|
else:
|
||||||
H = None
|
H = None
|
||||||
|
|
||||||
|
# only load the neural network LM if required
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||||
|
LM = LmScorer(
|
||||||
|
lm_type=params.lm_type,
|
||||||
|
params=params,
|
||||||
|
device=device,
|
||||||
|
lm_scale=params.lm_scale,
|
||||||
|
)
|
||||||
|
LM.to(device)
|
||||||
|
LM.eval()
|
||||||
|
else:
|
||||||
|
LM = None
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
|
|
||||||
@ -746,6 +841,7 @@ def main():
|
|||||||
model=model,
|
model=model,
|
||||||
H=H,
|
H=H,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
|
LM=LM,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_asr_output(
|
save_asr_output(
|
||||||
|
@ -1064,11 +1064,13 @@ def main():
|
|||||||
gigaspeech = GigaSpeechAsrDataModule(args)
|
gigaspeech = GigaSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_cuts = gigaspeech.test_cuts()
|
test_cuts = gigaspeech.test_cuts()
|
||||||
|
dev_cuts = gigaspeech.dev_cuts()
|
||||||
|
|
||||||
test_dl = gigaspeech.test_dataloaders(test_cuts)
|
test_dl = gigaspeech.test_dataloaders(test_cuts)
|
||||||
|
dev_dl = gigaspeech.test_dataloaders(dev_cuts)
|
||||||
|
|
||||||
test_sets = ["test"]
|
test_sets = ["test", "dev"]
|
||||||
test_dls = [test_dl]
|
test_dls = [test_dl, dev_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dls):
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user