From baa61723b69eeb4ad47513fc1ca0a0028c263e75 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 26 Sep 2024 15:27:03 +0800 Subject: [PATCH] add ctc prefix beam search attention decoder rescoring --- egs/gigaspeech/ASR/zipformer/ctc_decode.py | 24 ++++++++++++++++++++-- icefall/decode.py | 3 ++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/egs/gigaspeech/ASR/zipformer/ctc_decode.py b/egs/gigaspeech/ASR/zipformer/ctc_decode.py index a3405d4b9..ee55bff76 100755 --- a/egs/gigaspeech/ASR/zipformer/ctc_decode.py +++ b/egs/gigaspeech/ASR/zipformer/ctc_decode.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 # -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, # Liyong Guo, # Quandong Wang, -# Zengwei Yao) +# Zengwei Yao, +# Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -135,6 +136,7 @@ from icefall.checkpoint import ( from icefall.decode import ( ctc_greedy_search, ctc_prefix_beam_search, + ctc_prefix_beam_search_attention_decoder_rescoring, get_lattice, nbest_decoding, nbest_oracle, @@ -435,6 +437,22 @@ def decode_one_batch( key = "prefix-beam-search" return {key: hyps} + if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring": + best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring( + ctc_output=ctc_output, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + ans = dict() + for a_scale_str, token_ids in best_path_dict.items(): + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -760,6 +778,7 @@ def main(): assert params.decoding_method in ( "ctc-greedy-search", "prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", "ctc-decoding", "1best", "nbest", @@ -814,6 +833,7 @@ def main(): "ctc-greedy-search", "ctc-decoding", "prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", "attention-decoder-rescoring-no-ngram", ]: HLG = None diff --git a/icefall/decode.py b/icefall/decode.py index addbc3ff7..5ec9296e1 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1,4 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors #