mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Merge branch 'k2-fsa:master' into master
This commit is contained in:
commit
36d625bc9b
@ -29,6 +29,9 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
|
|||||||
ls -lh data/fbank
|
ls -lh data/fbank
|
||||||
ls -lh pruned_transducer_stateless2/exp
|
ls -lh pruned_transducer_stateless2/exp
|
||||||
|
|
||||||
|
ln -s data/fbank/cuts_DEV.jsonl.gz data/fbank/gigaspeech_cuts_DEV.jsonl.gz
|
||||||
|
ln -s data/fbank/cuts_TEST.jsonl.gz data/fbank/gigaspeech_cuts_TEST.jsonl.gz
|
||||||
|
|
||||||
log "Decoding dev and test"
|
log "Decoding dev and test"
|
||||||
|
|
||||||
# use a small value for decoding with CPU
|
# use a small value for decoding with CPU
|
||||||
|
@ -38,7 +38,7 @@ log "Decode with models exported by torch.jit.trace()"
|
|||||||
for m in ctc-decoding 1best; do
|
for m in ctc-decoding 1best; do
|
||||||
./conformer_ctc3/jit_pretrained.py \
|
./conformer_ctc3/jit_pretrained.py \
|
||||||
--model-filename $repo/exp/jit_trace.pt \
|
--model-filename $repo/exp/jit_trace.pt \
|
||||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--G $repo/data/lm/G_4_gram.pt \
|
--G $repo/data/lm/G_4_gram.pt \
|
||||||
@ -53,7 +53,7 @@ log "Export to torchscript model"
|
|||||||
|
|
||||||
./conformer_ctc3/export.py \
|
./conformer_ctc3/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--lang-dir $repo/data/lang_bpe_500 \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--jit-trace 1 \
|
--jit-trace 1 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -80,9 +80,9 @@ done
|
|||||||
for m in ctc-decoding 1best; do
|
for m in ctc-decoding 1best; do
|
||||||
./conformer_ctc3/pretrained.py \
|
./conformer_ctc3/pretrained.py \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--G $repo/data/lm/G_4_gram.pt \
|
--G $repo/data/lm/G_4_gram.pt \
|
||||||
--method $m \
|
--method $m \
|
||||||
--sample-rate 16000 \
|
--sample-rate 16000 \
|
||||||
@ -93,7 +93,7 @@ done
|
|||||||
|
|
||||||
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||||
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||||
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||||
mkdir -p conformer_ctc3/exp
|
mkdir -p conformer_ctc3/exp
|
||||||
ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt
|
ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt
|
||||||
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||||
|
@ -31,7 +31,7 @@ log "Test exporting with torch.jit.trace()"
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
./lstm_transducer_stateless2/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
@ -55,7 +55,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -68,7 +68,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -36,7 +36,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -49,7 +49,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -35,7 +35,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -48,7 +48,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -30,14 +30,14 @@ popd
|
|||||||
log "Export to torchscript model"
|
log "Export to torchscript model"
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit-trace 1
|
--jit-trace 1
|
||||||
@ -74,7 +74,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -32,7 +32,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--num-encoder-layers 18 \
|
--num-encoder-layers 18 \
|
||||||
--dim-feedforward 2048 \
|
--dim-feedforward 2048 \
|
||||||
--nhead 8 \
|
--nhead 8 \
|
||||||
@ -51,7 +51,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav \
|
$repo/test_wavs/1221-135766-0002.wav \
|
||||||
|
@ -33,7 +33,7 @@ log "Export to torchscript model"
|
|||||||
./pruned_transducer_stateless7/export.py \
|
./pruned_transducer_stateless7/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -56,7 +56,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -69,7 +69,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -37,7 +37,7 @@ log "Export to torchscript model"
|
|||||||
./pruned_transducer_stateless7_ctc/export.py \
|
./pruned_transducer_stateless7_ctc/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -74,7 +74,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -36,7 +36,7 @@ log "Export to torchscript model"
|
|||||||
./pruned_transducer_stateless7_ctc_bs/export.py \
|
./pruned_transducer_stateless7_ctc_bs/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -72,7 +72,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -85,7 +85,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -37,7 +37,7 @@ log "Export to torchscript model"
|
|||||||
./pruned_transducer_stateless7_streaming/export.py \
|
./pruned_transducer_stateless7_streaming/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -81,7 +81,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
@ -95,7 +95,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
@ -41,7 +41,7 @@ log "Decode with models exported by torch.jit.script()"
|
|||||||
log "Export to torchscript model"
|
log "Export to torchscript model"
|
||||||
./pruned_transducer_stateless8/export.py \
|
./pruned_transducer_stateless8/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -65,7 +65,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -78,7 +78,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -32,7 +32,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--simulate-streaming 1 \
|
--simulate-streaming 1 \
|
||||||
--causal-convolution 1 \
|
--causal-convolution 1 \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
@ -47,7 +47,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--simulate-streaming 1 \
|
--simulate-streaming 1 \
|
||||||
--causal-convolution 1 \
|
--causal-convolution 1 \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -37,7 +37,7 @@ log "Export to torchscript model"
|
|||||||
./zipformer_mmi/export.py \
|
./zipformer_mmi/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -61,7 +61,7 @@ for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescor
|
|||||||
--method $method \
|
--method $method \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--lang-dir $repo/data/lang_bpe_500 \
|
--lang-dir $repo/data/lang_bpe_500 \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
51
.github/scripts/run-multi-zh_hans-zipformer.sh
vendored
Executable file
51
.github/scripts/run-multi-zh_hans-zipformer.sh
vendored
Executable file
@ -0,0 +1,51 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
cd egs/multi_zh-hans/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/
|
||||||
|
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git lfs install
|
||||||
|
git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
|
||||||
|
log "Display test files"
|
||||||
|
tree $repo/
|
||||||
|
ls -lh $repo/test_wavs/*.wav
|
||||||
|
|
||||||
|
pushd $repo/exp
|
||||||
|
ln -s epoch-20.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
ls -lh $repo/exp/*.pt
|
||||||
|
|
||||||
|
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--checkpoint $repo/exp/epoch-99.pt \
|
||||||
|
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||||
|
--method greedy_search \
|
||||||
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
|
|
||||||
|
for method in modified_beam_search fast_beam_search; do
|
||||||
|
log "$method"
|
||||||
|
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--method $method \
|
||||||
|
--beam-size 4 \
|
||||||
|
--checkpoint $repo/exp/epoch-99.pt \
|
||||||
|
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
||||||
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
|
done
|
@ -27,7 +27,7 @@ log "CTC decoding"
|
|||||||
--method ctc-decoding \
|
--method ctc-decoding \
|
||||||
--num-classes 500 \
|
--num-classes 500 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.flac \
|
$repo/test_wavs/1089-134686-0001.flac \
|
||||||
$repo/test_wavs/1221-135766-0001.flac \
|
$repo/test_wavs/1221-135766-0001.flac \
|
||||||
$repo/test_wavs/1221-135766-0002.flac
|
$repo/test_wavs/1221-135766-0002.flac
|
||||||
@ -38,7 +38,7 @@ log "HLG decoding"
|
|||||||
--method 1best \
|
--method 1best \
|
||||||
--num-classes 500 \
|
--num-classes 500 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||||
$repo/test_wavs/1089-134686-0001.flac \
|
$repo/test_wavs/1089-134686-0001.flac \
|
||||||
|
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -27,7 +27,7 @@ log "Beam search decoding"
|
|||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -17,7 +17,6 @@ git lfs install
|
|||||||
git clone $repo_url
|
git clone $repo_url
|
||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
@ -29,12 +28,11 @@ popd
|
|||||||
|
|
||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless2/export.py \
|
./pruned_transducer_stateless2/export-onnx.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--lang-dir $repo/data/lang_char \
|
--lang-dir $repo/data/lang_char \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1
|
||||||
--onnx 1
|
|
||||||
|
|
||||||
log "Export to torchscript model"
|
log "Export to torchscript model"
|
||||||
|
|
||||||
@ -59,19 +57,17 @@ log "Decode with ONNX models"
|
|||||||
|
|
||||||
./pruned_transducer_stateless2/onnx_check.py \
|
./pruned_transducer_stateless2/onnx_check.py \
|
||||||
--jit-filename $repo/exp/cpu_jit.pt \
|
--jit-filename $repo/exp/cpu_jit.pt \
|
||||||
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
--onnx-encoder-filename $repo/exp/encoder-epoch-10-avg-2.onnx \
|
||||||
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
--onnx-decoder-filename $repo/exp/decoder-epoch-10-avg-2.onnx \
|
||||||
--onnx-joiner-filename $repo/exp/joiner.onnx \
|
--onnx-joiner-filename $repo/exp/joiner-epoch-10-avg-2.onnx \
|
||||||
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
|
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
|
||||||
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
|
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj-epoch-10-avg-2.onnx
|
||||||
|
|
||||||
./pruned_transducer_stateless2/onnx_pretrained.py \
|
./pruned_transducer_stateless2/onnx_pretrained.py \
|
||||||
--tokens $repo/data/lang_char/tokens.txt \
|
--tokens $repo/data/lang_char/tokens.txt \
|
||||||
--encoder-model-filename $repo/exp/encoder.onnx \
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
--decoder-model-filename $repo/exp/decoder.onnx \
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
--joiner-model-filename $repo/exp/joiner.onnx \
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||||
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
|
|
||||||
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
|
|
||||||
$repo/test_wavs/DEV_T0000000000.wav \
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
$repo/test_wavs/DEV_T0000000001.wav \
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
$repo/test_wavs/DEV_T0000000002.wav
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
@ -104,9 +100,9 @@ for sym in 1 2 3; do
|
|||||||
--lang-dir $repo/data/lang_char \
|
--lang-dir $repo/data/lang_char \
|
||||||
--decoding-method greedy_search \
|
--decoding-method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
$repo/test_wavs/DEV_T0000000000.wav \
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
$repo/test_wavs/DEV_T0000000001.wav \
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
$repo/test_wavs/DEV_T0000000002.wav
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
done
|
done
|
||||||
|
|
||||||
for method in modified_beam_search beam_search fast_beam_search; do
|
for method in modified_beam_search beam_search fast_beam_search; do
|
||||||
@ -117,7 +113,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/epoch-99.pt \
|
--checkpoint $repo/exp/epoch-99.pt \
|
||||||
--lang-dir $repo/data/lang_char \
|
--lang-dir $repo/data/lang_char \
|
||||||
$repo/test_wavs/DEV_T0000000000.wav \
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
$repo/test_wavs/DEV_T0000000001.wav \
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
$repo/test_wavs/DEV_T0000000002.wav
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
done
|
done
|
||||||
|
12
.github/scripts/test-ncnn-export.sh
vendored
12
.github/scripts/test-ncnn-export.sh
vendored
@ -45,7 +45,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -56,11 +55,10 @@ log "Export via torch.jit.trace()"
|
|||||||
|
|
||||||
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
\
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--num-encoder-layers 12 \
|
--num-encoder-layers 12 \
|
||||||
--chunk-length 32 \
|
--chunk-length 32 \
|
||||||
--cnn-module-kernel 31 \
|
--cnn-module-kernel 31 \
|
||||||
@ -91,7 +89,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
|
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -102,7 +99,7 @@ log "Export via torch.jit.trace()"
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export-for-ncnn.py \
|
./lstm_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0
|
--use-averaged-model 0
|
||||||
@ -140,7 +137,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -148,7 +144,7 @@ ln -s pretrained.pt epoch-99.pt
|
|||||||
popd
|
popd
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
@ -199,7 +195,7 @@ ln -s pretrained.pt epoch-9999.pt
|
|||||||
popd
|
popd
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
|
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
|
||||||
--lang-dir $repo/data/lang_char_bpe \
|
--tokens $repo/data/lang_char_bpe/tokens.txt \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
|
138
.github/scripts/test-onnx-export.sh
vendored
138
.github/scripts/test-onnx-export.sh
vendored
@ -10,7 +10,123 @@ log() {
|
|||||||
|
|
||||||
cd egs/librispeech/ASR
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
log "=========================================================================="
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git lfs install
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "Export via torch.jit.script()"
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
log "Test export to ONNX format"
|
||||||
|
./zipformer/export-onnx.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--num-encoder-layers "2,2,3,4,3,2" \
|
||||||
|
--downsampling-factor "1,2,4,8,4,2" \
|
||||||
|
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||||
|
--num-heads "4,4,4,8,4,4" \
|
||||||
|
--encoder-dim "192,256,384,512,384,256" \
|
||||||
|
--query-head-dim 32 \
|
||||||
|
--value-head-dim 12 \
|
||||||
|
--pos-head-dim 4 \
|
||||||
|
--pos-dim 48 \
|
||||||
|
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||||
|
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512 \
|
||||||
|
--causal False \
|
||||||
|
--chunk-size "16,32,64,-1" \
|
||||||
|
--left-context-frames "64,128,256,-1"
|
||||||
|
|
||||||
|
ls -lh $repo/exp
|
||||||
|
|
||||||
|
log "Run onnx_check.py"
|
||||||
|
|
||||||
|
./zipformer/onnx_check.py \
|
||||||
|
--jit-filename $repo/exp/jit_script.pt \
|
||||||
|
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
log "Run onnx_pretrained.py"
|
||||||
|
|
||||||
|
./zipformer/onnx_pretrained.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git lfs install
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "Test export streaming model to ONNX format"
|
||||||
|
./zipformer/export-onnx-streaming.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--num-encoder-layers "2,2,3,4,3,2" \
|
||||||
|
--downsampling-factor "1,2,4,8,4,2" \
|
||||||
|
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||||
|
--num-heads "4,4,4,8,4,4" \
|
||||||
|
--encoder-dim "192,256,384,512,384,256" \
|
||||||
|
--query-head-dim 32 \
|
||||||
|
--value-head-dim 12 \
|
||||||
|
--pos-head-dim 4 \
|
||||||
|
--pos-dim 48 \
|
||||||
|
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||||
|
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512 \
|
||||||
|
--causal True \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 64
|
||||||
|
|
||||||
|
ls -lh $repo/exp
|
||||||
|
|
||||||
|
log "Run onnx_pretrained-streaming.py"
|
||||||
|
|
||||||
|
./zipformer/onnx_pretrained-streaming.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
|
log "--------------------------------------------------------------------------"
|
||||||
|
|
||||||
log "=========================================================================="
|
log "=========================================================================="
|
||||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
@ -39,7 +155,7 @@ log "Export via torch.jit.trace()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -88,7 +204,7 @@ popd
|
|||||||
log "Export via torch.jit.script()"
|
log "Export via torch.jit.script()"
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--exp-dir $repo/exp/ \
|
--exp-dir $repo/exp/ \
|
||||||
@ -97,7 +213,7 @@ log "Export via torch.jit.script()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export-onnx.py \
|
./pruned_transducer_stateless3/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--exp-dir $repo/exp/
|
--exp-dir $repo/exp/
|
||||||
@ -126,7 +242,6 @@ log "Run onnx_pretrained.py"
|
|||||||
rm -rf $repo
|
rm -rf $repo
|
||||||
log "--------------------------------------------------------------------------"
|
log "--------------------------------------------------------------------------"
|
||||||
|
|
||||||
|
|
||||||
log "=========================================================================="
|
log "=========================================================================="
|
||||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
@ -143,7 +258,7 @@ popd
|
|||||||
log "Export via torch.jit.script()"
|
log "Export via torch.jit.script()"
|
||||||
|
|
||||||
./pruned_transducer_stateless5/export.py \
|
./pruned_transducer_stateless5/export.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
@ -159,7 +274,7 @@ log "Export via torch.jit.script()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless5/export-onnx.py \
|
./pruned_transducer_stateless5/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
@ -205,7 +320,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -215,7 +329,7 @@ popd
|
|||||||
log "Export via torch.jit.script()"
|
log "Export via torch.jit.script()"
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
./pruned_transducer_stateless7/export.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -226,7 +340,7 @@ log "Export via torch.jit.script()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export-onnx.py \
|
./pruned_transducer_stateless7/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -270,7 +384,7 @@ popd
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./conv_emformer_transducer_stateless2/export-onnx.py \
|
./conv_emformer_transducer_stateless2/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -310,7 +424,7 @@ popd
|
|||||||
log "Export via torch.jit.trace()"
|
log "Export via torch.jit.trace()"
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
./lstm_transducer_stateless2/export.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -320,7 +434,7 @@ log "Export via torch.jit.trace()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export-onnx.py \
|
./lstm_transducer_stateless2/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
|
2
.github/workflows/run-aishell-2022-06-20.yml
vendored
2
.github/workflows/run-aishell-2022-06-20.yml
vendored
@ -45,7 +45,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
84
.github/workflows/run-multi-zh_hans-zipformer.yml
vendored
Normal file
84
.github/workflows/run-multi-zh_hans-zipformer.yml
vendored
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright 2023 Xiaomi Corp. (author: Zengrui Jin)
|
||||||
|
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
name: run-multi-zh_hans-zipformer
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request:
|
||||||
|
types: [labeled]
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: run_multi-zh_hans_zipformer-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_multi-zh_hans_zipformer:
|
||||||
|
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans'
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
python-version: [3.8]
|
||||||
|
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
cache: 'pip'
|
||||||
|
cache-dependency-path: '**/requirements-ci.txt'
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: |
|
||||||
|
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
||||||
|
pip uninstall -y protobuf
|
||||||
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
|
- name: Cache kaldifeat
|
||||||
|
id: my-cache
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/tmp/kaldifeat
|
||||||
|
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
||||||
|
|
||||||
|
- name: Install kaldifeat
|
||||||
|
if: steps.my-cache.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/install-kaldifeat.sh
|
||||||
|
|
||||||
|
- name: Inference with pre-trained model
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
||||||
|
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
||||||
|
run: |
|
||||||
|
sudo apt-get -qq install git-lfs tree
|
||||||
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||||
|
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||||
|
|
||||||
|
.github/scripts/run-multi-zh_hans-zipformer.sh
|
@ -34,7 +34,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.8]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
76
.github/workflows/run-yesno-recipe.yml
vendored
76
.github/workflows/run-yesno-recipe.yml
vendored
@ -44,11 +44,6 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Install graphviz
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
sudo apt-get -qq install graphviz
|
|
||||||
|
|
||||||
- name: Setup Python ${{ matrix.python-version }}
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
@ -70,6 +65,7 @@ jobs:
|
|||||||
pip install --no-binary protobuf protobuf==3.20.*
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
pip install --no-deps --force-reinstall https://huggingface.co/csukuangfj/k2/resolve/main/cpu/k2-1.24.3.dev20230508+cpu.torch1.13.1-cp38-cp38-linux_x86_64.whl
|
pip install --no-deps --force-reinstall https://huggingface.co/csukuangfj/k2/resolve/main/cpu/k2-1.24.3.dev20230508+cpu.torch1.13.1-cp38-cp38-linux_x86_64.whl
|
||||||
|
pip install kaldifeat==1.25.0.dev20230726+cpu.torch1.13.1 -f https://csukuangfj.github.io/kaldifeat/cpu.html
|
||||||
|
|
||||||
- name: Run yesno recipe
|
- name: Run yesno recipe
|
||||||
shell: bash
|
shell: bash
|
||||||
@ -78,9 +74,75 @@ jobs:
|
|||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
echo $PYTHONPATH
|
echo $PYTHONPATH
|
||||||
|
|
||||||
|
|
||||||
cd egs/yesno/ASR
|
cd egs/yesno/ASR
|
||||||
./prepare.sh
|
./prepare.sh
|
||||||
python3 ./tdnn/train.py
|
python3 ./tdnn/train.py
|
||||||
python3 ./tdnn/decode.py
|
python3 ./tdnn/decode.py
|
||||||
# TODO: Check that the WER is less than some value
|
|
||||||
|
- name: Test exporting to pretrained.pt
|
||||||
|
shell: bash
|
||||||
|
working-directory: ${{github.workspace}}
|
||||||
|
run: |
|
||||||
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
echo $PYTHONPATH
|
||||||
|
|
||||||
|
cd egs/yesno/ASR
|
||||||
|
python3 ./tdnn/export.py --epoch 14 --avg 2
|
||||||
|
|
||||||
|
python3 ./tdnn/pretrained.py \
|
||||||
|
--checkpoint ./tdnn/exp/pretrained.pt \
|
||||||
|
--HLG ./data/lang_phone/HLG.pt \
|
||||||
|
--words-file ./data/lang_phone/words.txt \
|
||||||
|
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
||||||
|
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
||||||
|
|
||||||
|
- name: Test exporting to torchscript
|
||||||
|
shell: bash
|
||||||
|
working-directory: ${{github.workspace}}
|
||||||
|
run: |
|
||||||
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
echo $PYTHONPATH
|
||||||
|
|
||||||
|
cd egs/yesno/ASR
|
||||||
|
python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1
|
||||||
|
|
||||||
|
python3 ./tdnn/jit_pretrained.py \
|
||||||
|
--nn-model ./tdnn/exp/cpu_jit.pt \
|
||||||
|
--HLG ./data/lang_phone/HLG.pt \
|
||||||
|
--words-file ./data/lang_phone/words.txt \
|
||||||
|
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
||||||
|
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
||||||
|
|
||||||
|
- name: Test exporting to onnx
|
||||||
|
shell: bash
|
||||||
|
working-directory: ${{github.workspace}}
|
||||||
|
run: |
|
||||||
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
echo $PYTHONPATH
|
||||||
|
|
||||||
|
cd egs/yesno/ASR
|
||||||
|
python3 ./tdnn/export_onnx.py --epoch 14 --avg 2
|
||||||
|
|
||||||
|
echo "Test float32 model"
|
||||||
|
python3 ./tdnn/onnx_pretrained.py \
|
||||||
|
--nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
|
||||||
|
--HLG ./data/lang_phone/HLG.pt \
|
||||||
|
--words-file ./data/lang_phone/words.txt \
|
||||||
|
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
||||||
|
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
||||||
|
|
||||||
|
|
||||||
|
echo "Test int8 model"
|
||||||
|
python3 ./tdnn/onnx_pretrained.py \
|
||||||
|
--nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \
|
||||||
|
--HLG ./data/lang_phone/HLG.pt \
|
||||||
|
--words-file ./data/lang_phone/words.txt \
|
||||||
|
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
||||||
|
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
||||||
|
|
||||||
|
- name: Show generated files
|
||||||
|
shell: bash
|
||||||
|
working-directory: ${{github.workspace}}
|
||||||
|
run: |
|
||||||
|
cd egs/yesno/ASR
|
||||||
|
ls -lh tdnn/exp
|
||||||
|
@ -338,7 +338,7 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
|
|||||||
|
|
||||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
|
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
|
||||||
|
|
||||||
The best results for Chinese CER(%) and English WER(%) respectivly (zh: Chinese, en: English):
|
The best results for Chinese CER(%) and English WER(%) respectively (zh: Chinese, en: English):
|
||||||
|decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en |
|
|decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en |
|
||||||
|--|--|--|--|--|--|--|
|
|--|--|--|--|--|--|--|
|
||||||
|greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
|
|greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
|
||||||
|
@ -95,4 +95,7 @@ rst_epilog = """
|
|||||||
.. _k2: https://github.com/k2-fsa/k2
|
.. _k2: https://github.com/k2-fsa/k2
|
||||||
.. _lhotse: https://github.com/lhotse-speech/lhotse
|
.. _lhotse: https://github.com/lhotse-speech/lhotse
|
||||||
.. _yesno: https://www.openslr.org/1/
|
.. _yesno: https://www.openslr.org/1/
|
||||||
|
.. _Next-gen Kaldi: https://github.com/k2-fsa
|
||||||
|
.. _Kaldi: https://github.com/kaldi-asr/kaldi
|
||||||
|
.. _lilcom: https://github.com/danpovey/lilcom
|
||||||
"""
|
"""
|
||||||
|
@ -71,9 +71,12 @@ As the initial step, let's download the pre-trained model.
|
|||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
$ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||||
$ git lfs pull --include "pretrained.pt"
|
$ git lfs pull --include "pretrained.pt"
|
||||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||||
|
$ cd ../data/lang_bpe_500
|
||||||
|
$ git lfs pull --include bpe.model
|
||||||
|
$ cd ../../..
|
||||||
|
|
||||||
To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command:
|
To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command:
|
||||||
|
|
||||||
|
@ -34,9 +34,12 @@ As the initial step, let's download the pre-trained model.
|
|||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
$ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||||
$ git lfs pull --include "pretrained.pt"
|
$ git lfs pull --include "pretrained.pt"
|
||||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||||
|
$ cd ../data/lang_bpe_500
|
||||||
|
$ git lfs pull --include bpe.model
|
||||||
|
$ cd ../../..
|
||||||
|
|
||||||
As usual, we first test the model's performance without external LM. This can be done via the following command:
|
As usual, we first test the model's performance without external LM. This can be done via the following command:
|
||||||
|
|
||||||
|
@ -32,9 +32,12 @@ As the initial step, let's download the pre-trained model.
|
|||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
$ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||||
$ git lfs pull --include "pretrained.pt"
|
$ git lfs pull --include "pretrained.pt"
|
||||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||||
|
$ cd ../data/lang_bpe_500
|
||||||
|
$ git lfs pull --include bpe.model
|
||||||
|
$ cd ../../..
|
||||||
|
|
||||||
To test the model, let's have a look at the decoding results without using LM. This can be done via the following command:
|
To test the model, let's have a look at the decoding results without using LM. This can be done via the following command:
|
||||||
|
|
||||||
|
180
docs/source/for-dummies/data-preparation.rst
Normal file
180
docs/source/for-dummies/data-preparation.rst
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
.. _dummies_tutorial_data_preparation:
|
||||||
|
|
||||||
|
Data Preparation
|
||||||
|
================
|
||||||
|
|
||||||
|
After :ref:`dummies_tutorial_environment_setup`, we can start preparing the
|
||||||
|
data for training and decoding.
|
||||||
|
|
||||||
|
The first step is to prepare the data for training. We have already provided
|
||||||
|
`prepare.sh <https://github.com/k2-fsa/icefall/blob/master/egs/yesno/ASR/prepare.sh>`_
|
||||||
|
that would prepare everything required for training.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
cd /tmp/icefall
|
||||||
|
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||||
|
cd egs/yesno/ASR
|
||||||
|
|
||||||
|
./prepare.sh
|
||||||
|
|
||||||
|
Note that in each recipe from `icefall`_, there exists a file ``prepare.sh``,
|
||||||
|
which you should run before you run anything else.
|
||||||
|
|
||||||
|
That is all you need for data preparation.
|
||||||
|
|
||||||
|
For the more curious
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
If you are wondering how to prepare your own dataset, please refer to the following
|
||||||
|
URLs for more details:
|
||||||
|
|
||||||
|
- `<https://github.com/lhotse-speech/lhotse/tree/master/lhotse/recipes>`_
|
||||||
|
|
||||||
|
It contains recipes for a variety of dataset. If you want to add your own
|
||||||
|
dataset, please read recipes in this folder first.
|
||||||
|
|
||||||
|
- `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/yesno.py>`_
|
||||||
|
|
||||||
|
The `yesno`_ recipe in `lhotse`_.
|
||||||
|
|
||||||
|
If you already have a `Kaldi`_ dataset directory, which contains files like
|
||||||
|
``wav.scp``, ``feats.scp``, then you can refer to `<https://lhotse.readthedocs.io/en/latest/kaldi.html#example>`_.
|
||||||
|
|
||||||
|
A quick look to the generated files
|
||||||
|
-----------------------------------
|
||||||
|
|
||||||
|
``./prepare.sh`` puts generated files into two directories:
|
||||||
|
|
||||||
|
- ``download``
|
||||||
|
- ``data``
|
||||||
|
|
||||||
|
download
|
||||||
|
^^^^^^^^
|
||||||
|
|
||||||
|
The ``download`` directory contains downloaded dataset files:
|
||||||
|
|
||||||
|
.. code-block:: bas
|
||||||
|
|
||||||
|
tree -L 1 ./download/
|
||||||
|
|
||||||
|
./download/
|
||||||
|
|-- waves_yesno
|
||||||
|
`-- waves_yesno.tar.gz
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
Please refer to `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/yesno.py#L41>`_
|
||||||
|
for how the data is downloaded and extracted.
|
||||||
|
|
||||||
|
data
|
||||||
|
^^^^
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
tree ./data/
|
||||||
|
|
||||||
|
./data/
|
||||||
|
|-- fbank
|
||||||
|
| |-- yesno_cuts_test.jsonl.gz
|
||||||
|
| |-- yesno_cuts_train.jsonl.gz
|
||||||
|
| |-- yesno_feats_test.lca
|
||||||
|
| `-- yesno_feats_train.lca
|
||||||
|
|-- lang_phone
|
||||||
|
| |-- HLG.pt
|
||||||
|
| |-- L.pt
|
||||||
|
| |-- L_disambig.pt
|
||||||
|
| |-- Linv.pt
|
||||||
|
| |-- lexicon.txt
|
||||||
|
| |-- lexicon_disambig.txt
|
||||||
|
| |-- tokens.txt
|
||||||
|
| `-- words.txt
|
||||||
|
|-- lm
|
||||||
|
| |-- G.arpa
|
||||||
|
| `-- G.fst.txt
|
||||||
|
`-- manifests
|
||||||
|
|-- yesno_recordings_test.jsonl.gz
|
||||||
|
|-- yesno_recordings_train.jsonl.gz
|
||||||
|
|-- yesno_supervisions_test.jsonl.gz
|
||||||
|
`-- yesno_supervisions_train.jsonl.gz
|
||||||
|
|
||||||
|
4 directories, 18 files
|
||||||
|
|
||||||
|
**data/manifests**:
|
||||||
|
|
||||||
|
This directory contains manifests. They are used to generate files in
|
||||||
|
``data/fbank``.
|
||||||
|
|
||||||
|
To give you an idea of what it contains, we examine the first few lines of
|
||||||
|
the manifests related to the ``train`` dataset.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd data/manifests
|
||||||
|
gunzip -c yesno_recordings_train.jsonl.gz | head -n 3
|
||||||
|
|
||||||
|
The output is given below:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
{"id": "0_0_0_0_1_1_1_1", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_0_1_1_1_1.wav"}], "sampling_rate": 8000, "num_samples": 50800, "duration": 6.35, "channel_ids": [0]}
|
||||||
|
{"id": "0_0_0_1_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_1_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48880, "duration": 6.11, "channel_ids": [0]}
|
||||||
|
{"id": "0_0_1_0_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_1_0_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48160, "duration": 6.02, "channel_ids": [0]}
|
||||||
|
|
||||||
|
Please refer to `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/audio.py#L300>`_
|
||||||
|
for the meaning of each field per line.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
gunzip -c yesno_supervisions_train.jsonl.gz | head -n 3
|
||||||
|
|
||||||
|
The output is given below:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
{"id": "0_0_0_0_1_1_1_1", "recording_id": "0_0_0_0_1_1_1_1", "start": 0.0, "duration": 6.35, "channel": 0, "text": "NO NO NO NO YES YES YES YES", "language": "Hebrew"}
|
||||||
|
{"id": "0_0_0_1_0_1_1_0", "recording_id": "0_0_0_1_0_1_1_0", "start": 0.0, "duration": 6.11, "channel": 0, "text": "NO NO NO YES NO YES YES NO", "language": "Hebrew"}
|
||||||
|
{"id": "0_0_1_0_0_1_1_0", "recording_id": "0_0_1_0_0_1_1_0", "start": 0.0, "duration": 6.02, "channel": 0, "text": "NO NO YES NO NO YES YES NO", "language": "Hebrew"}
|
||||||
|
|
||||||
|
Please refer to `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/supervision.py#L510>`_
|
||||||
|
for the meaning of each field per line.
|
||||||
|
|
||||||
|
**data/fbank**:
|
||||||
|
|
||||||
|
This directory contains everything from ``data/manifests``. Furthermore, it also contains features
|
||||||
|
for training.
|
||||||
|
|
||||||
|
``data/fbank/yesno_feats_train.lca`` contains the features for the train dataset.
|
||||||
|
Features are compressed using `lilcom`_.
|
||||||
|
|
||||||
|
``data/fbank/yesno_cuts_train.jsonl.gz`` stores the `CutSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/cut/set.py#L72>`_,
|
||||||
|
which stores `RecordingSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/audio.py#L928>`_,
|
||||||
|
`SupervisionSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/supervision.py#L510>`_,
|
||||||
|
and `FeatureSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/base.py#L593>`_.
|
||||||
|
|
||||||
|
To give you an idea about what it looks like, we can run the following command:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd data/fbank
|
||||||
|
|
||||||
|
gunzip -c yesno_cuts_train.jsonl.gz | head -n 3
|
||||||
|
|
||||||
|
The output is given below:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
{"id": "0_0_0_0_1_1_1_1-0", "start": 0, "duration": 6.35, "channel": 0, "supervisions": [{"id": "0_0_0_0_1_1_1_1", "recording_id": "0_0_0_0_1_1_1_1", "start": 0.0, "duration": 6.35, "channel": 0, "text": "NO NO NO NO YES YES YES YES", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 635, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.35, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "0,13000,3570", "channels": 0}, "recording": {"id": "0_0_0_0_1_1_1_1", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_0_1_1_1_1.wav"}], "sampling_rate": 8000, "num_samples": 50800, "duration": 6.35, "channel_ids": [0]}, "type": "MonoCut"}
|
||||||
|
{"id": "0_0_0_1_0_1_1_0-1", "start": 0, "duration": 6.11, "channel": 0, "supervisions": [{"id": "0_0_0_1_0_1_1_0", "recording_id": "0_0_0_1_0_1_1_0", "start": 0.0, "duration": 6.11, "channel": 0, "text": "NO NO NO YES NO YES YES NO", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 611, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.11, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "16570,12964,2929", "channels": 0}, "recording": {"id": "0_0_0_1_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_1_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48880, "duration": 6.11, "channel_ids": [0]}, "type": "MonoCut"}
|
||||||
|
{"id": "0_0_1_0_0_1_1_0-2", "start": 0, "duration": 6.02, "channel": 0, "supervisions": [{"id": "0_0_1_0_0_1_1_0", "recording_id": "0_0_1_0_0_1_1_0", "start": 0.0, "duration": 6.02, "channel": 0, "text": "NO NO YES NO NO YES YES NO", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 602, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.02, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "32463,12936,2696", "channels": 0}, "recording": {"id": "0_0_1_0_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_1_0_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48160, "duration": 6.02, "channel_ids": [0]}, "type": "MonoCut"}
|
||||||
|
|
||||||
|
Note that ``yesno_cuts_train.jsonl.gz`` only stores the information about how to read the features.
|
||||||
|
The actual features are stored separately in ``data/fbank/yesno_feats_train.lca``.
|
||||||
|
|
||||||
|
**data/lang**:
|
||||||
|
|
||||||
|
This directory contains the lexicon.
|
||||||
|
|
||||||
|
**data/lm**:
|
||||||
|
|
||||||
|
This directory contains language models.
|
39
docs/source/for-dummies/decoding.rst
Normal file
39
docs/source/for-dummies/decoding.rst
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
.. _dummies_tutorial_decoding:
|
||||||
|
|
||||||
|
Decoding
|
||||||
|
========
|
||||||
|
|
||||||
|
After :ref:`dummies_tutorial_training`, we can start decoding.
|
||||||
|
|
||||||
|
The command to start the decoding is quite simple:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd /tmp/icefall
|
||||||
|
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||||
|
cd egs/yesno/ASR
|
||||||
|
|
||||||
|
# We use CPU for decoding by setting the following environment variable
|
||||||
|
export CUDA_VISIBLE_DEVICES=""
|
||||||
|
|
||||||
|
./tdnn/decode.py
|
||||||
|
|
||||||
|
The output logs are given below:
|
||||||
|
|
||||||
|
.. literalinclude:: ./code/decoding-yesno.txt
|
||||||
|
|
||||||
|
For the more curious
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./tdnn/decode.py --help
|
||||||
|
|
||||||
|
will print the usage information about ``./tdnn/decode.py``. For instance, you
|
||||||
|
can specify:
|
||||||
|
|
||||||
|
- ``--epoch`` to use which checkpoint for decoding
|
||||||
|
- ``--avg`` to select how many checkpoints to use for model averaging
|
||||||
|
|
||||||
|
You usually try different combinations of ``--epoch`` and ``--avg`` and select
|
||||||
|
one that leads to the lowest WER (`Word Error Rate <https://en.wikipedia.org/wiki/Word_error_rate>`_).
|
121
docs/source/for-dummies/environment-setup.rst
Normal file
121
docs/source/for-dummies/environment-setup.rst
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
.. _dummies_tutorial_environment_setup:
|
||||||
|
|
||||||
|
Environment setup
|
||||||
|
=================
|
||||||
|
|
||||||
|
We will create an environment for `Next-gen Kaldi`_ that runs on ``CPU``
|
||||||
|
in this tutorial.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Since the `yesno`_ dataset used in this tutorial is very tiny, training on
|
||||||
|
``CPU`` works very well for it.
|
||||||
|
|
||||||
|
If your dataset is very large, e.g., hundreds or thousands of hours of
|
||||||
|
training data, please follow :ref:`install icefall` to install `icefall`_
|
||||||
|
that works with ``GPU``.
|
||||||
|
|
||||||
|
|
||||||
|
Create a virtual environment
|
||||||
|
----------------------------
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
virtualenv -p python3 /tmp/icefall_env
|
||||||
|
|
||||||
|
The above command creates a virtual environment in the directory ``/tmp/icefall_env``.
|
||||||
|
You can select any directory you want.
|
||||||
|
|
||||||
|
The output of the above command is given below:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
Already using interpreter /usr/bin/python3
|
||||||
|
Using base prefix '/usr'
|
||||||
|
New python executable in /tmp/icefall_env/bin/python3
|
||||||
|
Also creating executable in /tmp/icefall_env/bin/python
|
||||||
|
Installing setuptools, pkg_resources, pip, wheel...done.
|
||||||
|
|
||||||
|
Now we can activate the environment using:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
source /tmp/icefall_env/bin/activate
|
||||||
|
|
||||||
|
Install dependencies
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
Remeber to activate your virtual environment before you continue!
|
||||||
|
|
||||||
|
After activating the virtual environment, we can use the following command
|
||||||
|
to install dependencies of `icefall`_:
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
Remeber that we will run this tutorial on ``CPU``, so we install
|
||||||
|
dependencies required only by running on ``CPU``.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
# Caution: Installation order matters!
|
||||||
|
|
||||||
|
# We use torch 2.0.0 and torchaduio 2.0.0 in this tutorial.
|
||||||
|
# Other versions should also work.
|
||||||
|
|
||||||
|
pip install torch==2.0.0+cpu torchaudio==2.0.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
|
||||||
|
# If you are using macOS or Windows, please use the following command to install torch and torchaudio
|
||||||
|
# pip install torch==2.0.0 torchaudio==2.0.0 -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
|
||||||
|
# Now install k2
|
||||||
|
# Please refer to https://k2-fsa.github.io/k2/installation/from_wheels.html#linux-cpu-example
|
||||||
|
|
||||||
|
pip install k2==1.24.3.dev20230726+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html
|
||||||
|
|
||||||
|
# Install the latest version of lhotse
|
||||||
|
|
||||||
|
pip install git+https://github.com/lhotse-speech/lhotse
|
||||||
|
|
||||||
|
|
||||||
|
Install icefall
|
||||||
|
---------------
|
||||||
|
|
||||||
|
We will put the source code of `icefall`_ into the directory ``/tmp``
|
||||||
|
You can select any directory you want.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd /tmp
|
||||||
|
git clone https://github.com/k2-fsa/icefall
|
||||||
|
cd icefall
|
||||||
|
pip install -r ./requirements.txt
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
# Anytime we want to use icefall, we have to set the following
|
||||||
|
# environment variable
|
||||||
|
|
||||||
|
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
If you get the following error during this tutorial:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
ModuleNotFoundError: No module named 'icefall'
|
||||||
|
|
||||||
|
please set the above environment variable to fix it.
|
||||||
|
|
||||||
|
|
||||||
|
Congratulations! You have installed `icefall`_ successfully.
|
||||||
|
|
||||||
|
For the more curious
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
`icefall`_ contains a collection of Python scripts and you don't need to
|
||||||
|
use ``python3 setup.py install`` or ``pip install icefall`` to install it.
|
||||||
|
All you need to do is to download the code and set the environment variable
|
||||||
|
``PYTHONPATH``.
|
34
docs/source/for-dummies/index.rst
Normal file
34
docs/source/for-dummies/index.rst
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
Icefall for dummies tutorial
|
||||||
|
============================
|
||||||
|
|
||||||
|
This tutorial walks you step by step about how to create a simple
|
||||||
|
ASR (`Automatic Speech Recognition <https://en.wikipedia.org/wiki/Speech_recognition>`_)
|
||||||
|
system with `Next-gen Kaldi`_.
|
||||||
|
|
||||||
|
We use the `yesno`_ dataset for demonstration. We select it out of two reasons:
|
||||||
|
|
||||||
|
- It is quite tiny, containing only about 12 minutes of data
|
||||||
|
- The training can be finished within 20 seconds on ``CPU``.
|
||||||
|
|
||||||
|
That also means you don't need a ``GPU`` to run this tutorial.
|
||||||
|
|
||||||
|
Let's get started!
|
||||||
|
|
||||||
|
Please follow items below **sequentially**.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
The :ref:`dummies_tutorial_data_preparation` runs only on Linux and on macOS.
|
||||||
|
All other parts run on Linux, macOS, and Windows.
|
||||||
|
|
||||||
|
Help from the community is appreciated to port the :ref:`dummies_tutorial_data_preparation`
|
||||||
|
to Windows.
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
|
||||||
|
./environment-setup.rst
|
||||||
|
./data-preparation.rst
|
||||||
|
./training.rst
|
||||||
|
./decoding.rst
|
||||||
|
./model-export.rst
|
310
docs/source/for-dummies/model-export.rst
Normal file
310
docs/source/for-dummies/model-export.rst
Normal file
@ -0,0 +1,310 @@
|
|||||||
|
Model Export
|
||||||
|
============
|
||||||
|
|
||||||
|
There are three ways to export a pre-trained model.
|
||||||
|
|
||||||
|
- Export the model parameters via `model.state_dict() <https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.state_dict>`_
|
||||||
|
- Export via `torchscript <https://pytorch.org/docs/stable/jit.html>`_: either `torch.jit.script() <https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script>`_ or `torch.jit.trace() <https://pytorch.org/docs/stable/generated/torch.jit.trace.html>`_
|
||||||
|
- Export to `ONNX`_ via `torch.onnx.export() <https://pytorch.org/docs/stable/onnx.html>`_
|
||||||
|
|
||||||
|
Each method is explained below in detail.
|
||||||
|
|
||||||
|
Export the model parameters via model.state_dict()
|
||||||
|
---------------------------------------------------
|
||||||
|
|
||||||
|
The command for this kind of export is
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd /tmp/icefall
|
||||||
|
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||||
|
cd egs/yesno/ASR
|
||||||
|
|
||||||
|
# assume that "--epoch 14 --avg 2" produces the lowest WER.
|
||||||
|
|
||||||
|
./tdnn/export.py --epoch 14 --avg 2
|
||||||
|
|
||||||
|
The output logs are given below:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
2023-08-16 20:42:03,912 INFO [export.py:76] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'jit': False}
|
||||||
|
2023-08-16 20:42:03,913 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||||
|
2023-08-16 20:42:03,950 INFO [export.py:93] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
||||||
|
2023-08-16 20:42:03,971 INFO [export.py:106] Not using torch.jit.script
|
||||||
|
2023-08-16 20:42:03,974 INFO [export.py:111] Saved to tdnn/exp/pretrained.pt
|
||||||
|
|
||||||
|
We can see from the logs that the exported model is saved to the file ``tdnn/exp/pretrained.pt``.
|
||||||
|
|
||||||
|
To give you an idea of what ``tdnn/exp/pretrained.pt`` contains, we can use the following command:
|
||||||
|
|
||||||
|
.. code-block:: python3
|
||||||
|
|
||||||
|
>>> import torch
|
||||||
|
>>> m = torch.load("tdnn/exp/pretrained.pt")
|
||||||
|
>>> list(m.keys())
|
||||||
|
['model']
|
||||||
|
>>> list(m["model"].keys())
|
||||||
|
['tdnn.0.weight', 'tdnn.0.bias', 'tdnn.2.running_mean', 'tdnn.2.running_var', 'tdnn.2.num_batches_tracked', 'tdnn.3.weight', 'tdnn.3.bias', 'tdnn.5.running_mean', 'tdnn.5.running_var', 'tdnn.5.num_batches_tracked', 'tdnn.6.weight', 'tdnn.6.bias', 'tdnn.8.running_mean', 'tdnn.8.running_var', 'tdnn.8.num_batches_tracked', 'output_linear.weight', 'output_linear.bias']
|
||||||
|
|
||||||
|
We can use ``tdnn/exp/pretrained.pt`` in the following way with ``./tdnn/decode.py``:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd tdnn/exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
./tdnn/decode.py --epoch 99 --avg 1
|
||||||
|
|
||||||
|
The output logs of the above command are given below:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
2023-08-16 20:45:48,089 INFO [decode.py:262] Decoding started
|
||||||
|
2023-08-16 20:45:48,090 INFO [decode.py:263] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 99, 'avg': 1, 'export': False, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': False, 'k2-git-sha1': 'ad79f1c699c684de9785ed6ca5edb805a41f78c3', 'k2-git-date': 'Wed Jul 26 09:30:42 2023', 'lhotse-version': '1.16.0.dev+git.aa073f6.clean', 'torch-version': '2.0.0', 'torch-cuda-available': False, 'torch-cuda-version': None, 'python-version': '3.1', 'icefall-git-branch': 'master', 'icefall-git-sha1': '9a47c08-clean', 'icefall-git-date': 'Mon Aug 14 22:10:50 2023', 'icefall-path': '/private/tmp/icefall', 'k2-path': '/private/tmp/icefall_env/lib/python3.11/site-packages/k2/__init__.py', 'lhotse-path': '/private/tmp/icefall_env/lib/python3.11/site-packages/lhotse/__init__.py', 'hostname': 'fangjuns-MacBook-Pro.local', 'IP address': '127.0.0.1'}}
|
||||||
|
2023-08-16 20:45:48,092 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||||
|
2023-08-16 20:45:48,103 INFO [decode.py:272] device: cpu
|
||||||
|
2023-08-16 20:45:48,109 INFO [checkpoint.py:112] Loading checkpoint from tdnn/exp/epoch-99.pt
|
||||||
|
2023-08-16 20:45:48,115 INFO [asr_datamodule.py:218] About to get test cuts
|
||||||
|
2023-08-16 20:45:48,115 INFO [asr_datamodule.py:253] About to get test cuts
|
||||||
|
2023-08-16 20:45:50,386 INFO [decode.py:203] batch 0/?, cuts processed until now is 4
|
||||||
|
2023-08-16 20:45:50,556 INFO [decode.py:240] The transcripts are stored in tdnn/exp/recogs-test_set.txt
|
||||||
|
2023-08-16 20:45:50,557 INFO [utils.py:564] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
||||||
|
2023-08-16 20:45:50,558 INFO [decode.py:248] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
|
||||||
|
2023-08-16 20:45:50,559 INFO [decode.py:315] Done!
|
||||||
|
|
||||||
|
We can see that it produces an identical WER as before.
|
||||||
|
|
||||||
|
We can also use it to decode files with the following command:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
# ./tdnn/pretrained.py requires kaldifeat
|
||||||
|
#
|
||||||
|
# Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html
|
||||||
|
# for how to install kaldifeat
|
||||||
|
|
||||||
|
pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
|
||||||
|
|
||||||
|
./tdnn/pretrained.py \
|
||||||
|
--checkpoint ./tdnn/exp/pretrained.pt \
|
||||||
|
--HLG ./data/lang_phone/HLG.pt \
|
||||||
|
--words-file ./data/lang_phone/words.txt \
|
||||||
|
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
||||||
|
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
||||||
|
|
||||||
|
The output is given below:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
2023-08-16 20:53:19,208 INFO [pretrained.py:136] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'checkpoint': './tdnn/exp/pretrained.pt', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']}
|
||||||
|
2023-08-16 20:53:19,208 INFO [pretrained.py:142] device: cpu
|
||||||
|
2023-08-16 20:53:19,208 INFO [pretrained.py:144] Creating model
|
||||||
|
2023-08-16 20:53:19,212 INFO [pretrained.py:156] Loading HLG from ./data/lang_phone/HLG.pt
|
||||||
|
2023-08-16 20:53:19,213 INFO [pretrained.py:160] Constructing Fbank computer
|
||||||
|
2023-08-16 20:53:19,213 INFO [pretrained.py:170] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']
|
||||||
|
2023-08-16 20:53:19,224 INFO [pretrained.py:176] Decoding started
|
||||||
|
2023-08-16 20:53:19,304 INFO [pretrained.py:212]
|
||||||
|
download/waves_yesno/0_0_0_1_0_0_0_1.wav:
|
||||||
|
NO NO NO YES NO NO NO YES
|
||||||
|
|
||||||
|
download/waves_yesno/0_0_1_0_0_0_1_0.wav:
|
||||||
|
NO NO YES NO NO NO YES NO
|
||||||
|
|
||||||
|
|
||||||
|
2023-08-16 20:53:19,304 INFO [pretrained.py:214] Decoding Done
|
||||||
|
|
||||||
|
|
||||||
|
Export via torch.jit.script()
|
||||||
|
-----------------------------
|
||||||
|
|
||||||
|
The command for this kind of export is
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd /tmp/icefall
|
||||||
|
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||||
|
cd egs/yesno/ASR
|
||||||
|
|
||||||
|
# assume that "--epoch 14 --avg 2" produces the lowest WER.
|
||||||
|
|
||||||
|
./tdnn/export.py --epoch 14 --avg 2 --jit true
|
||||||
|
|
||||||
|
The output logs are given below:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
2023-08-16 20:47:44,666 INFO [export.py:76] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'jit': True}
|
||||||
|
2023-08-16 20:47:44,667 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||||
|
2023-08-16 20:47:44,670 INFO [export.py:93] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
||||||
|
2023-08-16 20:47:44,677 INFO [export.py:100] Using torch.jit.script
|
||||||
|
2023-08-16 20:47:44,843 INFO [export.py:104] Saved to tdnn/exp/cpu_jit.pt
|
||||||
|
|
||||||
|
From the output logs we can see that the generated file is saved to ``tdnn/exp/cpu_jit.pt``.
|
||||||
|
|
||||||
|
Don't be confused by the name ``cpu_jit.pt``. The ``cpu`` part means the model is moved to
|
||||||
|
CPU before exporting. That means, when you load it with:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
torch.jit.load()
|
||||||
|
|
||||||
|
you don't need to specify the argument `map_location <https://pytorch.org/docs/stable/generated/torch.jit.load.html#torch.jit.load>`_
|
||||||
|
and it resides on CPU by default.
|
||||||
|
|
||||||
|
To use ``tdnn/exp/cpu_jit.pt`` with `icefall`_ to decode files, we can use:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
# ./tdnn/jit_pretrained.py requires kaldifeat
|
||||||
|
#
|
||||||
|
# Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html
|
||||||
|
# for how to install kaldifeat
|
||||||
|
|
||||||
|
pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
|
||||||
|
|
||||||
|
|
||||||
|
./tdnn/jit_pretrained.py \
|
||||||
|
--nn-model ./tdnn/exp/cpu_jit.pt \
|
||||||
|
--HLG ./data/lang_phone/HLG.pt \
|
||||||
|
--words-file ./data/lang_phone/words.txt \
|
||||||
|
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
||||||
|
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
||||||
|
|
||||||
|
The output is given below:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
2023-08-16 20:56:00,603 INFO [jit_pretrained.py:121] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'nn_model': './tdnn/exp/cpu_jit.pt', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']}
|
||||||
|
2023-08-16 20:56:00,603 INFO [jit_pretrained.py:127] device: cpu
|
||||||
|
2023-08-16 20:56:00,603 INFO [jit_pretrained.py:129] Loading torchscript model
|
||||||
|
2023-08-16 20:56:00,640 INFO [jit_pretrained.py:134] Loading HLG from ./data/lang_phone/HLG.pt
|
||||||
|
2023-08-16 20:56:00,641 INFO [jit_pretrained.py:138] Constructing Fbank computer
|
||||||
|
2023-08-16 20:56:00,641 INFO [jit_pretrained.py:148] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']
|
||||||
|
2023-08-16 20:56:00,642 INFO [jit_pretrained.py:154] Decoding started
|
||||||
|
2023-08-16 20:56:00,727 INFO [jit_pretrained.py:190]
|
||||||
|
download/waves_yesno/0_0_0_1_0_0_0_1.wav:
|
||||||
|
NO NO NO YES NO NO NO YES
|
||||||
|
|
||||||
|
download/waves_yesno/0_0_1_0_0_0_1_0.wav:
|
||||||
|
NO NO YES NO NO NO YES NO
|
||||||
|
|
||||||
|
|
||||||
|
2023-08-16 20:56:00,727 INFO [jit_pretrained.py:192] Decoding Done
|
||||||
|
|
||||||
|
.. hint::
|
||||||
|
|
||||||
|
We provide only code for ``torch.jit.script()``. You can try ``torch.jit.trace()``
|
||||||
|
if you want.
|
||||||
|
|
||||||
|
Export via torch.onnx.export()
|
||||||
|
------------------------------
|
||||||
|
|
||||||
|
The command for this kind of export is
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd /tmp/icefall
|
||||||
|
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||||
|
cd egs/yesno/ASR
|
||||||
|
|
||||||
|
# tdnn/export_onnx.py requires onnx and onnxruntime
|
||||||
|
pip install onnx onnxruntime
|
||||||
|
|
||||||
|
# assume that "--epoch 14 --avg 2" produces the lowest WER.
|
||||||
|
|
||||||
|
./tdnn/export_onnx.py \
|
||||||
|
--epoch 14 \
|
||||||
|
--avg 2
|
||||||
|
|
||||||
|
The output logs are given below:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
2023-08-16 20:59:20,888 INFO [export_onnx.py:83] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2}
|
||||||
|
2023-08-16 20:59:20,888 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
||||||
|
2023-08-16 20:59:20,892 INFO [export_onnx.py:100] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
||||||
|
================ Diagnostic Run torch.onnx.export version 2.0.0 ================
|
||||||
|
verbose: False, log level: Level.ERROR
|
||||||
|
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
|
||||||
|
|
||||||
|
2023-08-16 20:59:21,047 INFO [export_onnx.py:127] Saved to tdnn/exp/model-epoch-14-avg-2.onnx
|
||||||
|
2023-08-16 20:59:21,047 INFO [export_onnx.py:136] meta_data: {'model_type': 'tdnn', 'version': '1', 'model_author': 'k2-fsa', 'comment': 'non-streaming tdnn for the yesno recipe', 'vocab_size': 4}
|
||||||
|
2023-08-16 20:59:21,049 INFO [export_onnx.py:140] Generate int8 quantization models
|
||||||
|
2023-08-16 20:59:21,075 INFO [onnx_quantizer.py:538] Quantization parameters for tensor:"/Transpose_1_output_0" not specified
|
||||||
|
2023-08-16 20:59:21,081 INFO [export_onnx.py:151] Saved to tdnn/exp/model-epoch-14-avg-2.int8.onnx
|
||||||
|
|
||||||
|
We can see from the logs that it generates two files:
|
||||||
|
|
||||||
|
- ``tdnn/exp/model-epoch-14-avg-2.onnx`` (ONNX model with ``float32`` weights)
|
||||||
|
- ``tdnn/exp/model-epoch-14-avg-2.int8.onnx`` (ONNX model with ``int8`` weights)
|
||||||
|
|
||||||
|
To use the generated ONNX model files for decoding with `onnxruntime`_, we can use
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
# ./tdnn/onnx_pretrained.py requires kaldifeat
|
||||||
|
#
|
||||||
|
# Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html
|
||||||
|
# for how to install kaldifeat
|
||||||
|
|
||||||
|
pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
|
||||||
|
|
||||||
|
./tdnn/onnx_pretrained.py \
|
||||||
|
--nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
|
||||||
|
--HLG ./data/lang_phone/HLG.pt \
|
||||||
|
--words-file ./data/lang_phone/words.txt \
|
||||||
|
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
||||||
|
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
||||||
|
|
||||||
|
The output is given below:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:166] {'feature_dim': 23, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'nn_model': './tdnn/exp/model-epoch-14-avg-2.onnx', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']}
|
||||||
|
2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:171] device: cpu
|
||||||
|
2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:173] Loading onnx model ./tdnn/exp/model-epoch-14-avg-2.onnx
|
||||||
|
2023-08-16 21:03:24,267 INFO [onnx_pretrained.py:176] Loading HLG from ./data/lang_phone/HLG.pt
|
||||||
|
2023-08-16 21:03:24,270 INFO [onnx_pretrained.py:180] Constructing Fbank computer
|
||||||
|
2023-08-16 21:03:24,273 INFO [onnx_pretrained.py:190] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']
|
||||||
|
2023-08-16 21:03:24,279 INFO [onnx_pretrained.py:196] Decoding started
|
||||||
|
2023-08-16 21:03:24,318 INFO [onnx_pretrained.py:232]
|
||||||
|
download/waves_yesno/0_0_0_1_0_0_0_1.wav:
|
||||||
|
NO NO NO YES NO NO NO YES
|
||||||
|
|
||||||
|
download/waves_yesno/0_0_1_0_0_0_1_0.wav:
|
||||||
|
NO NO YES NO NO NO YES NO
|
||||||
|
|
||||||
|
|
||||||
|
2023-08-16 21:03:24,318 INFO [onnx_pretrained.py:234] Decoding Done
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
To use the ``int8`` ONNX model for decoding, please use:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./tdnn/onnx_pretrained.py \
|
||||||
|
--nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
|
||||||
|
--HLG ./data/lang_phone/HLG.pt \
|
||||||
|
--words-file ./data/lang_phone/words.txt \
|
||||||
|
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
||||||
|
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
||||||
|
|
||||||
|
For the more curious
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
If you are wondering how to deploy the model without ``torch``, please
|
||||||
|
continue reading. We will show how to use `sherpa-onnx`_ to run the
|
||||||
|
exported ONNX models, which depends only on `onnxruntime`_ and does not
|
||||||
|
depend on ``torch``.
|
||||||
|
|
||||||
|
In this tutorial, we will only demonstrate the usage of `sherpa-onnx`_ with the
|
||||||
|
pre-trained model of the `yesno`_ recipe. There are also other two frameworks
|
||||||
|
available:
|
||||||
|
|
||||||
|
- `sherpa`_. It works with torchscript models.
|
||||||
|
- `sherpa-ncnn`_. It works with models exported using :ref:`icefall_export_to_ncnn` with `ncnn`_
|
||||||
|
|
||||||
|
Please see `<https://k2-fsa.github.io/sherpa/>`_ for further details.
|
39
docs/source/for-dummies/training.rst
Normal file
39
docs/source/for-dummies/training.rst
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
.. _dummies_tutorial_training:
|
||||||
|
|
||||||
|
Training
|
||||||
|
========
|
||||||
|
|
||||||
|
After :ref:`dummies_tutorial_data_preparation`, we can start training.
|
||||||
|
|
||||||
|
The command to start the training is quite simple:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
cd /tmp/icefall
|
||||||
|
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
||||||
|
cd egs/yesno/ASR
|
||||||
|
|
||||||
|
# We use CPU for training by setting the following environment variable
|
||||||
|
export CUDA_VISIBLE_DEVICES=""
|
||||||
|
|
||||||
|
./tdnn/train.py
|
||||||
|
|
||||||
|
That's it!
|
||||||
|
|
||||||
|
You can find the training logs below:
|
||||||
|
|
||||||
|
.. literalinclude:: ./code/train-yesno.txt
|
||||||
|
|
||||||
|
For the more curious
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
./tdnn/train.py --help
|
||||||
|
|
||||||
|
will print the usage information about ``./tdnn/train.py``. For instance, you
|
||||||
|
can specify the number of epochs to train and the location to save the training
|
||||||
|
results.
|
||||||
|
|
||||||
|
The training text logs are saved in ``tdnn/exp/log`` while the tensorboard
|
||||||
|
logs are in ``tdnn/exp/tensorboard``.
|
@ -20,6 +20,7 @@ speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
|
|||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
:caption: Contents:
|
:caption: Contents:
|
||||||
|
|
||||||
|
for-dummies/index.rst
|
||||||
installation/index
|
installation/index
|
||||||
docker/index
|
docker/index
|
||||||
faqs
|
faqs
|
||||||
|
@ -41,7 +41,7 @@ as an example.
|
|||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ In each recipe, there is also a file ``pretrained.py``, which can use
|
|||||||
|
|
||||||
./pruned_transducer_stateless3/pretrained.py \
|
./pruned_transducer_stateless3/pretrained.py \
|
||||||
--checkpoint ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/pretrained-iter-1224000-avg-14.pt \
|
--checkpoint ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/pretrained-iter-1224000-avg-14.pt \
|
||||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model \
|
--tokens ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav \
|
||||||
|
@ -153,11 +153,10 @@ Next, we use the following code to export our model:
|
|||||||
|
|
||||||
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir $dir/exp \
|
--exp-dir $dir/exp \
|
||||||
--bpe-model $dir/data/lang_bpe_500/bpe.model \
|
--tokens $dir/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
\
|
|
||||||
--num-encoder-layers 12 \
|
--num-encoder-layers 12 \
|
||||||
--chunk-length 32 \
|
--chunk-length 32 \
|
||||||
--cnn-module-kernel 31 \
|
--cnn-module-kernel 31 \
|
||||||
|
@ -73,7 +73,7 @@ Next, we use the following code to export our model:
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export-for-ncnn.py \
|
./lstm_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir $dir/exp \
|
--exp-dir $dir/exp \
|
||||||
--bpe-model $dir/data/lang_bpe_500/bpe.model \
|
--tokens $dir/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
|
@ -72,12 +72,11 @@ Next, we use the following code to export our model:
|
|||||||
dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
||||||
--bpe-model $dir/data/lang_bpe_500/bpe.model \
|
--tokens $dir/data/lang_bpe_500/tokens.txt \
|
||||||
--exp-dir $dir/exp \
|
--exp-dir $dir/exp \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
\
|
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
--num-left-chunks 4 \
|
--num-left-chunks 4 \
|
||||||
--num-encoder-layers "2,4,3,2,4" \
|
--num-encoder-layers "2,4,3,2,4" \
|
||||||
|
@ -71,7 +71,7 @@ Export the model to ONNX
|
|||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
|
@ -32,7 +32,7 @@ as an example in the following.
|
|||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch $epoch \
|
--epoch $epoch \
|
||||||
--avg $avg \
|
--avg $avg \
|
||||||
--jit 1
|
--jit 1
|
||||||
|
@ -33,7 +33,7 @@ as an example in the following.
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
./lstm_transducer_stateless2/export.py \
|
||||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--iter $iter \
|
--iter $iter \
|
||||||
--avg $avg \
|
--avg $avg \
|
||||||
--jit-trace 1
|
--jit-trace 1
|
||||||
|
@ -32,7 +32,7 @@ import torch
|
|||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||||
src_dir = Path("data/manifests/aidatatang_200zh")
|
src_dir = Path("data/manifests/aidatatang_200zh")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -85,7 +85,8 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
|||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition and perturb_speed:
|
||||||
|
logging.info(f"Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -109,7 +110,12 @@ def get_args():
|
|||||||
default=80,
|
default=80,
|
||||||
help="""The number of mel bins for Fbank""",
|
help="""The number of mel bins for Fbank""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -119,4 +125,6 @@ if __name__ == "__main__":
|
|||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins)
|
compute_fbank_aidatatang_200zh(
|
||||||
|
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||||
|
)
|
||||||
|
@ -77,7 +77,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
log "Stage 4: Compute fbank for aidatatang_200zh"
|
log "Stage 4: Compute fbank for aidatatang_200zh"
|
||||||
if [ ! -f data/fbank/.aidatatang_200zh.done ]; then
|
if [ ! -f data/fbank/.aidatatang_200zh.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aidatatang_200zh.py
|
./local/compute_fbank_aidatatang_200zh.py --perturb-speed True
|
||||||
touch data/fbank/.aidatatang_200zh.done
|
touch data/fbank/.aidatatang_200zh.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -37,7 +37,7 @@ from lhotse.dataset import (
|
|||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||||
@ -291,8 +291,8 @@ class Aidatatang_200zhAsrDataModule:
|
|||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
train_sampler = SingleCutSampler(
|
train_sampler = SimpleCutSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
@ -32,7 +32,7 @@ import torch
|
|||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -85,7 +85,8 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80):
|
|||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition and perturb_speed:
|
||||||
|
logging.info(f"Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -109,7 +110,12 @@ def get_args():
|
|||||||
default=80,
|
default=80,
|
||||||
help="""The number of mel bins for Fbank""",
|
help="""The number of mel bins for Fbank""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -119,4 +125,6 @@ if __name__ == "__main__":
|
|||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins)
|
compute_fbank_aidatatang_200zh(
|
||||||
|
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||||
|
)
|
||||||
|
@ -32,7 +32,7 @@ import torch
|
|||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aishell(num_mel_bins: int = 80):
|
def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -81,7 +81,8 @@ def compute_fbank_aishell(num_mel_bins: int = 80):
|
|||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition and perturb_speed:
|
||||||
|
logging.info(f"Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -104,7 +105,12 @@ def get_args():
|
|||||||
default=80,
|
default=80,
|
||||||
help="""The number of mel bins for Fbank""",
|
help="""The number of mel bins for Fbank""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -114,4 +120,6 @@ if __name__ == "__main__":
|
|||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aishell(num_mel_bins=args.num_mel_bins)
|
compute_fbank_aishell(
|
||||||
|
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||||
|
)
|
||||||
|
@ -114,7 +114,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
log "Stage 3: Compute fbank for aishell"
|
log "Stage 3: Compute fbank for aishell"
|
||||||
if [ ! -f data/fbank/.aishell.done ]; then
|
if [ ! -f data/fbank/.aishell.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aishell.py
|
./local/compute_fbank_aishell.py --perturb-speed True
|
||||||
touch data/fbank/.aishell.done
|
touch data/fbank/.aishell.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -53,7 +53,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
log "Stage 2: Process aidatatang_200zh"
|
log "Stage 2: Process aidatatang_200zh"
|
||||||
if [ ! -f data/fbank/.aidatatang_200zh_fbank.done ]; then
|
if [ ! -f data/fbank/.aidatatang_200zh_fbank.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aidatatang_200zh.py
|
./local/compute_fbank_aidatatang_200zh.py --perturb-speed True
|
||||||
touch data/fbank/.aidatatang_200zh_fbank.done
|
touch data/fbank/.aidatatang_200zh_fbank.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -1,321 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
#
|
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# This script converts several saved checkpoints
|
|
||||||
# to a single one using model averaging.
|
|
||||||
"""
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
|
|
||||||
(1) Export to torchscript model using torch.jit.script()
|
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
|
||||||
--lang-dir data/lang_char \
|
|
||||||
--epoch 30 \
|
|
||||||
--avg 9 \
|
|
||||||
--jit 1
|
|
||||||
|
|
||||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
|
||||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
|
||||||
|
|
||||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
|
||||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
|
||||||
|
|
||||||
Check
|
|
||||||
https://github.com/k2-fsa/sherpa
|
|
||||||
for how to use the exported models outside of icefall.
|
|
||||||
|
|
||||||
(2) Export `model.state_dict()`
|
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
|
||||||
--lang-dir data/lang_char \
|
|
||||||
--epoch 20 \
|
|
||||||
--avg 10
|
|
||||||
|
|
||||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
|
||||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
|
||||||
|
|
||||||
To use the generated file with `pruned_transducer_stateless7/decode.py`,
|
|
||||||
you can do:
|
|
||||||
|
|
||||||
cd /path/to/exp_dir
|
|
||||||
ln -s pretrained.pt epoch-9999.pt
|
|
||||||
|
|
||||||
cd /path/to/egs/librispeech/ASR
|
|
||||||
./pruned_transducer_stateless7/decode.py \
|
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
|
||||||
--epoch 9999 \
|
|
||||||
--avg 1 \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method greedy_search \
|
|
||||||
--lang-dir data/lang_char
|
|
||||||
|
|
||||||
Check ./pretrained.py for its usage.
|
|
||||||
|
|
||||||
Note: If you don't want to train a model from scratch, we have
|
|
||||||
provided one for you. You can get it at
|
|
||||||
|
|
||||||
https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21
|
|
||||||
|
|
||||||
with the following commands:
|
|
||||||
|
|
||||||
sudo apt-get install git-lfs
|
|
||||||
git lfs install
|
|
||||||
git clone https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21
|
|
||||||
# You will find the pre-trained model in icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21exp
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
|
||||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
|
||||||
average_checkpoints,
|
|
||||||
average_checkpoints_with_averaged_model,
|
|
||||||
find_checkpoints,
|
|
||||||
load_checkpoint,
|
|
||||||
)
|
|
||||||
from icefall.lexicon import Lexicon
|
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--epoch",
|
|
||||||
type=int,
|
|
||||||
default=30,
|
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
|
||||||
Note: Epoch counts from 1.
|
|
||||||
You can specify --avg to use more checkpoints for model averaging.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--iter",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="""If positive, --epoch is ignored and it
|
|
||||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
|
||||||
You can specify --avg to use more checkpoints for model averaging.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--avg",
|
|
||||||
type=int,
|
|
||||||
default=9,
|
|
||||||
help="Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
|
||||||
"'--epoch' and '--iter'",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-averaged-model",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="Whether to load averaged model. Currently it only supports "
|
|
||||||
"using --epoch. If True, it would decode with the averaged model "
|
|
||||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
|
||||||
"Actually only the models with epoch number of `epoch-avg` and "
|
|
||||||
"`epoch` are loaded for averaging. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--exp-dir",
|
|
||||||
type=str,
|
|
||||||
default="pruned_transducer_stateless7/exp",
|
|
||||||
help="""It specifies the directory where all training related
|
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--lang-dir",
|
|
||||||
type=str,
|
|
||||||
default="data/lang_char",
|
|
||||||
help="""The lang dir
|
|
||||||
It contains language related input files such as
|
|
||||||
"lexicon.txt"
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--jit",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""True to save a model after applying torch.jit.script.
|
|
||||||
It will generate a file named cpu_jit.pt
|
|
||||||
|
|
||||||
Check ./jit_pretrained.py for how to use it.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--context-size",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main():
|
|
||||||
args = get_parser().parse_args()
|
|
||||||
args.exp_dir = Path(args.exp_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
|
||||||
params.update(vars(args))
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda", 0)
|
|
||||||
|
|
||||||
logging.info(f"device: {device}")
|
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
|
||||||
params.blank_id = 0
|
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
|
||||||
|
|
||||||
logging.info(params)
|
|
||||||
|
|
||||||
logging.info("About to create model")
|
|
||||||
model = get_transducer_model(params)
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
if not params.use_averaged_model:
|
|
||||||
if params.iter > 0:
|
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
|
||||||
: params.avg
|
|
||||||
]
|
|
||||||
if len(filenames) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"No checkpoints found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
elif len(filenames) < params.avg:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
logging.info(f"averaging {filenames}")
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
|
||||||
elif params.avg == 1:
|
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
||||||
else:
|
|
||||||
start = params.epoch - params.avg + 1
|
|
||||||
filenames = []
|
|
||||||
for i in range(start, params.epoch + 1):
|
|
||||||
if i >= 1:
|
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
|
||||||
logging.info(f"averaging {filenames}")
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
|
||||||
else:
|
|
||||||
if params.iter > 0:
|
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
|
||||||
: params.avg + 1
|
|
||||||
]
|
|
||||||
if len(filenames) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"No checkpoints found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
elif len(filenames) < params.avg + 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
filename_start = filenames[-1]
|
|
||||||
filename_end = filenames[0]
|
|
||||||
logging.info(
|
|
||||||
"Calculating the averaged model over iteration checkpoints"
|
|
||||||
f" from {filename_start} (excluded) to {filename_end}"
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(
|
|
||||||
average_checkpoints_with_averaged_model(
|
|
||||||
filename_start=filename_start,
|
|
||||||
filename_end=filename_end,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert params.avg > 0, params.avg
|
|
||||||
start = params.epoch - params.avg
|
|
||||||
assert start >= 1, start
|
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
|
||||||
logging.info(
|
|
||||||
f"Calculating the averaged model over epoch range from "
|
|
||||||
f"{start} (excluded) to {params.epoch}"
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(
|
|
||||||
average_checkpoints_with_averaged_model(
|
|
||||||
filename_start=filename_start,
|
|
||||||
filename_end=filename_end,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
model.to("cpu")
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
if params.jit is True:
|
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
|
||||||
# We won't use the forward() method of the model in C++, so just ignore
|
|
||||||
# it here.
|
|
||||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
|
||||||
# torch scriptabe.
|
|
||||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
|
||||||
logging.info("Using torch.jit.script")
|
|
||||||
model = torch.jit.script(model)
|
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
|
||||||
model.save(str(filename))
|
|
||||||
logging.info(f"Saved to {filename}")
|
|
||||||
else:
|
|
||||||
logging.info("Not using torchscript. Export model.state_dict()")
|
|
||||||
# Save it using a format so that it can be loaded
|
|
||||||
# by :func:`load_checkpoint`
|
|
||||||
filename = params.exp_dir / "pretrained.pt"
|
|
||||||
torch.save({"model": model.state_dict()}, str(filename))
|
|
||||||
logging.info(f"Saved to {filename}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
main()
|
|
1
egs/aishell/ASR/pruned_transducer_stateless7/export.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7/export.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7/export.py
|
@ -1,348 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
This script loads a checkpoint and uses it to decode waves.
|
|
||||||
You can generate the checkpoint with the following command:
|
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
|
||||||
--lang-dir data/lang_char \
|
|
||||||
--epoch 20 \
|
|
||||||
--avg 10
|
|
||||||
|
|
||||||
Usage of this script:
|
|
||||||
|
|
||||||
(1) greedy search
|
|
||||||
./pruned_transducer_stateless7/pretrained.py \
|
|
||||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
|
||||||
--lang-dir ./data/lang_char \
|
|
||||||
--method greedy_search \
|
|
||||||
/path/to/foo.wav \
|
|
||||||
/path/to/bar.wav
|
|
||||||
|
|
||||||
(2) beam search
|
|
||||||
./pruned_transducer_stateless7/pretrained.py \
|
|
||||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
|
||||||
--lang-dir ./data/lang_char \
|
|
||||||
--method beam_search \
|
|
||||||
--beam-size 4 \
|
|
||||||
/path/to/foo.wav \
|
|
||||||
/path/to/bar.wav
|
|
||||||
|
|
||||||
(3) modified beam search
|
|
||||||
./pruned_transducer_stateless7/pretrained.py \
|
|
||||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
|
||||||
--lang-dir ./data/lang_char \
|
|
||||||
--method modified_beam_search \
|
|
||||||
--beam-size 4 \
|
|
||||||
/path/to/foo.wav \
|
|
||||||
/path/to/bar.wav
|
|
||||||
|
|
||||||
(4) fast beam search
|
|
||||||
./pruned_transducer_stateless7/pretrained.py \
|
|
||||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
|
||||||
--lang-dir ./data/lang_char \
|
|
||||||
--method fast_beam_search \
|
|
||||||
--beam-size 4 \
|
|
||||||
/path/to/foo.wav \
|
|
||||||
/path/to/bar.wav
|
|
||||||
|
|
||||||
You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`.
|
|
||||||
|
|
||||||
Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by
|
|
||||||
./pruned_transducer_stateless7/export.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import k2
|
|
||||||
import kaldifeat
|
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
from beam_search import (
|
|
||||||
beam_search,
|
|
||||||
fast_beam_search_one_best,
|
|
||||||
greedy_search,
|
|
||||||
greedy_search_batch,
|
|
||||||
modified_beam_search,
|
|
||||||
)
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
|
||||||
|
|
||||||
from icefall.lexicon import Lexicon
|
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--checkpoint",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the checkpoint. "
|
|
||||||
"The checkpoint is assumed to be saved by "
|
|
||||||
"icefall.checkpoint.save_checkpoint().",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--lang-dir",
|
|
||||||
type=str,
|
|
||||||
help="""The lang dir
|
|
||||||
It contains language related input files such as
|
|
||||||
"lexicon.txt"
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--method",
|
|
||||||
type=str,
|
|
||||||
default="greedy_search",
|
|
||||||
help="""Possible values are:
|
|
||||||
- greedy_search
|
|
||||||
- beam_search
|
|
||||||
- modified_beam_search
|
|
||||||
- fast_beam_search
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"sound_files",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
help="The input sound file(s) to transcribe. "
|
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
|
||||||
"For example, wav and flac are supported. "
|
|
||||||
"The sample rate has to be 16kHz.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--sample-rate",
|
|
||||||
type=int,
|
|
||||||
default=16000,
|
|
||||||
help="The sample rate of the input sound file",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--beam-size",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="""An integer indicating how many candidates we will keep for each
|
|
||||||
frame. Used only when --method is beam_search or
|
|
||||||
modified_beam_search.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--beam",
|
|
||||||
type=float,
|
|
||||||
default=4,
|
|
||||||
help="""A floating point value to calculate the cutoff score during beam
|
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
|
||||||
`beam` in Kaldi.
|
|
||||||
Used only when --method is fast_beam_search""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-contexts",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="""Used only when --method is fast_beam_search""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-states",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="""Used only when --method is fast_beam_search""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--context-size",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-sym-per-frame",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="""Maximum number of symbols per frame. Used only when
|
|
||||||
--method is greedy_search.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def read_sound_files(
|
|
||||||
filenames: List[str], expected_sample_rate: float
|
|
||||||
) -> List[torch.Tensor]:
|
|
||||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
|
||||||
Args:
|
|
||||||
filenames:
|
|
||||||
A list of sound filenames.
|
|
||||||
expected_sample_rate:
|
|
||||||
The expected sample rate of the sound files.
|
|
||||||
Returns:
|
|
||||||
Return a list of 1-D float32 torch tensors.
|
|
||||||
"""
|
|
||||||
ans = []
|
|
||||||
for f in filenames:
|
|
||||||
wave, sample_rate = torchaudio.load(f)
|
|
||||||
assert (
|
|
||||||
sample_rate == expected_sample_rate
|
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
|
||||||
# We use only the first channel
|
|
||||||
ans.append(wave[0])
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main():
|
|
||||||
parser = get_parser()
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
params = get_params()
|
|
||||||
|
|
||||||
params.update(vars(args))
|
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
|
||||||
params.blank_id = 0
|
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
|
||||||
token_table = lexicon.token_table
|
|
||||||
|
|
||||||
logging.info(f"{params}")
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda", 0)
|
|
||||||
|
|
||||||
logging.info(f"device: {device}")
|
|
||||||
|
|
||||||
logging.info("Creating model")
|
|
||||||
model = get_transducer_model(params)
|
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
|
||||||
model.load_state_dict(checkpoint["model"], strict=False)
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
model.device = device
|
|
||||||
|
|
||||||
logging.info("Constructing Fbank computer")
|
|
||||||
opts = kaldifeat.FbankOptions()
|
|
||||||
opts.device = device
|
|
||||||
opts.frame_opts.dither = 0
|
|
||||||
opts.frame_opts.snip_edges = False
|
|
||||||
opts.frame_opts.samp_freq = params.sample_rate
|
|
||||||
opts.mel_opts.num_bins = params.feature_dim
|
|
||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
|
||||||
|
|
||||||
logging.info(f"Reading sound files: {params.sound_files}")
|
|
||||||
waves = read_sound_files(
|
|
||||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
|
||||||
)
|
|
||||||
waves = [w.to(device) for w in waves]
|
|
||||||
|
|
||||||
logging.info("Decoding started")
|
|
||||||
features = fbank(waves)
|
|
||||||
feature_lengths = [f.size(0) for f in features]
|
|
||||||
|
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
|
||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
|
||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
|
||||||
hyps = []
|
|
||||||
msg = f"Using {params.method}"
|
|
||||||
if params.method == "beam_search":
|
|
||||||
msg += f" with beam size {params.beam_size}"
|
|
||||||
logging.info(msg)
|
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
|
||||||
model=model,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam,
|
|
||||||
max_contexts=params.max_contexts,
|
|
||||||
max_states=params.max_states,
|
|
||||||
)
|
|
||||||
elif params.method == "modified_beam_search":
|
|
||||||
hyp_tokens = modified_beam_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam_size,
|
|
||||||
)
|
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
|
||||||
hyp_tokens = greedy_search_batch(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for i in range(num_waves):
|
|
||||||
# fmt: off
|
|
||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
|
||||||
# fmt: on
|
|
||||||
if params.method == "greedy_search":
|
|
||||||
hyp_tokens = greedy_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
|
||||||
)
|
|
||||||
elif params.method == "beam_search":
|
|
||||||
hyp_tokens = beam_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
beam=params.beam_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
|
||||||
|
|
||||||
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
|
|
||||||
s = "\n"
|
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
|
||||||
words = " ".join(hyp)
|
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
main()
|
|
1
egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py
|
@ -30,7 +30,7 @@ from lhotse.dataset import (
|
|||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||||
@ -278,8 +278,8 @@ class AishellAsrDataModule:
|
|||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
train_sampler = SingleCutSampler(
|
train_sampler = SimpleCutSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
@ -32,7 +32,7 @@ import torch
|
|||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aishell2(num_mel_bins: int = 80):
|
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -81,7 +81,8 @@ def compute_fbank_aishell2(num_mel_bins: int = 80):
|
|||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition and perturb_speed:
|
||||||
|
logging.info(f"Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -104,6 +105,12 @@ def get_args():
|
|||||||
default=80,
|
default=80,
|
||||||
help="""The number of mel bins for Fbank""",
|
help="""The number of mel bins for Fbank""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -114,4 +121,6 @@ if __name__ == "__main__":
|
|||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aishell2(num_mel_bins=args.num_mel_bins)
|
compute_fbank_aishell2(
|
||||||
|
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||||
|
)
|
||||||
|
@ -101,7 +101,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
log "Stage 3: Compute fbank for aishell2"
|
log "Stage 3: Compute fbank for aishell2"
|
||||||
if [ ! -f data/fbank/.aishell2.done ]; then
|
if [ ! -f data/fbank/.aishell2.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aishell2.py
|
./local/compute_fbank_aishell2.py --perturb-speed True
|
||||||
touch data/fbank/.aishell2.done
|
touch data/fbank/.aishell2.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
@ -299,8 +299,8 @@ class AiShell2AsrDataModule:
|
|||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
train_sampler = SingleCutSampler(
|
train_sampler = SimpleCutSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
@ -32,7 +32,7 @@ import torch
|
|||||||
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aishell4(num_mel_bins: int = 80):
|
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||||
src_dir = Path("data/manifests/aishell4")
|
src_dir = Path("data/manifests/aishell4")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -83,10 +83,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
|
|||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition and perturb_speed:
|
||||||
|
logging.info(f"Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
|
|
||||||
cut_set = cut_set.compute_and_store_features(
|
cut_set = cut_set.compute_and_store_features(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||||
@ -113,6 +115,12 @@ def get_args():
|
|||||||
default=80,
|
default=80,
|
||||||
help="""The number of mel bins for Fbank""",
|
help="""The number of mel bins for Fbank""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -123,4 +131,6 @@ if __name__ == "__main__":
|
|||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aishell4(num_mel_bins=args.num_mel_bins)
|
compute_fbank_aishell4(
|
||||||
|
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||||
|
)
|
||||||
|
@ -107,7 +107,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
log "Stage 5: Compute fbank for aishell4"
|
log "Stage 5: Compute fbank for aishell4"
|
||||||
if [ ! -f data/fbank/.aishell4.done ]; then
|
if [ ! -f data/fbank/.aishell4.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_aishell4.py
|
./local/compute_fbank_aishell4.py --perturb-speed True
|
||||||
touch data/fbank/.aishell4.done
|
touch data/fbank/.aishell4.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -30,7 +30,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 for AudioSamples
|
from lhotse.dataset.input_strategies import ( # noqa F401 for AudioSamples
|
||||||
@ -310,8 +310,8 @@ class Aishell4AsrDataModule:
|
|||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
train_sampler = SingleCutSampler(
|
train_sampler = SimpleCutSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
@ -32,7 +32,7 @@ import torch
|
|||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -42,7 +42,7 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_alimeeting(num_mel_bins: int = 80):
|
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False):
|
||||||
src_dir = Path("data/manifests/alimeeting")
|
src_dir = Path("data/manifests/alimeeting")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -82,7 +82,8 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
|
|||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition and perturb_speed:
|
||||||
|
logging.info(f"Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -114,6 +115,12 @@ def get_args():
|
|||||||
default=80,
|
default=80,
|
||||||
help="""The number of mel bins for Fbank""",
|
help="""The number of mel bins for Fbank""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -124,4 +131,6 @@ if __name__ == "__main__":
|
|||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_alimeeting(num_mel_bins=args.num_mel_bins)
|
compute_fbank_alimeeting(
|
||||||
|
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
|
||||||
|
)
|
||||||
|
@ -97,7 +97,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
log "Stage 5: Compute fbank for alimeeting"
|
log "Stage 5: Compute fbank for alimeeting"
|
||||||
if [ ! -f data/fbank/.alimeeting.done ]; then
|
if [ ! -f data/fbank/.alimeeting.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
./local/compute_fbank_alimeeting.py
|
./local/compute_fbank_alimeeting.py --perturb-speed True
|
||||||
touch data/fbank/.alimeeting.done
|
touch data/fbank/.alimeeting.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -37,7 +37,7 @@ from lhotse.dataset import (
|
|||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||||
@ -292,8 +292,8 @@ class AlimeetingAsrDataModule:
|
|||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
train_sampler = SingleCutSampler(
|
train_sampler = SimpleCutSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
@ -25,6 +25,7 @@ It looks for manifests in the directory data/manifests.
|
|||||||
|
|
||||||
The generated fbank features are saved in data/fbank.
|
The generated fbank features are saved in data/fbank.
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -39,6 +40,8 @@ from lhotse.features.kaldifeat import (
|
|||||||
)
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
# Do this outside of main() in case it needs to take effect
|
# Do this outside of main() in case it needs to take effect
|
||||||
@ -48,7 +51,7 @@ torch.set_num_interop_threads(1)
|
|||||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_ami():
|
def compute_fbank_ami(perturb_speed: bool = False):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
|
|
||||||
@ -84,8 +87,12 @@ def compute_fbank_ami():
|
|||||||
suffix="jsonl.gz",
|
suffix="jsonl.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None:
|
def _extract_feats(
|
||||||
cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
|
cuts: CutSet, storage_path: Path, manifest_path: Path, speed_perturb: bool
|
||||||
|
) -> None:
|
||||||
|
if speed_perturb:
|
||||||
|
logging.info(f"Doing speed perturb")
|
||||||
|
cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1)
|
||||||
_ = cuts.compute_and_store_features_batch(
|
_ = cuts.compute_and_store_features_batch(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=storage_path,
|
storage_path=storage_path,
|
||||||
@ -109,6 +116,7 @@ def compute_fbank_ami():
|
|||||||
cuts_ihm,
|
cuts_ihm,
|
||||||
output_dir / "feats_train_ihm",
|
output_dir / "feats_train_ihm",
|
||||||
src_dir / "cuts_train_ihm.jsonl.gz",
|
src_dir / "cuts_train_ihm.jsonl.gz",
|
||||||
|
perturb_speed,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Processing train split IHM + reverberated IHM")
|
logging.info("Processing train split IHM + reverberated IHM")
|
||||||
@ -117,6 +125,7 @@ def compute_fbank_ami():
|
|||||||
cuts_ihm_rvb,
|
cuts_ihm_rvb,
|
||||||
output_dir / "feats_train_ihm_rvb",
|
output_dir / "feats_train_ihm_rvb",
|
||||||
src_dir / "cuts_train_ihm_rvb.jsonl.gz",
|
src_dir / "cuts_train_ihm_rvb.jsonl.gz",
|
||||||
|
perturb_speed,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Processing train split SDM")
|
logging.info("Processing train split SDM")
|
||||||
@ -129,6 +138,7 @@ def compute_fbank_ami():
|
|||||||
cuts_sdm,
|
cuts_sdm,
|
||||||
output_dir / "feats_train_sdm",
|
output_dir / "feats_train_sdm",
|
||||||
src_dir / "cuts_train_sdm.jsonl.gz",
|
src_dir / "cuts_train_sdm.jsonl.gz",
|
||||||
|
perturb_speed,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Processing train split GSS")
|
logging.info("Processing train split GSS")
|
||||||
@ -141,6 +151,7 @@ def compute_fbank_ami():
|
|||||||
cuts_gss,
|
cuts_gss,
|
||||||
output_dir / "feats_train_gss",
|
output_dir / "feats_train_gss",
|
||||||
src_dir / "cuts_train_gss.jsonl.gz",
|
src_dir / "cuts_train_gss.jsonl.gz",
|
||||||
|
perturb_speed,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
|
logging.info("Preparing test cuts: IHM, SDM, GSS (optional)")
|
||||||
@ -186,8 +197,21 @@ def compute_fbank_ami():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--perturb-speed",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
compute_fbank_ami()
|
args = get_args()
|
||||||
|
|
||||||
|
compute_fbank_ami(perturb_speed=args.perturb_speed)
|
||||||
|
@ -85,7 +85,7 @@ fi
|
|||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "Stage 5: Compute fbank for alimeeting"
|
log "Stage 5: Compute fbank for alimeeting"
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
python local/compute_fbank_alimeeting.py
|
python local/compute_fbank_alimeeting.py --perturb-speed True
|
||||||
log "Combine features from train splits"
|
log "Combine features from train splits"
|
||||||
lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
|
lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\
|
||||||
gzip -c > data/manifests/cuts_train_all.jsonl.gz
|
gzip -c > data/manifests/cuts_train_all.jsonl.gz
|
||||||
|
@ -257,7 +257,7 @@ class AmiAsrDataModule:
|
|||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
train_sampler = SimpleCutSampler(
|
train_sampler = SimpleCutSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
|
@ -30,7 +30,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
@ -311,8 +311,8 @@ class CommonVoiceAsrDataModule:
|
|||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
train_sampler = SingleCutSampler(
|
train_sampler = SimpleCutSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
@ -339,8 +339,8 @@ class CSJAsrDataModule:
|
|||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
train_sampler = SingleCutSampler(
|
train_sampler = SimpleCutSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
@ -27,7 +27,7 @@ from lhotse.dataset import (
|
|||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||||
@ -264,8 +264,8 @@ class GigaSpeechAsrDataModule:
|
|||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
train_sampler = SingleCutSampler(
|
train_sampler = SimpleCutSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
@ -30,7 +30,7 @@ from lhotse.dataset import (
|
|||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SimpleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||||
@ -297,8 +297,8 @@ class GigaSpeechAsrDataModule:
|
|||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
train_sampler = SingleCutSampler(
|
train_sampler = SimpleCutSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
@ -259,7 +259,7 @@ class LibriCssAsrDataModule:
|
|||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
train_sampler = SimpleCutSampler(
|
train_sampler = SimpleCutSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
|
|
1576
egs/libricss/SURT/dprnn_zipformer/scaling.py
Normal file
1576
egs/libricss/SURT/dprnn_zipformer/scaling.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -79,7 +79,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|||||||
# ln -sfv /path/to/rirs_noises $dl_dir/
|
# ln -sfv /path/to/rirs_noises $dl_dir/
|
||||||
#
|
#
|
||||||
if [ ! -d $dl_dir/rirs_noises ]; then
|
if [ ! -d $dl_dir/rirs_noises ]; then
|
||||||
lhotse download rirs_noises $dl_dir
|
lhotse download rir-noise $dl_dir/rirs_noises
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -89,6 +89,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
# to $dl_dir/librispeech. We perform text normalization for the transcripts.
|
# to $dl_dir/librispeech. We perform text normalization for the transcripts.
|
||||||
# NOTE: Alignments are required for this recipe.
|
# NOTE: Alignments are required for this recipe.
|
||||||
mkdir -p data/manifests
|
mkdir -p data/manifests
|
||||||
|
|
||||||
lhotse prepare librispeech -p train-clean-100 -p train-clean-360 -p train-other-500 -p dev-clean \
|
lhotse prepare librispeech -p train-clean-100 -p train-clean-360 -p train-other-500 -p dev-clean \
|
||||||
-j 4 --alignments-dir $dl_dir/libri_alignments/LibriSpeech $dl_dir/librispeech data/manifests/
|
-j 4 --alignments-dir $dl_dir/libri_alignments/LibriSpeech $dl_dir/librispeech data/manifests/
|
||||||
fi
|
fi
|
||||||
@ -112,7 +113,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
|
|
||||||
# We assume that you have downloaded the RIRS_NOISES corpus
|
# We assume that you have downloaded the RIRS_NOISES corpus
|
||||||
# to $dl_dir/rirs_noises
|
# to $dl_dir/rirs_noises
|
||||||
lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests
|
lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises/RIRS_NOISES data/manifests
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
@ -23,12 +23,13 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, str2bool
|
from icefall.utils import AttributeDict, num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -63,11 +64,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500",
|
required=True,
|
||||||
help="""It contains language related input files such as "lexicon.txt"
|
help="Path to the tokens.txt.",
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -98,16 +98,16 @@ def get_params() -> AttributeDict:
|
|||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
# Load tokens.txt here
|
||||||
max_token_id = max(lexicon.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
num_classes = max_token_id + 1 # +1 for the blank
|
|
||||||
|
num_classes = num_tokens(token_table) + 1 # +1 for the blank
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -24,7 +24,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
@ -70,11 +69,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.
|
help="Path to the tokens.txt.",
|
||||||
Used only when method is ctc-decoding.
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -83,10 +80,9 @@ def get_parser():
|
|||||||
default="1best",
|
default="1best",
|
||||||
help="""Decoding method.
|
help="""Decoding method.
|
||||||
Possible values are:
|
Possible values are:
|
||||||
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
(0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file
|
||||||
piece model, i.e., lang_dir/bpe.model, to convert
|
to convert tokens to actual words or characters. It needs
|
||||||
word pieces to words. It needs neither a lexicon
|
neither a lexicon nor an n-gram LM.
|
||||||
nor an n-gram LM.
|
|
||||||
(1) 1best - Use the best path as decoding output. Only
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
the transformer encoder output is used for decoding.
|
the transformer encoder output is used for decoding.
|
||||||
We call it HLG decoding.
|
We call it HLG decoding.
|
||||||
@ -297,6 +293,7 @@ def main():
|
|||||||
waves = [w.to(device) for w in waves]
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
hyps = []
|
||||||
features = fbank(waves)
|
features = fbank(waves)
|
||||||
|
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||||
@ -313,10 +310,17 @@ def main():
|
|||||||
|
|
||||||
if params.method == "ctc-decoding":
|
if params.method == "ctc-decoding":
|
||||||
logging.info("Use CTC decoding")
|
logging.info("Use CTC decoding")
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
|
||||||
bpe_model.load(params.bpe_model)
|
|
||||||
max_token_id = params.num_classes - 1
|
max_token_id = params.num_classes - 1
|
||||||
|
|
||||||
|
# Load tokens.txt here
|
||||||
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
H = k2.ctc_topo(
|
H = k2.ctc_topo(
|
||||||
max_token=max_token_id,
|
max_token=max_token_id,
|
||||||
modified=params.num_classes > 500,
|
modified=params.num_classes > 500,
|
||||||
@ -337,9 +341,9 @@ def main():
|
|||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
token_ids = get_texts(best_path)
|
hyp_tokens = get_texts(best_path)
|
||||||
hyps = bpe_model.decode(token_ids)
|
for hyp in hyp_tokens:
|
||||||
hyps = [s.split() for s in hyps]
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method in [
|
elif params.method in [
|
||||||
"1best",
|
"1best",
|
||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
@ -408,16 +412,16 @@ def main():
|
|||||||
)
|
)
|
||||||
best_path = next(iter(best_path_dict.values()))
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
|
||||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
hyp_tokens = get_texts(best_path)
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
@ -23,6 +23,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./conformer_ctc2/export.py \
|
./conformer_ctc2/export.py \
|
||||||
--exp-dir ./conformer_ctc2/exp \
|
--exp-dir ./conformer_ctc2/exp \
|
||||||
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -46,6 +47,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decode import get_params
|
from decode import get_params
|
||||||
@ -56,8 +58,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.utils import num_tokens, str2bool
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -123,10 +124,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500",
|
required=True,
|
||||||
help="The lang dir",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -143,14 +144,14 @@ def get_parser():
|
|||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
# Load tokens.txt here
|
||||||
max_token_id = max(lexicon.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
num_classes = max_token_id + 1 # +1 for the blank
|
|
||||||
|
num_classes = num_tokens(token_table) + 1 # +1 for the blank
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -25,7 +25,7 @@ Usage:
|
|||||||
|
|
||||||
./conformer_ctc3/export.py \
|
./conformer_ctc3/export.py \
|
||||||
--exp-dir ./conformer_ctc3/exp \
|
--exp-dir ./conformer_ctc3/exp \
|
||||||
--lang-dir data/lang_bpe_500 \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--jit-trace 1
|
--jit-trace 1
|
||||||
@ -36,7 +36,7 @@ It will generates the file: `jit_trace.pt`.
|
|||||||
|
|
||||||
./conformer_ctc3/export.py \
|
./conformer_ctc3/export.py \
|
||||||
--exp-dir ./conformer_ctc3/exp \
|
--exp-dir ./conformer_ctc3/exp \
|
||||||
--lang-dir data/lang_bpe_500 \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -62,6 +62,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_ctc_model, get_params
|
from train import add_model_arguments, get_ctc_model, get_params
|
||||||
@ -72,8 +73,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.utils import num_tokens, str2bool
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -130,10 +130,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--tokens",
|
||||||
type=Path,
|
type=str,
|
||||||
default="data/lang_bpe_500",
|
required=True,
|
||||||
help="The lang dir containing word table and LG graph",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -171,9 +171,10 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
# Load tokens.txt here
|
||||||
max_token_id = max(lexicon.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
num_classes = max_token_id + 1 # +1 for the blank
|
|
||||||
|
num_classes = num_tokens(token_table) + 1 # +1 for the blank
|
||||||
params.vocab_size = num_classes
|
params.vocab_size = num_classes
|
||||||
|
|
||||||
if params.streaming_model:
|
if params.streaming_model:
|
||||||
|
@ -24,7 +24,7 @@ Usage (for non-streaming mode):
|
|||||||
(1) ctc-decoding
|
(1) ctc-decoding
|
||||||
./conformer_ctc3/pretrained.py \
|
./conformer_ctc3/pretrained.py \
|
||||||
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method ctc-decoding \
|
--method ctc-decoding \
|
||||||
--sample-rate 16000 \
|
--sample-rate 16000 \
|
||||||
test_wavs/1089-134686-0001.wav
|
test_wavs/1089-134686-0001.wav
|
||||||
@ -67,7 +67,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from decode import get_decoding_params
|
from decode import get_decoding_params
|
||||||
@ -114,11 +113,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.
|
help="Path to the tokens.txt.",
|
||||||
Used only when method is ctc-decoding.
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -127,10 +124,9 @@ def get_parser():
|
|||||||
default="1best",
|
default="1best",
|
||||||
help="""Decoding method.
|
help="""Decoding method.
|
||||||
Possible values are:
|
Possible values are:
|
||||||
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
(0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file
|
||||||
piece model, i.e., lang_dir/bpe.model, to convert
|
to convert tokens to actual words or characters. It needs
|
||||||
word pieces to words. It needs neither a lexicon
|
neither a lexicon nor an n-gram LM.
|
||||||
nor an n-gram LM.
|
|
||||||
(1) 1best - Use the best path as decoding output. Only
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
the transformer encoder output is used for decoding.
|
the transformer encoder output is used for decoding.
|
||||||
We call it HLG decoding.
|
We call it HLG decoding.
|
||||||
@ -316,6 +312,7 @@ def main():
|
|||||||
waves = [w.to(device) for w in waves]
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
hyps = []
|
||||||
features = fbank(waves)
|
features = fbank(waves)
|
||||||
feature_lengths = [f.size(0) for f in features]
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
@ -348,10 +345,17 @@ def main():
|
|||||||
|
|
||||||
if params.method == "ctc-decoding":
|
if params.method == "ctc-decoding":
|
||||||
logging.info("Use CTC decoding")
|
logging.info("Use CTC decoding")
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
|
||||||
bpe_model.load(params.bpe_model)
|
|
||||||
max_token_id = params.num_classes - 1
|
max_token_id = params.num_classes - 1
|
||||||
|
|
||||||
|
# Load tokens.txt here
|
||||||
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
H = k2.ctc_topo(
|
H = k2.ctc_topo(
|
||||||
max_token=max_token_id,
|
max_token=max_token_id,
|
||||||
modified=False,
|
modified=False,
|
||||||
@ -372,9 +376,9 @@ def main():
|
|||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
token_ids = get_texts(best_path)
|
hyp_tokens = get_texts(best_path)
|
||||||
hyps = bpe_model.decode(token_ids)
|
for hyp in hyp_tokens:
|
||||||
hyps = [s.split() for s in hyps]
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method in [
|
elif params.method in [
|
||||||
"1best",
|
"1best",
|
||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
@ -439,16 +443,16 @@ def main():
|
|||||||
)
|
)
|
||||||
best_path = next(iter(best_path_dict.values()))
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
|
||||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
hyp_tokens = get_texts(best_path)
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./conv_emformer_transducer_stateless/export.py \
|
./conv_emformer_transducer_stateless/export.py \
|
||||||
--exp-dir ./conv_emformer_transducer_stateless/exp \
|
--exp-dir ./conv_emformer_transducer_stateless/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--use-averaged-model=True \
|
--use-averaged-model=True \
|
||||||
@ -62,7 +62,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
@ -72,7 +72,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -118,10 +118,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
required=True,
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -166,12 +166,12 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ for more details about how to use this file.
|
|||||||
Usage:
|
Usage:
|
||||||
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir ./conv_emformer_transducer_stateless2/exp \
|
--exp-dir ./conv_emformer_transducer_stateless2/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--use-averaged-model=True \
|
--use-averaged-model=True \
|
||||||
@ -37,7 +37,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -48,7 +48,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import setup_logger, str2bool
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -94,10 +94,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
required=True,
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -217,12 +217,12 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user