add ctc prefix beam search attention decoder rescoring

This commit is contained in:
pkufool 2024-09-26 15:27:03 +08:00
parent 0c096a9ab4
commit baa61723b6
2 changed files with 24 additions and 3 deletions

View File

@ -1,9 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, # Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
# Liyong Guo, # Liyong Guo,
# Quandong Wang, # Quandong Wang,
# Zengwei Yao) # Zengwei Yao,
# Wei Kang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -135,6 +136,7 @@ from icefall.checkpoint import (
from icefall.decode import ( from icefall.decode import (
ctc_greedy_search, ctc_greedy_search,
ctc_prefix_beam_search, ctc_prefix_beam_search,
ctc_prefix_beam_search_attention_decoder_rescoring,
get_lattice, get_lattice,
nbest_decoding, nbest_decoding,
nbest_oracle, nbest_oracle,
@ -435,6 +437,22 @@ def decode_one_batch(
key = "prefix-beam-search" key = "prefix-beam-search"
return {key: hyps} return {key: hyps}
if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring":
best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring(
ctc_output=ctc_output,
attention_decoder=model.attention_decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
ans = dict()
for a_scale_str, token_ids in best_path_dict.items():
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
ans[a_scale_str] = hyps
return ans
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (
supervisions["sequence_idx"], supervisions["sequence_idx"],
@ -760,6 +778,7 @@ def main():
assert params.decoding_method in ( assert params.decoding_method in (
"ctc-greedy-search", "ctc-greedy-search",
"prefix-beam-search", "prefix-beam-search",
"ctc-prefix-beam-search-attention-decoder-rescoring",
"ctc-decoding", "ctc-decoding",
"1best", "1best",
"nbest", "nbest",
@ -814,6 +833,7 @@ def main():
"ctc-greedy-search", "ctc-greedy-search",
"ctc-decoding", "ctc-decoding",
"prefix-beam-search", "prefix-beam-search",
"ctc-prefix-beam-search-attention-decoder-rescoring",
"attention-decoder-rescoring-no-ngram", "attention-decoder-rescoring-no-ngram",
]: ]:
HLG = None HLG = None

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #