mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
add decoding method of ctc-greedy-search in zipformer recipe (#1690)
This commit is contained in:
parent
334beed2af
commit
d47c078286
@ -21,7 +21,16 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(1) ctc-decoding
|
||||
(1) ctc-greedy-search
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method ctc-greedy-search
|
||||
|
||||
(2) ctc-decoding
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
@ -30,7 +39,7 @@ Usage:
|
||||
--max-duration 600 \
|
||||
--decoding-method ctc-decoding
|
||||
|
||||
(2) 1best
|
||||
(3) 1best
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
@ -40,7 +49,7 @@ Usage:
|
||||
--hlg-scale 0.6 \
|
||||
--decoding-method 1best
|
||||
|
||||
(3) nbest
|
||||
(4) nbest
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
@ -50,7 +59,7 @@ Usage:
|
||||
--hlg-scale 0.6 \
|
||||
--decoding-method nbest
|
||||
|
||||
(4) nbest-rescoring
|
||||
(5) nbest-rescoring
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
@ -62,7 +71,7 @@ Usage:
|
||||
--lm-dir data/lm \
|
||||
--decoding-method nbest-rescoring
|
||||
|
||||
(5) whole-lattice-rescoring
|
||||
(6) whole-lattice-rescoring
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
@ -74,7 +83,7 @@ Usage:
|
||||
--lm-dir data/lm \
|
||||
--decoding-method whole-lattice-rescoring
|
||||
|
||||
(6) attention-decoder-rescoring-no-ngram
|
||||
(7) attention-decoder-rescoring-no-ngram
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
@ -84,7 +93,7 @@ Usage:
|
||||
--max-duration 100 \
|
||||
--decoding-method attention-decoder-rescoring-no-ngram
|
||||
|
||||
(7) attention-decoder-rescoring-with-ngram
|
||||
(8) attention-decoder-rescoring-with-ngram
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
@ -120,6 +129,7 @@ from icefall.checkpoint import (
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import (
|
||||
ctc_greedy_search,
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
@ -220,26 +230,29 @@ def get_parser():
|
||||
default="ctc-decoding",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||
- (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
- (2) 1best. Extract the best path from the decoding lattice as the
|
||||
- (2) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
- (3) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
- (3) nbest. Extract n paths from the decoding lattice; the path
|
||||
- (4) nbest. Extract n paths from the decoding lattice; the path
|
||||
with the highest score is the decoding result.
|
||||
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
- (5) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||
the highest score is the decoding result.
|
||||
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||
- (6) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||
is the decoding result.
|
||||
you have trained an RNN LM using ./rnn_lm/train.py
|
||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
- (7) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
- (7) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
|
||||
- (8) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding
|
||||
lattice, rescore them with the attention decoder.
|
||||
- (8) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
|
||||
- (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM
|
||||
rescored lattice, rescore them with the attention decoder.
|
||||
""",
|
||||
)
|
||||
@ -381,6 +394,15 @@ def decode_one_batch(
|
||||
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
|
||||
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
if params.decoding_method == "ctc-greedy-search":
|
||||
hyps = ctc_greedy_search(ctc_output, encoder_out_lens)
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(hyps)
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "ctc-greedy-search"
|
||||
return {key: hyps}
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
@ -684,6 +706,7 @@ def main():
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"ctc-greedy-search",
|
||||
"ctc-decoding",
|
||||
"1best",
|
||||
"nbest",
|
||||
@ -733,7 +756,9 @@ def main():
|
||||
params.eos_id = 1
|
||||
params.sos_id = 1
|
||||
|
||||
if params.decoding_method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]:
|
||||
if params.decoding_method in [
|
||||
"ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram"
|
||||
]:
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
|
@ -1473,3 +1473,34 @@ def rescore_with_rnn_lm(
|
||||
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}_rnn_lm_scale_{r_scale}" # noqa
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
||||
|
||||
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
||||
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
|
||||
new_hyp: List[int] = []
|
||||
cur = 0
|
||||
while cur < len(hyp):
|
||||
if hyp[cur] != 0:
|
||||
new_hyp.append(hyp[cur])
|
||||
prev = cur
|
||||
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
||||
cur += 1
|
||||
return new_hyp
|
||||
|
||||
|
||||
def ctc_greedy_search(
|
||||
ctc_output: torch.Tensor, encoder_out_lens: torch.Tensor
|
||||
) -> List[List[int]]:
|
||||
"""CTC greedy search.
|
||||
|
||||
Args:
|
||||
ctc_output: (batch, seq_len, vocab_size)
|
||||
encoder_out_lens: (batch,)
|
||||
Returns:
|
||||
List[List[int]]: greedy search result
|
||||
"""
|
||||
batch = ctc_output.shape[0]
|
||||
index = ctc_output.argmax(dim=-1) # (batch, seq_len)
|
||||
hyps = [index[i].tolist()[:encoder_out_lens[i]] for i in range(batch)]
|
||||
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
|
||||
return hyps
|
||||
|
Loading…
x
Reference in New Issue
Block a user