mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
update result
This commit is contained in:
parent
3016440035
commit
dfff69b160
@ -13,6 +13,9 @@ When training with context size equals to 1, the WERs are
|
||||
| greedy search | 5.57 | 5.89 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| modified beam search (beam size 4) | 5.32 | 5.56 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search (set as default) | 5.5 | 5.78 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search nbest | 5.46 | 5.74 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search oracle | 1.92 | 2.2 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search nbest LG | 5.59 | 5.93 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
|
||||
The training command for reproducing is given below:
|
||||
|
||||
@ -37,11 +40,13 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
The decoding command is:
|
||||
```
|
||||
for method in greedy_search modified_beam_search fast_beam_search; do
|
||||
for method in greedy_search modified_beam_search \
|
||||
fast_beam_search fast_beam_search_nbest \
|
||||
fast_beam_search_nbest_oracle fast_beam_search_nbest_LG; do
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 25 \
|
||||
--avg 5 \
|
||||
--exp-dir /result \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method $method \
|
||||
--max-sym-per-frame 1 \
|
||||
@ -51,7 +56,13 @@ for method in greedy_search modified_beam_search fast_beam_search; do
|
||||
--encoder-dim 384 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--use-averaged-model True
|
||||
--context-size 1 \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5 \
|
||||
--use-averaged-model False
|
||||
done
|
||||
```
|
||||
The tensorboard training log can be found at
|
||||
@ -66,6 +77,9 @@ When training with context size equals to 2, the WERs are
|
||||
| greedy search | 5.47 | 5.81 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| modified beam search (beam size 4) | 5.38 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search (set as default) | 5.36 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search nbest | 5.37 | 5.6 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search oracle | 2.04 | 2.2 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
| fast beam search nbest LG | 5.59 | 5.82 | --epoch 25, --avg 5, --max-duration 600 |
|
||||
|
||||
The tensorboard training log can be found at
|
||||
https://tensorboard.dev/experiment/5AxJ8LHoSre8kDAuLp4L7Q/#scalars
|
||||
|
@ -112,9 +112,9 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
lang_char_dir=data/lang_char
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare char based lang"
|
||||
lang_char_dir=data/lang_char
|
||||
mkdir -p $lang_char_dir
|
||||
|
||||
# Prepare text.
|
||||
@ -151,3 +151,31 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
python3 ./local/prepare_char.py
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare G"
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
|
||||
if [ ! -f ${lang_char_dir}/3-gram.unpruned.arpa ]; then
|
||||
./shared/make_kn_lm.py \
|
||||
-ngram-order 3 \
|
||||
-text $lang_char_dir/text_words_segmentation \
|
||||
-lm $lang_char_dir/3-gram.unpruned.arpa
|
||||
fi
|
||||
|
||||
mkdir -p data/lm
|
||||
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
|
||||
# It is used in building LG
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="$lang_char_dir/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
$lang_char_dir/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Compile LG"
|
||||
./local/compile_lg.py --lang-dir $lang_char_dir
|
||||
fi
|
||||
|
@ -123,6 +123,7 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -306,6 +307,7 @@ def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
@ -376,7 +378,8 @@ def decode_one_batch(
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([lexicon.word_table[i] for i in hyp])
|
||||
sentence = "".join([lexicon.word_table[i] for i in hyp])
|
||||
hyps.append(list(sentence))
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
@ -401,7 +404,7 @@ def decode_one_batch(
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
@ -473,6 +476,7 @@ def decode_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
@ -515,6 +519,7 @@ def decode_dataset(
|
||||
params=params,
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
graph_compiler=graph_compiler,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
)
|
||||
@ -642,6 +647,11 @@ def main():
|
||||
params.unk_id = lexicon.token_table["<unk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
@ -728,7 +738,18 @@ def main():
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device
|
||||
)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
@ -753,6 +774,7 @@ def main():
|
||||
params=params,
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
graph_compiler=graph_compiler,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
|
@ -75,6 +75,16 @@ def main():
|
||||
|
||||
logging.info("Starting writing the words.txt")
|
||||
f_out = open(output_file, "w", encoding="utf-8")
|
||||
|
||||
# LG decoding needs below symbols.
|
||||
id1, id2, id3 = (
|
||||
str(len(new_lines)),
|
||||
str(len(new_lines) + 1),
|
||||
str(len(new_lines) + 2),
|
||||
)
|
||||
add_words = ["#0 " + id1, "<s> " + id2, "</s> " + id3]
|
||||
new_lines.extend(add_words)
|
||||
|
||||
for line in new_lines:
|
||||
f_out.write(line)
|
||||
f_out.write("\n")
|
||||
|
Loading…
x
Reference in New Issue
Block a user