update result

This commit is contained in:
Yuekai Zhang 2022-07-13 14:10:42 +00:00
parent 3016440035
commit dfff69b160
4 changed files with 82 additions and 8 deletions

View File

@ -13,6 +13,9 @@ When training with context size equals to 1, the WERs are
| greedy search | 5.57 | 5.89 | --epoch 25, --avg 5, --max-duration 600 |
| modified beam search (beam size 4) | 5.32 | 5.56 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search (set as default) | 5.5 | 5.78 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search nbest | 5.46 | 5.74 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search oracle | 1.92 | 2.2 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search nbest LG | 5.59 | 5.93 | --epoch 25, --avg 5, --max-duration 600 |
The training command for reproducing is given below:
@ -37,11 +40,13 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
The decoding command is:
```
for method in greedy_search modified_beam_search fast_beam_search; do
for method in greedy_search modified_beam_search \
fast_beam_search fast_beam_search_nbest \
fast_beam_search_nbest_oracle fast_beam_search_nbest_LG; do
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir /result \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method $method \
--max-sym-per-frame 1 \
@ -51,7 +56,13 @@ for method in greedy_search modified_beam_search fast_beam_search; do
--encoder-dim 384 \
--decoder-dim 512 \
--joiner-dim 512 \
--use-averaged-model True
--context-size 1 \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5 \
--use-averaged-model False
done
```
The tensorboard training log can be found at
@ -66,6 +77,9 @@ When training with context size equals to 2, the WERs are
| greedy search | 5.47 | 5.81 | --epoch 25, --avg 5, --max-duration 600 |
| modified beam search (beam size 4) | 5.38 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search (set as default) | 5.36 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search nbest | 5.37 | 5.6 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search oracle | 2.04 | 2.2 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search nbest LG | 5.59 | 5.82 | --epoch 25, --avg 5, --max-duration 600 |
The tensorboard training log can be found at
https://tensorboard.dev/experiment/5AxJ8LHoSre8kDAuLp4L7Q/#scalars

View File

@ -112,9 +112,9 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
fi
lang_char_dir=data/lang_char
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare char based lang"
lang_char_dir=data/lang_char
mkdir -p $lang_char_dir
# Prepare text.
@ -151,3 +151,31 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
python3 ./local/prepare_char.py
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare G"
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
if [ ! -f ${lang_char_dir}/3-gram.unpruned.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order 3 \
-text $lang_char_dir/text_words_segmentation \
-lm $lang_char_dir/3-gram.unpruned.arpa
fi
mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
# It is used in building LG
python3 -m kaldilm \
--read-symbol-table="$lang_char_dir/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
$lang_char_dir/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Compile LG"
./local/compile_lg.py --lang-dir $lang_char_dir
fi

View File

@ -123,6 +123,7 @@ from beam_search import (
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -306,6 +307,7 @@ def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
@ -376,7 +378,8 @@ def decode_one_batch(
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([lexicon.word_table[i] for i in hyp])
sentence = "".join([lexicon.word_table[i] for i in hyp])
hyps.append(list(sentence))
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
@ -401,7 +404,7 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for i in range(encoder_out.size(0)):
@ -473,6 +476,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -515,6 +519,7 @@ def decode_dataset(
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
batch=batch,
)
@ -642,6 +647,11 @@ def main():
params.unk_id = lexicon.token_table["<unk>"]
params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
logging.info(params)
logging.info("About to create model")
@ -728,7 +738,18 @@ def main():
model.eval()
if "fast_beam_search" in params.decoding_method:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None
@ -753,6 +774,7 @@ def main():
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
)

View File

@ -75,6 +75,16 @@ def main():
logging.info("Starting writing the words.txt")
f_out = open(output_file, "w", encoding="utf-8")
# LG decoding needs below symbols.
id1, id2, id3 = (
str(len(new_lines)),
str(len(new_lines) + 1),
str(len(new_lines) + 2),
)
add_words = ["#0 " + id1, "<s> " + id2, "</s> " + id3]
new_lines.extend(add_words)
for line in new_lines:
f_out.write(line)
f_out.write("\n")