Add other decoding methods (nbest, nbest oracle, nbest LG) for wenetspeech pruned rnnt2 (#482)

* add other decoding methods for wenetspeech

* changes for RESULTS.md

* add ngram-lm-scale=0.35 results

* set ngram-lm-scale=0.35 as default

* Update README.md

* add nbest-scale for flie name
This commit is contained in:
Mingshuang Luo 2022-07-29 12:03:08 +08:00 committed by GitHub
parent 34b4356bad
commit 1b478d3ac3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 232 additions and 6 deletions

View File

@ -257,8 +257,8 @@ We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer enc
| | Dev | Test-Net | Test-Meeting | | | Dev | Test-Net | Test-Meeting |
|----------------------|-------|----------|--------------| |----------------------|-------|----------|--------------|
| greedy search | 7.80 | 8.75 | 13.49 | | greedy search | 7.80 | 8.75 | 13.49 |
| fast beam search | 7.94 | 8.74 | 13.80 |
| modified beam search| 7.76 | 8.71 | 13.41 | | modified beam search| 7.76 | 8.71 | 13.41 |
| fast beam search | 7.94 | 8.74 | 13.80 |
#### Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset) #### Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset)
**Streaming**: **Streaming**:

View File

@ -84,7 +84,10 @@ When training with the L subset, the CERs are
|------------------------------------|-------|----------|--------------|------------------------------------------| |------------------------------------|-------|----------|--------------|------------------------------------------|
| greedy search | 7.80 | 8.75 | 13.49 | --epoch 10, --avg 2, --max-duration 100 | | greedy search | 7.80 | 8.75 | 13.49 | --epoch 10, --avg 2, --max-duration 100 |
| modified beam search (beam size 4) | 7.76 | 8.71 | 13.41 | --epoch 10, --avg 2, --max-duration 100 | | modified beam search (beam size 4) | 7.76 | 8.71 | 13.41 | --epoch 10, --avg 2, --max-duration 100 |
| fast beam search (set as default) | 7.94 | 8.74 | 13.80 | --epoch 10, --avg 2, --max-duration 1500 | | fast beam search (1best) | 7.94 | 8.74 | 13.80 | --epoch 10, --avg 2, --max-duration 1500 |
| fast beam search (nbest) | 9.82 | 10.98 | 16.37 | --epoch 10, --avg 2, --max-duration 600 |
| fast beam search (nbest oracle) | 6.88 | 7.18 | 11.77 | --epoch 10, --avg 2, --max-duration 600 |
| fast beam search (nbest LG, ngram_lm_scale=0.35) | 8.83 | 9.88 | 15.47 | --epoch 10, --avg 2, --max-duration 600 |
The training command for reproducing is given below: The training command for reproducing is given below:
@ -131,7 +134,7 @@ avg=2
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
## fast beam search ## fast beam search (1best)
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch $epoch \ --epoch $epoch \
--avg $avg \ --avg $avg \
@ -142,6 +145,47 @@ avg=2
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
--max-states 8 --max-states 8
## fast beam search (nbest)
./pruned_transducer_stateless2/decode.py \
--epoch 10 \
--avg 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
## fast beam search (nbest oracle WER)
./pruned_transducer_stateless2/decode.py \
--epoch 10 \
--avg 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
## fast beam search (with LG)
./pruned_transducer_stateless2/decode.py \
--epoch 10 \
--avg 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--ngram-lm-scale 0.35 \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
``` ```
When training with the M subset, the CERs are When training with the M subset, the CERs are

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compile_lg.py

View File

@ -225,3 +225,34 @@ if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then
--lang-dir data/lang_char --lang-dir data/lang_char
fi fi
fi fi
# If you don't want to use LG for decoding, the following steps are not necessary.
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
log "Stage 17: Prepare G"
# It will take about 20 minutes.
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
lang_char_dir=data/lang_char
if [ ! -f $lang_char_dir/3-gram.unpruned.arpa ]; then
python ./shared/make_kn_lm.py \
-ngram-order 3 \
-text $lang_char_dir/text_words_segmentation \
-lm $lang_char_dir/3-gram.unpruned.arpa
fi
mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
# It is used in building LG
python3 -m kaldilm \
--read-symbol-table="$lang_char_dir/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$lang_char_dir/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
fi
fi
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
log "Stage 18: Compile LG"
lang_char_dir=data/lang_char
python ./local/compile_lg.py --lang-dir $lang_char_dir
fi

View File

@ -37,7 +37,7 @@ When training with the L subset, usage:
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(3) fast beam search (3) fast beam search (1best)
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 10 \ --epoch 10 \
--avg 2 \ --avg 2 \
@ -48,6 +48,46 @@ When training with the L subset, usage:
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
--max-states 8 --max-states 8
(4) fast beam search (nbest)
./pruned_transducer_stateless2/decode.py \
--epoch 10 \
--avg 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(5) fast beam search (nbest oracle WER)
./pruned_transducer_stateless2/decode.py \
--epoch 10 \
--avg 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (with LG)
./pruned_transducer_stateless2/decode.py \
--epoch 10 \
--avg 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
""" """
@ -63,6 +103,9 @@ import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule from asr_datamodule import WenetSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
@ -70,6 +113,7 @@ from beam_search import (
) )
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
find_checkpoints, find_checkpoints,
@ -151,6 +195,11 @@ def get_parser():
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search - fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to
specify `--lang-dir`, which should contain `LG.pt`.
""", """,
) )
@ -173,6 +222,16 @@ def get_parser():
Used only when --decoding-method is fast_beam_search""", Used only when --decoding-method is fast_beam_search""",
) )
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.35,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
@ -204,6 +263,24 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
return parser return parser
@ -211,6 +288,7 @@ def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
lexicon: Lexicon, lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
@ -267,6 +345,50 @@ def decode_one_batch(
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
sentence = "".join([lexicon.word_table[i] for i in hyp])
hyps.append(list(sentence))
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif ( elif (
params.decoding_method == "greedy_search" params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1 and params.max_sym_per_frame == 1
@ -331,6 +453,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
lexicon: Lexicon, lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -373,6 +496,7 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
lexicon=lexicon, lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -454,6 +578,9 @@ def main():
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -463,6 +590,13 @@ def main():
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
if params.decoding_method == "fast_beam_search_nbest_LG":
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if (
params.decoding_method == "fast_beam_search_nbest"
or params.decoding_method == "fast_beam_search_nbest_oracle"
):
params.suffix += f"-nbest-scale-{params.nbest_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
else: else:
@ -482,6 +616,11 @@ def main():
params.blank_id = lexicon.token_table["<blk>"] params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
@ -513,8 +652,18 @@ def main():
model.eval() model.eval()
model.device = device model.device = device
if params.decoding_method == "fast_beam_search": if "fast_beam_search" in params.decoding_method:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) if params.decoding_method == "fast_beam_search_nbest_LG":
lg_filename = params.lang_dir + "/LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None
@ -610,6 +759,7 @@ def main():
params=params, params=params,
model=model, model=model,
lexicon=lexicon, lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )
save_results( save_results(