add more decoding methods

This commit is contained in:
Yuekai Zhang 2022-07-13 02:21:47 +00:00
parent 2e619be9b0
commit 3016440035
4 changed files with 113 additions and 10 deletions

View File

@ -69,3 +69,5 @@ When training with context size equals to 2, the WERs are
The tensorboard training log can be found at
https://tensorboard.dev/experiment/5AxJ8LHoSre8kDAuLp4L7Q/#scalars
A pre-trained model and decoding logs can be found at <https://huggingface.co/yuekai/icefall-asr-aishell2-pruned-transducer-stateless5-A-2022-07-12>

View File

@ -32,7 +32,7 @@ def main():
paths = [
"./data/fbank/aishell2_cuts_train.jsonl.gz",
"./data/fbank/aishell2_cuts_dev.jsonl.gz",
"./data/fbank/aishell2_cuts_test.jsonl.gz"
"./data/fbank/aishell2_cuts_test.jsonl.gz",
]
for path in paths:
@ -44,7 +44,7 @@ def main():
if __name__ == "__main__":
main()
'''
"""
Starting display the statistics for ./data/fbank/aishell2_cuts_train.jsonl.gz
Cuts count: 3026106
Total duration (hours): 3021.2
@ -93,4 +93,4 @@ min 1.1
99.5% 6.6
99.9% 7.7
max 8.5
'''
"""

View File

@ -9,6 +9,7 @@ stop_stage=5
# We assume dl_dir (download dir) contains the following
# directories and files. If not, you need to apply aishell2 through
# their official website.
# https://www.aishelltech.com/aishell_2
#
# - $dl_dir/aishell2
#

View File

@ -24,28 +24,80 @@ Usage:
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 100 \
--max-duration 600 \
--decoding-method greedy_search
(2) modified beam search
(2) beam search (not recommended)
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 100 \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(3) fast beam search
(4) fast beam search (one best)
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 1500 \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
(5) fast beam search (nbest)
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
@ -61,6 +113,9 @@ import torch.nn as nn
from asr_datamodule import AiShell2AsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
@ -273,8 +328,6 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
@ -310,6 +363,49 @@ def decode_one_batch(
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([lexicon.word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
nbest_scale=params.nbest_scale,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
@ -498,7 +594,11 @@ def main():
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method