mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
update result
This commit is contained in:
parent
3016440035
commit
dfff69b160
@ -12,7 +12,10 @@ When training with context size equals to 1, the WERs are
|
|||||||
|------------------------------------|-------|----------|----------------------------------|
|
|------------------------------------|-------|----------|----------------------------------|
|
||||||
| greedy search | 5.57 | 5.89 | --epoch 25, --avg 5, --max-duration 600 |
|
| greedy search | 5.57 | 5.89 | --epoch 25, --avg 5, --max-duration 600 |
|
||||||
| modified beam search (beam size 4) | 5.32 | 5.56 | --epoch 25, --avg 5, --max-duration 600 |
|
| modified beam search (beam size 4) | 5.32 | 5.56 | --epoch 25, --avg 5, --max-duration 600 |
|
||||||
| fast beam search (set as default) | 5.5 | 5.78 | --epoch 25, --avg 5, --max-duration 600 |
|
| fast beam search (set as default) | 5.5 | 5.78 | --epoch 25, --avg 5, --max-duration 600 |
|
||||||
|
| fast beam search nbest | 5.46 | 5.74 | --epoch 25, --avg 5, --max-duration 600 |
|
||||||
|
| fast beam search oracle | 1.92 | 2.2 | --epoch 25, --avg 5, --max-duration 600 |
|
||||||
|
| fast beam search nbest LG | 5.59 | 5.93 | --epoch 25, --avg 5, --max-duration 600 |
|
||||||
|
|
||||||
The training command for reproducing is given below:
|
The training command for reproducing is given below:
|
||||||
|
|
||||||
@ -37,11 +40,13 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
|
|
||||||
The decoding command is:
|
The decoding command is:
|
||||||
```
|
```
|
||||||
for method in greedy_search modified_beam_search fast_beam_search; do
|
for method in greedy_search modified_beam_search \
|
||||||
|
fast_beam_search fast_beam_search_nbest \
|
||||||
|
fast_beam_search_nbest_oracle fast_beam_search_nbest_LG; do
|
||||||
./pruned_transducer_stateless5/decode.py \
|
./pruned_transducer_stateless5/decode.py \
|
||||||
--epoch 25 \
|
--epoch 25 \
|
||||||
--avg 5 \
|
--avg 5 \
|
||||||
--exp-dir /result \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method $method \
|
--decoding-method $method \
|
||||||
--max-sym-per-frame 1 \
|
--max-sym-per-frame 1 \
|
||||||
@ -51,7 +56,13 @@ for method in greedy_search modified_beam_search fast_beam_search; do
|
|||||||
--encoder-dim 384 \
|
--encoder-dim 384 \
|
||||||
--decoder-dim 512 \
|
--decoder-dim 512 \
|
||||||
--joiner-dim 512 \
|
--joiner-dim 512 \
|
||||||
--use-averaged-model True
|
--context-size 1 \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5 \
|
||||||
|
--use-averaged-model False
|
||||||
done
|
done
|
||||||
```
|
```
|
||||||
The tensorboard training log can be found at
|
The tensorboard training log can be found at
|
||||||
@ -66,6 +77,9 @@ When training with context size equals to 2, the WERs are
|
|||||||
| greedy search | 5.47 | 5.81 | --epoch 25, --avg 5, --max-duration 600 |
|
| greedy search | 5.47 | 5.81 | --epoch 25, --avg 5, --max-duration 600 |
|
||||||
| modified beam search (beam size 4) | 5.38 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
|
| modified beam search (beam size 4) | 5.38 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
|
||||||
| fast beam search (set as default) | 5.36 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
|
| fast beam search (set as default) | 5.36 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
|
||||||
|
| fast beam search nbest | 5.37 | 5.6 | --epoch 25, --avg 5, --max-duration 600 |
|
||||||
|
| fast beam search oracle | 2.04 | 2.2 | --epoch 25, --avg 5, --max-duration 600 |
|
||||||
|
| fast beam search nbest LG | 5.59 | 5.82 | --epoch 25, --avg 5, --max-duration 600 |
|
||||||
|
|
||||||
The tensorboard training log can be found at
|
The tensorboard training log can be found at
|
||||||
https://tensorboard.dev/experiment/5AxJ8LHoSre8kDAuLp4L7Q/#scalars
|
https://tensorboard.dev/experiment/5AxJ8LHoSre8kDAuLp4L7Q/#scalars
|
||||||
|
@ -112,9 +112,9 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
lang_char_dir=data/lang_char
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "Stage 5: Prepare char based lang"
|
log "Stage 5: Prepare char based lang"
|
||||||
lang_char_dir=data/lang_char
|
|
||||||
mkdir -p $lang_char_dir
|
mkdir -p $lang_char_dir
|
||||||
|
|
||||||
# Prepare text.
|
# Prepare text.
|
||||||
@ -151,3 +151,31 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
python3 ./local/prepare_char.py
|
python3 ./local/prepare_char.py
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
|
log "Stage 6: Prepare G"
|
||||||
|
# We assume you have install kaldilm, if not, please install
|
||||||
|
# it using: pip install kaldilm
|
||||||
|
|
||||||
|
if [ ! -f ${lang_char_dir}/3-gram.unpruned.arpa ]; then
|
||||||
|
./shared/make_kn_lm.py \
|
||||||
|
-ngram-order 3 \
|
||||||
|
-text $lang_char_dir/text_words_segmentation \
|
||||||
|
-lm $lang_char_dir/3-gram.unpruned.arpa
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir -p data/lm
|
||||||
|
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
|
||||||
|
# It is used in building LG
|
||||||
|
python3 -m kaldilm \
|
||||||
|
--read-symbol-table="$lang_char_dir/words.txt" \
|
||||||
|
--disambig-symbol='#0' \
|
||||||
|
--max-order=3 \
|
||||||
|
$lang_char_dir/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||||
|
log "Stage 7: Compile LG"
|
||||||
|
./local/compile_lg.py --lang-dir $lang_char_dir
|
||||||
|
fi
|
||||||
|
@ -123,6 +123,7 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -306,6 +307,7 @@ def decode_one_batch(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
@ -376,7 +378,8 @@ def decode_one_batch(
|
|||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
)
|
)
|
||||||
for hyp in hyp_tokens:
|
for hyp in hyp_tokens:
|
||||||
hyps.append([lexicon.word_table[i] for i in hyp])
|
sentence = "".join([lexicon.word_table[i] for i in hyp])
|
||||||
|
hyps.append(list(sentence))
|
||||||
elif params.decoding_method == "fast_beam_search_nbest":
|
elif params.decoding_method == "fast_beam_search_nbest":
|
||||||
hyp_tokens = fast_beam_search_nbest(
|
hyp_tokens = fast_beam_search_nbest(
|
||||||
model=model,
|
model=model,
|
||||||
@ -401,7 +404,7 @@ def decode_one_batch(
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
ref_texts=supervisions["text"],
|
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
@ -473,6 +476,7 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
@ -515,6 +519,7 @@ def decode_dataset(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
@ -642,6 +647,11 @@ def main():
|
|||||||
params.unk_id = lexicon.token_table["<unk>"]
|
params.unk_id = lexicon.token_table["<unk>"]
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
@ -728,7 +738,18 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
|
logging.info(f"Loading {lg_filename}")
|
||||||
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
|
torch.load(lg_filename, map_location=device)
|
||||||
|
)
|
||||||
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
|
else:
|
||||||
|
decoding_graph = k2.trivial_graph(
|
||||||
|
params.vocab_size - 1, device=device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
|
||||||
@ -753,6 +774,7 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -75,6 +75,16 @@ def main():
|
|||||||
|
|
||||||
logging.info("Starting writing the words.txt")
|
logging.info("Starting writing the words.txt")
|
||||||
f_out = open(output_file, "w", encoding="utf-8")
|
f_out = open(output_file, "w", encoding="utf-8")
|
||||||
|
|
||||||
|
# LG decoding needs below symbols.
|
||||||
|
id1, id2, id3 = (
|
||||||
|
str(len(new_lines)),
|
||||||
|
str(len(new_lines) + 1),
|
||||||
|
str(len(new_lines) + 2),
|
||||||
|
)
|
||||||
|
add_words = ["#0 " + id1, "<s> " + id2, "</s> " + id3]
|
||||||
|
new_lines.extend(add_words)
|
||||||
|
|
||||||
for line in new_lines:
|
for line in new_lines:
|
||||||
f_out.write(line)
|
f_out.write(line)
|
||||||
f_out.write("\n")
|
f_out.write("\n")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user