mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Add other decoding methods (nbest, nbest oracle, nbest LG) for wenetspeech pruned rnnt2 (#482)
* add other decoding methods for wenetspeech * changes for RESULTS.md * add ngram-lm-scale=0.35 results * set ngram-lm-scale=0.35 as default * Update README.md * add nbest-scale for flie name
This commit is contained in:
parent
34b4356bad
commit
1b478d3ac3
@ -257,8 +257,8 @@ We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer enc
|
|||||||
| | Dev | Test-Net | Test-Meeting |
|
| | Dev | Test-Net | Test-Meeting |
|
||||||
|----------------------|-------|----------|--------------|
|
|----------------------|-------|----------|--------------|
|
||||||
| greedy search | 7.80 | 8.75 | 13.49 |
|
| greedy search | 7.80 | 8.75 | 13.49 |
|
||||||
|
| modified beam search| 7.76 | 8.71 | 13.41 |
|
||||||
| fast beam search | 7.94 | 8.74 | 13.80 |
|
| fast beam search | 7.94 | 8.74 | 13.80 |
|
||||||
| modified beam search | 7.76 | 8.71 | 13.41 |
|
|
||||||
|
|
||||||
#### Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset)
|
#### Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset)
|
||||||
**Streaming**:
|
**Streaming**:
|
||||||
|
@ -84,7 +84,10 @@ When training with the L subset, the CERs are
|
|||||||
|------------------------------------|-------|----------|--------------|------------------------------------------|
|
|------------------------------------|-------|----------|--------------|------------------------------------------|
|
||||||
| greedy search | 7.80 | 8.75 | 13.49 | --epoch 10, --avg 2, --max-duration 100 |
|
| greedy search | 7.80 | 8.75 | 13.49 | --epoch 10, --avg 2, --max-duration 100 |
|
||||||
| modified beam search (beam size 4) | 7.76 | 8.71 | 13.41 | --epoch 10, --avg 2, --max-duration 100 |
|
| modified beam search (beam size 4) | 7.76 | 8.71 | 13.41 | --epoch 10, --avg 2, --max-duration 100 |
|
||||||
| fast beam search (set as default) | 7.94 | 8.74 | 13.80 | --epoch 10, --avg 2, --max-duration 1500 |
|
| fast beam search (1best) | 7.94 | 8.74 | 13.80 | --epoch 10, --avg 2, --max-duration 1500 |
|
||||||
|
| fast beam search (nbest) | 9.82 | 10.98 | 16.37 | --epoch 10, --avg 2, --max-duration 600 |
|
||||||
|
| fast beam search (nbest oracle) | 6.88 | 7.18 | 11.77 | --epoch 10, --avg 2, --max-duration 600 |
|
||||||
|
| fast beam search (nbest LG, ngram_lm_scale=0.35) | 8.83 | 9.88 | 15.47 | --epoch 10, --avg 2, --max-duration 600 |
|
||||||
|
|
||||||
The training command for reproducing is given below:
|
The training command for reproducing is given below:
|
||||||
|
|
||||||
@ -131,7 +134,7 @@ avg=2
|
|||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
## fast beam search
|
## fast beam search (1best)
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--epoch $epoch \
|
--epoch $epoch \
|
||||||
--avg $avg \
|
--avg $avg \
|
||||||
@ -142,6 +145,47 @@ avg=2
|
|||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
|
|
||||||
|
## fast beam search (nbest)
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--epoch 10 \
|
||||||
|
--avg 2 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
## fast beam search (nbest oracle WER)
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--epoch 10 \
|
||||||
|
--avg 2 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest_oracle \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
## fast beam search (with LG)
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--epoch 10 \
|
||||||
|
--avg 2 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest_LG \
|
||||||
|
--ngram-lm-scale 0.35 \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64
|
||||||
```
|
```
|
||||||
|
|
||||||
When training with the M subset, the CERs are
|
When training with the M subset, the CERs are
|
||||||
|
1
egs/wenetspeech/ASR/local/compile_lg.py
Symbolic link
1
egs/wenetspeech/ASR/local/compile_lg.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/local/compile_lg.py
|
@ -225,3 +225,34 @@ if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
|
|||||||
--lang-dir data/lang_char
|
--lang-dir data/lang_char
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# If you don't want to use LG for decoding, the following steps are not necessary.
|
||||||
|
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
|
||||||
|
log "Stage 17: Prepare G"
|
||||||
|
# It will take about 20 minutes.
|
||||||
|
# We assume you have install kaldilm, if not, please install
|
||||||
|
# it using: pip install kaldilm
|
||||||
|
lang_char_dir=data/lang_char
|
||||||
|
if [ ! -f $lang_char_dir/3-gram.unpruned.arpa ]; then
|
||||||
|
python ./shared/make_kn_lm.py \
|
||||||
|
-ngram-order 3 \
|
||||||
|
-text $lang_char_dir/text_words_segmentation \
|
||||||
|
-lm $lang_char_dir/3-gram.unpruned.arpa
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir -p data/lm
|
||||||
|
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
|
||||||
|
# It is used in building LG
|
||||||
|
python3 -m kaldilm \
|
||||||
|
--read-symbol-table="$lang_char_dir/words.txt" \
|
||||||
|
--disambig-symbol='#0' \
|
||||||
|
--max-order=3 \
|
||||||
|
$lang_char_dir/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
|
||||||
|
log "Stage 18: Compile LG"
|
||||||
|
lang_char_dir=data/lang_char
|
||||||
|
python ./local/compile_lg.py --lang-dir $lang_char_dir
|
||||||
|
fi
|
||||||
|
@ -37,7 +37,7 @@ When training with the L subset, usage:
|
|||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(3) fast beam search
|
(3) fast beam search (1best)
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 10 \
|
--epoch 10 \
|
||||||
--avg 2 \
|
--avg 2 \
|
||||||
@ -48,6 +48,46 @@ When training with the L subset, usage:
|
|||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
|
|
||||||
|
(4) fast beam search (nbest)
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--epoch 10 \
|
||||||
|
--avg 2 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(5) fast beam search (nbest oracle WER)
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--epoch 10 \
|
||||||
|
--avg 2 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest_oracle \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(6) fast beam search (with LG)
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--epoch 10 \
|
||||||
|
--avg 2 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest_LG \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -63,6 +103,9 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import WenetSpeechAsrDataModule
|
from asr_datamodule import WenetSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_LG,
|
||||||
|
fast_beam_search_nbest_oracle,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
@ -70,6 +113,7 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
@ -151,6 +195,11 @@ def get_parser():
|
|||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
|
- fast_beam_search_nbest
|
||||||
|
- fast_beam_search_nbest_oracle
|
||||||
|
- fast_beam_search_nbest_LG
|
||||||
|
If you use fast_beam_search_nbest_LG, you have to
|
||||||
|
specify `--lang-dir`, which should contain `LG.pt`.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -173,6 +222,16 @@ def get_parser():
|
|||||||
Used only when --decoding-method is fast_beam_search""",
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.35,
|
||||||
|
help="""
|
||||||
|
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-contexts",
|
"--max-contexts",
|
||||||
type=int,
|
type=int,
|
||||||
@ -204,6 +263,24 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="""Number of paths for nbest decoding.
|
||||||
|
Used only when the decoding method is fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""Scale applied to lattice scores when computing nbest paths.
|
||||||
|
Used only when the decoding method is fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -211,6 +288,7 @@ def decode_one_batch(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
@ -267,6 +345,50 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
|
hyp_tokens = fast_beam_search_nbest_LG(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
sentence = "".join([lexicon.word_table[i] for i in hyp])
|
||||||
|
hyps.append(list(sentence))
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest":
|
||||||
|
hyp_tokens = fast_beam_search_nbest(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for i in range(encoder_out.size(0)):
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
|
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for i in range(encoder_out.size(0)):
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
elif (
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
@ -331,6 +453,7 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
@ -373,6 +496,7 @@ def decode_dataset(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
@ -454,6 +578,9 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_nbest",
|
||||||
|
"fast_beam_search_nbest_LG",
|
||||||
|
"fast_beam_search_nbest_oracle",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
@ -463,6 +590,13 @@ def main():
|
|||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
|
if (
|
||||||
|
params.decoding_method == "fast_beam_search_nbest"
|
||||||
|
or params.decoding_method == "fast_beam_search_nbest_oracle"
|
||||||
|
):
|
||||||
|
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += f"-beam-{params.beam_size}"
|
||||||
else:
|
else:
|
||||||
@ -482,6 +616,11 @@ def main():
|
|||||||
params.blank_id = lexicon.token_table["<blk>"]
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
@ -513,8 +652,18 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if "fast_beam_search" in params.decoding_method:
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
|
lg_filename = params.lang_dir + "/LG.pt"
|
||||||
|
logging.info(f"Loading {lg_filename}")
|
||||||
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
|
torch.load(lg_filename, map_location=device)
|
||||||
|
)
|
||||||
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
|
else:
|
||||||
|
decoding_graph = k2.trivial_graph(
|
||||||
|
params.vocab_size - 1, device=device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
|
||||||
@ -610,6 +759,7 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
save_results(
|
save_results(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user