mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add other decoding methods (nbest, nbest oracle, nbest LG) for wenetspeech pruned rnnt2 (#482)
* add other decoding methods for wenetspeech * changes for RESULTS.md * add ngram-lm-scale=0.35 results * set ngram-lm-scale=0.35 as default * Update README.md * add nbest-scale for flie name
This commit is contained in:
parent
34b4356bad
commit
1b478d3ac3
@ -257,8 +257,8 @@ We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer enc
|
||||
| | Dev | Test-Net | Test-Meeting |
|
||||
|----------------------|-------|----------|--------------|
|
||||
| greedy search | 7.80 | 8.75 | 13.49 |
|
||||
| modified beam search| 7.76 | 8.71 | 13.41 |
|
||||
| fast beam search | 7.94 | 8.74 | 13.80 |
|
||||
| modified beam search | 7.76 | 8.71 | 13.41 |
|
||||
|
||||
#### Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset)
|
||||
**Streaming**:
|
||||
|
@ -84,7 +84,10 @@ When training with the L subset, the CERs are
|
||||
|------------------------------------|-------|----------|--------------|------------------------------------------|
|
||||
| greedy search | 7.80 | 8.75 | 13.49 | --epoch 10, --avg 2, --max-duration 100 |
|
||||
| modified beam search (beam size 4) | 7.76 | 8.71 | 13.41 | --epoch 10, --avg 2, --max-duration 100 |
|
||||
| fast beam search (set as default) | 7.94 | 8.74 | 13.80 | --epoch 10, --avg 2, --max-duration 1500 |
|
||||
| fast beam search (1best) | 7.94 | 8.74 | 13.80 | --epoch 10, --avg 2, --max-duration 1500 |
|
||||
| fast beam search (nbest) | 9.82 | 10.98 | 16.37 | --epoch 10, --avg 2, --max-duration 600 |
|
||||
| fast beam search (nbest oracle) | 6.88 | 7.18 | 11.77 | --epoch 10, --avg 2, --max-duration 600 |
|
||||
| fast beam search (nbest LG, ngram_lm_scale=0.35) | 8.83 | 9.88 | 15.47 | --epoch 10, --avg 2, --max-duration 600 |
|
||||
|
||||
The training command for reproducing is given below:
|
||||
|
||||
@ -131,7 +134,7 @@ avg=2
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
## fast beam search
|
||||
## fast beam search (1best)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
@ -142,6 +145,47 @@ avg=2
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
|
||||
## fast beam search (nbest)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 10 \
|
||||
--avg 2 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
## fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 10 \
|
||||
--avg 2 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
## fast beam search (with LG)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 10 \
|
||||
--avg 2 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--ngram-lm-scale 0.35 \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
```
|
||||
|
||||
When training with the M subset, the CERs are
|
||||
|
1
egs/wenetspeech/ASR/local/compile_lg.py
Symbolic link
1
egs/wenetspeech/ASR/local/compile_lg.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/compile_lg.py
|
@ -225,3 +225,34 @@ if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
|
||||
--lang-dir data/lang_char
|
||||
fi
|
||||
fi
|
||||
|
||||
# If you don't want to use LG for decoding, the following steps are not necessary.
|
||||
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
|
||||
log "Stage 17: Prepare G"
|
||||
# It will take about 20 minutes.
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
lang_char_dir=data/lang_char
|
||||
if [ ! -f $lang_char_dir/3-gram.unpruned.arpa ]; then
|
||||
python ./shared/make_kn_lm.py \
|
||||
-ngram-order 3 \
|
||||
-text $lang_char_dir/text_words_segmentation \
|
||||
-lm $lang_char_dir/3-gram.unpruned.arpa
|
||||
fi
|
||||
|
||||
mkdir -p data/lm
|
||||
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
|
||||
# It is used in building LG
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="$lang_char_dir/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
$lang_char_dir/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
|
||||
log "Stage 18: Compile LG"
|
||||
lang_char_dir=data/lang_char
|
||||
python ./local/compile_lg.py --lang-dir $lang_char_dir
|
||||
fi
|
||||
|
@ -37,7 +37,7 @@ When training with the L subset, usage:
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) fast beam search
|
||||
(3) fast beam search (1best)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 10 \
|
||||
--avg 2 \
|
||||
@ -48,6 +48,46 @@ When training with the L subset, usage:
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
|
||||
(4) fast beam search (nbest)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 10 \
|
||||
--avg 2 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(5) fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 10 \
|
||||
--avg 2 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (with LG)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 10 \
|
||||
--avg 2 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
|
||||
@ -63,6 +103,9 @@ import torch.nn as nn
|
||||
from asr_datamodule import WenetSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
@ -70,6 +113,7 @@ from beam_search import (
|
||||
)
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
@ -151,6 +195,11 @@ def get_parser():
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
If you use fast_beam_search_nbest_LG, you have to
|
||||
specify `--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -173,6 +222,16 @@ def get_parser():
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.35,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
@ -204,6 +263,24 @@ def get_parser():
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -211,6 +288,7 @@ def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
@ -267,6 +345,50 @@ def decode_one_batch(
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
sentence = "".join([lexicon.word_table[i] for i in hyp])
|
||||
hyps.append(list(sentence))
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
@ -331,6 +453,7 @@ def decode_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
@ -373,6 +496,7 @@ def decode_dataset(
|
||||
params=params,
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
graph_compiler=graph_compiler,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
)
|
||||
@ -454,6 +578,9 @@ def main():
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
@ -463,6 +590,13 @@ def main():
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
if (
|
||||
params.decoding_method == "fast_beam_search_nbest"
|
||||
or params.decoding_method == "fast_beam_search_nbest_oracle"
|
||||
):
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
else:
|
||||
@ -482,6 +616,11 @@ def main():
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
@ -513,8 +652,18 @@ def main():
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lg_filename = params.lang_dir + "/LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device
|
||||
)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
@ -610,6 +759,7 @@ def main():
|
||||
params=params,
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
graph_compiler=graph_compiler,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
save_results(
|
||||
|
Loading…
x
Reference in New Issue
Block a user