mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add ctc prefix beam search attention decoder rescoring
This commit is contained in:
parent
0c096a9ab4
commit
baa61723b6
@ -1,9 +1,10 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
# Liyong Guo,
|
# Liyong Guo,
|
||||||
# Quandong Wang,
|
# Quandong Wang,
|
||||||
# Zengwei Yao)
|
# Zengwei Yao,
|
||||||
|
# Wei Kang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -135,6 +136,7 @@ from icefall.checkpoint import (
|
|||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
ctc_greedy_search,
|
ctc_greedy_search,
|
||||||
ctc_prefix_beam_search,
|
ctc_prefix_beam_search,
|
||||||
|
ctc_prefix_beam_search_attention_decoder_rescoring,
|
||||||
get_lattice,
|
get_lattice,
|
||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
nbest_oracle,
|
nbest_oracle,
|
||||||
@ -435,6 +437,22 @@ def decode_one_batch(
|
|||||||
key = "prefix-beam-search"
|
key = "prefix-beam-search"
|
||||||
return {key: hyps}
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring":
|
||||||
|
best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring(
|
||||||
|
ctc_output=ctc_output,
|
||||||
|
attention_decoder=model.attention_decoder,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
|
ans = dict()
|
||||||
|
for a_scale_str, token_ids in best_path_dict.items():
|
||||||
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
|
hyps = bpe_model.decode(token_ids)
|
||||||
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||||
|
hyps = [s.split() for s in hyps]
|
||||||
|
ans[a_scale_str] = hyps
|
||||||
|
return ans
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
supervision_segments = torch.stack(
|
||||||
(
|
(
|
||||||
supervisions["sequence_idx"],
|
supervisions["sequence_idx"],
|
||||||
@ -760,6 +778,7 @@ def main():
|
|||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"ctc-greedy-search",
|
"ctc-greedy-search",
|
||||||
"prefix-beam-search",
|
"prefix-beam-search",
|
||||||
|
"ctc-prefix-beam-search-attention-decoder-rescoring",
|
||||||
"ctc-decoding",
|
"ctc-decoding",
|
||||||
"1best",
|
"1best",
|
||||||
"nbest",
|
"nbest",
|
||||||
@ -814,6 +833,7 @@ def main():
|
|||||||
"ctc-greedy-search",
|
"ctc-greedy-search",
|
||||||
"ctc-decoding",
|
"ctc-decoding",
|
||||||
"prefix-beam-search",
|
"prefix-beam-search",
|
||||||
|
"ctc-prefix-beam-search-attention-decoder-rescoring",
|
||||||
"attention-decoder-rescoring-no-ngram",
|
"attention-decoder-rescoring-no-ngram",
|
||||||
]:
|
]:
|
||||||
HLG = None
|
HLG = None
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Wei Kang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
|
Loading…
x
Reference in New Issue
Block a user