mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
update
This commit is contained in:
parent
1c4db88747
commit
8df405b6b2
@ -29,9 +29,6 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
|
|||||||
ls -lh data/fbank
|
ls -lh data/fbank
|
||||||
ls -lh pruned_transducer_stateless2/exp
|
ls -lh pruned_transducer_stateless2/exp
|
||||||
|
|
||||||
ln -s data/fbank/cuts_DEV.jsonl.gz data/fbank/gigaspeech_cuts_DEV.jsonl.gz
|
|
||||||
ln -s data/fbank/cuts_TEST.jsonl.gz data/fbank/gigaspeech_cuts_TEST.jsonl.gz
|
|
||||||
|
|
||||||
log "Decoding dev and test"
|
log "Decoding dev and test"
|
||||||
|
|
||||||
# use a small value for decoding with CPU
|
# use a small value for decoding with CPU
|
||||||
|
@ -38,7 +38,7 @@ log "Decode with models exported by torch.jit.trace()"
|
|||||||
for m in ctc-decoding 1best; do
|
for m in ctc-decoding 1best; do
|
||||||
./conformer_ctc3/jit_pretrained.py \
|
./conformer_ctc3/jit_pretrained.py \
|
||||||
--model-filename $repo/exp/jit_trace.pt \
|
--model-filename $repo/exp/jit_trace.pt \
|
||||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--G $repo/data/lm/G_4_gram.pt \
|
--G $repo/data/lm/G_4_gram.pt \
|
||||||
@ -53,7 +53,7 @@ log "Export to torchscript model"
|
|||||||
|
|
||||||
./conformer_ctc3/export.py \
|
./conformer_ctc3/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--lang-dir $repo/data/lang_bpe_500 \
|
||||||
--jit-trace 1 \
|
--jit-trace 1 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -80,9 +80,9 @@ done
|
|||||||
for m in ctc-decoding 1best; do
|
for m in ctc-decoding 1best; do
|
||||||
./conformer_ctc3/pretrained.py \
|
./conformer_ctc3/pretrained.py \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--G $repo/data/lm/G_4_gram.pt \
|
--G $repo/data/lm/G_4_gram.pt \
|
||||||
--method $m \
|
--method $m \
|
||||||
--sample-rate 16000 \
|
--sample-rate 16000 \
|
||||||
@ -93,7 +93,7 @@ done
|
|||||||
|
|
||||||
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||||
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||||
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||||
mkdir -p conformer_ctc3/exp
|
mkdir -p conformer_ctc3/exp
|
||||||
ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt
|
ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt
|
||||||
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||||
|
@ -31,7 +31,7 @@ log "Test exporting with torch.jit.trace()"
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
./lstm_transducer_stateless2/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
@ -55,7 +55,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -68,7 +68,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -36,7 +36,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -49,7 +49,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -35,7 +35,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -48,7 +48,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -30,14 +30,14 @@ popd
|
|||||||
log "Export to torchscript model"
|
log "Export to torchscript model"
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit-trace 1
|
--jit-trace 1
|
||||||
@ -74,7 +74,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -32,7 +32,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--num-encoder-layers 18 \
|
--num-encoder-layers 18 \
|
||||||
--dim-feedforward 2048 \
|
--dim-feedforward 2048 \
|
||||||
--nhead 8 \
|
--nhead 8 \
|
||||||
@ -51,7 +51,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav \
|
$repo/test_wavs/1221-135766-0002.wav \
|
||||||
|
@ -33,7 +33,7 @@ log "Export to torchscript model"
|
|||||||
./pruned_transducer_stateless7/export.py \
|
./pruned_transducer_stateless7/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -56,7 +56,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -69,7 +69,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -37,7 +37,7 @@ log "Export to torchscript model"
|
|||||||
./pruned_transducer_stateless7_ctc/export.py \
|
./pruned_transducer_stateless7_ctc/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -74,7 +74,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -36,7 +36,7 @@ log "Export to torchscript model"
|
|||||||
./pruned_transducer_stateless7_ctc_bs/export.py \
|
./pruned_transducer_stateless7_ctc_bs/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -72,7 +72,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -85,7 +85,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -37,7 +37,7 @@ log "Export to torchscript model"
|
|||||||
./pruned_transducer_stateless7_streaming/export.py \
|
./pruned_transducer_stateless7_streaming/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -81,7 +81,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
@ -95,7 +95,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
@ -41,7 +41,7 @@ log "Decode with models exported by torch.jit.script()"
|
|||||||
log "Export to torchscript model"
|
log "Export to torchscript model"
|
||||||
./pruned_transducer_stateless8/export.py \
|
./pruned_transducer_stateless8/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -65,7 +65,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -78,7 +78,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -32,7 +32,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--simulate-streaming 1 \
|
--simulate-streaming 1 \
|
||||||
--causal-convolution 1 \
|
--causal-convolution 1 \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
@ -47,7 +47,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--simulate-streaming 1 \
|
--simulate-streaming 1 \
|
||||||
--causal-convolution 1 \
|
--causal-convolution 1 \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -37,7 +37,7 @@ log "Export to torchscript model"
|
|||||||
./zipformer_mmi/export.py \
|
./zipformer_mmi/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -61,7 +61,7 @@ for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescor
|
|||||||
--method $method \
|
--method $method \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--lang-dir $repo/data/lang_bpe_500 \
|
--lang-dir $repo/data/lang_bpe_500 \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
51
.github/scripts/run-multi-zh_hans-zipformer.sh
vendored
51
.github/scripts/run-multi-zh_hans-zipformer.sh
vendored
@ -1,51 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
log() {
|
|
||||||
# This function is from espnet
|
|
||||||
local fname=${BASH_SOURCE[1]##*/}
|
|
||||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
|
||||||
}
|
|
||||||
|
|
||||||
cd egs/multi_zh-hans/ASR
|
|
||||||
|
|
||||||
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/
|
|
||||||
|
|
||||||
log "Downloading pre-trained model from $repo_url"
|
|
||||||
git lfs install
|
|
||||||
git clone $repo_url
|
|
||||||
repo=$(basename $repo_url)
|
|
||||||
|
|
||||||
|
|
||||||
log "Display test files"
|
|
||||||
tree $repo/
|
|
||||||
ls -lh $repo/test_wavs/*.wav
|
|
||||||
|
|
||||||
pushd $repo/exp
|
|
||||||
ln -s epoch-20.pt epoch-99.pt
|
|
||||||
popd
|
|
||||||
|
|
||||||
ls -lh $repo/exp/*.pt
|
|
||||||
|
|
||||||
|
|
||||||
./zipformer/pretrained.py \
|
|
||||||
--checkpoint $repo/exp/epoch-99.pt \
|
|
||||||
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
|
||||||
--method greedy_search \
|
|
||||||
$repo/test_wavs/DEV_T0000000000.wav \
|
|
||||||
$repo/test_wavs/DEV_T0000000001.wav \
|
|
||||||
$repo/test_wavs/DEV_T0000000002.wav
|
|
||||||
|
|
||||||
for method in modified_beam_search fast_beam_search; do
|
|
||||||
log "$method"
|
|
||||||
|
|
||||||
./zipformer/pretrained.py \
|
|
||||||
--method $method \
|
|
||||||
--beam-size 4 \
|
|
||||||
--checkpoint $repo/exp/epoch-99.pt \
|
|
||||||
--tokens $repo/data/lang_bpe_2000/tokens.txt \
|
|
||||||
$repo/test_wavs/DEV_T0000000000.wav \
|
|
||||||
$repo/test_wavs/DEV_T0000000001.wav \
|
|
||||||
$repo/test_wavs/DEV_T0000000002.wav
|
|
||||||
done
|
|
@ -27,7 +27,7 @@ log "CTC decoding"
|
|||||||
--method ctc-decoding \
|
--method ctc-decoding \
|
||||||
--num-classes 500 \
|
--num-classes 500 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.flac \
|
$repo/test_wavs/1089-134686-0001.flac \
|
||||||
$repo/test_wavs/1221-135766-0001.flac \
|
$repo/test_wavs/1221-135766-0001.flac \
|
||||||
$repo/test_wavs/1221-135766-0002.flac
|
$repo/test_wavs/1221-135766-0002.flac
|
||||||
@ -38,7 +38,7 @@ log "HLG decoding"
|
|||||||
--method 1best \
|
--method 1best \
|
||||||
--num-classes 500 \
|
--num-classes 500 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||||
$repo/test_wavs/1089-134686-0001.flac \
|
$repo/test_wavs/1089-134686-0001.flac \
|
||||||
|
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -27,7 +27,7 @@ log "Beam search decoding"
|
|||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -17,6 +17,7 @@ git lfs install
|
|||||||
git clone $repo_url
|
git clone $repo_url
|
||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
@ -28,11 +29,12 @@ popd
|
|||||||
|
|
||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless2/export-onnx.py \
|
./pruned_transducer_stateless2/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--lang-dir $repo/data/lang_char \
|
--lang-dir $repo/data/lang_char \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1
|
--avg 1 \
|
||||||
|
--onnx 1
|
||||||
|
|
||||||
log "Export to torchscript model"
|
log "Export to torchscript model"
|
||||||
|
|
||||||
@ -57,17 +59,19 @@ log "Decode with ONNX models"
|
|||||||
|
|
||||||
./pruned_transducer_stateless2/onnx_check.py \
|
./pruned_transducer_stateless2/onnx_check.py \
|
||||||
--jit-filename $repo/exp/cpu_jit.pt \
|
--jit-filename $repo/exp/cpu_jit.pt \
|
||||||
--onnx-encoder-filename $repo/exp/encoder-epoch-10-avg-2.onnx \
|
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
||||||
--onnx-decoder-filename $repo/exp/decoder-epoch-10-avg-2.onnx \
|
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
||||||
--onnx-joiner-filename $repo/exp/joiner-epoch-10-avg-2.onnx \
|
--onnx-joiner-filename $repo/exp/joiner.onnx \
|
||||||
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
|
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
|
||||||
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj-epoch-10-avg-2.onnx
|
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
|
||||||
|
|
||||||
./pruned_transducer_stateless2/onnx_pretrained.py \
|
./pruned_transducer_stateless2/onnx_pretrained.py \
|
||||||
--tokens $repo/data/lang_char/tokens.txt \
|
--tokens $repo/data/lang_char/tokens.txt \
|
||||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
--encoder-model-filename $repo/exp/encoder.onnx \
|
||||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
--decoder-model-filename $repo/exp/decoder.onnx \
|
||||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
--joiner-model-filename $repo/exp/joiner.onnx \
|
||||||
|
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
|
||||||
|
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
|
||||||
$repo/test_wavs/DEV_T0000000000.wav \
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
$repo/test_wavs/DEV_T0000000001.wav \
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
$repo/test_wavs/DEV_T0000000002.wav
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
@ -100,9 +104,9 @@ for sym in 1 2 3; do
|
|||||||
--lang-dir $repo/data/lang_char \
|
--lang-dir $repo/data/lang_char \
|
||||||
--decoding-method greedy_search \
|
--decoding-method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
$repo/test_wavs/DEV_T0000000000.wav \
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
$repo/test_wavs/DEV_T0000000001.wav \
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
$repo/test_wavs/DEV_T0000000002.wav
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
done
|
done
|
||||||
|
|
||||||
for method in modified_beam_search beam_search fast_beam_search; do
|
for method in modified_beam_search beam_search fast_beam_search; do
|
||||||
@ -113,7 +117,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/epoch-99.pt \
|
--checkpoint $repo/exp/epoch-99.pt \
|
||||||
--lang-dir $repo/data/lang_char \
|
--lang-dir $repo/data/lang_char \
|
||||||
$repo/test_wavs/DEV_T0000000000.wav \
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
$repo/test_wavs/DEV_T0000000001.wav \
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
$repo/test_wavs/DEV_T0000000002.wav
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
done
|
done
|
||||||
|
12
.github/scripts/test-ncnn-export.sh
vendored
12
.github/scripts/test-ncnn-export.sh
vendored
@ -45,6 +45,7 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -55,10 +56,11 @@ log "Export via torch.jit.trace()"
|
|||||||
|
|
||||||
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
\
|
||||||
--num-encoder-layers 12 \
|
--num-encoder-layers 12 \
|
||||||
--chunk-length 32 \
|
--chunk-length 32 \
|
||||||
--cnn-module-kernel 31 \
|
--cnn-module-kernel 31 \
|
||||||
@ -89,6 +91,7 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
|
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -99,7 +102,7 @@ log "Export via torch.jit.trace()"
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export-for-ncnn.py \
|
./lstm_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0
|
--use-averaged-model 0
|
||||||
@ -137,6 +140,7 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -144,7 +148,7 @@ ln -s pretrained.pt epoch-99.pt
|
|||||||
popd
|
popd
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
@ -195,7 +199,7 @@ ln -s pretrained.pt epoch-9999.pt
|
|||||||
popd
|
popd
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
|
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
|
||||||
--tokens $repo/data/lang_char_bpe/tokens.txt \
|
--lang-dir $repo/data/lang_char_bpe \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
|
138
.github/scripts/test-onnx-export.sh
vendored
138
.github/scripts/test-onnx-export.sh
vendored
@ -10,123 +10,7 @@ log() {
|
|||||||
|
|
||||||
cd egs/librispeech/ASR
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
log "=========================================================================="
|
|
||||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
|
||||||
log "Downloading pre-trained model from $repo_url"
|
|
||||||
git lfs install
|
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|
||||||
repo=$(basename $repo_url)
|
|
||||||
|
|
||||||
pushd $repo
|
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
|
||||||
cd exp
|
|
||||||
ln -s pretrained.pt epoch-99.pt
|
|
||||||
popd
|
|
||||||
|
|
||||||
log "Export via torch.jit.script()"
|
|
||||||
./zipformer/export.py \
|
|
||||||
--exp-dir $repo/exp \
|
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
|
||||||
--epoch 99 \
|
|
||||||
--avg 1 \
|
|
||||||
--jit 1
|
|
||||||
|
|
||||||
log "Test export to ONNX format"
|
|
||||||
./zipformer/export-onnx.py \
|
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
|
||||||
--use-averaged-model 0 \
|
|
||||||
--epoch 99 \
|
|
||||||
--avg 1 \
|
|
||||||
--exp-dir $repo/exp \
|
|
||||||
--num-encoder-layers "2,2,3,4,3,2" \
|
|
||||||
--downsampling-factor "1,2,4,8,4,2" \
|
|
||||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
|
||||||
--num-heads "4,4,4,8,4,4" \
|
|
||||||
--encoder-dim "192,256,384,512,384,256" \
|
|
||||||
--query-head-dim 32 \
|
|
||||||
--value-head-dim 12 \
|
|
||||||
--pos-head-dim 4 \
|
|
||||||
--pos-dim 48 \
|
|
||||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
|
||||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
|
||||||
--decoder-dim 512 \
|
|
||||||
--joiner-dim 512 \
|
|
||||||
--causal False \
|
|
||||||
--chunk-size "16,32,64,-1" \
|
|
||||||
--left-context-frames "64,128,256,-1"
|
|
||||||
|
|
||||||
ls -lh $repo/exp
|
|
||||||
|
|
||||||
log "Run onnx_check.py"
|
|
||||||
|
|
||||||
./zipformer/onnx_check.py \
|
|
||||||
--jit-filename $repo/exp/jit_script.pt \
|
|
||||||
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
|
||||||
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
|
||||||
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
|
|
||||||
|
|
||||||
log "Run onnx_pretrained.py"
|
|
||||||
|
|
||||||
./zipformer/onnx_pretrained.py \
|
|
||||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
|
||||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
|
||||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
|
||||||
$repo/test_wavs/1089-134686-0001.wav
|
|
||||||
|
|
||||||
rm -rf $repo
|
|
||||||
|
|
||||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
|
||||||
log "Downloading pre-trained model from $repo_url"
|
|
||||||
git lfs install
|
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|
||||||
repo=$(basename $repo_url)
|
|
||||||
|
|
||||||
pushd $repo
|
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
|
||||||
|
|
||||||
cd exp
|
|
||||||
ln -s pretrained.pt epoch-99.pt
|
|
||||||
popd
|
|
||||||
|
|
||||||
log "Test export streaming model to ONNX format"
|
|
||||||
./zipformer/export-onnx-streaming.py \
|
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
|
||||||
--use-averaged-model 0 \
|
|
||||||
--epoch 99 \
|
|
||||||
--avg 1 \
|
|
||||||
--exp-dir $repo/exp \
|
|
||||||
--num-encoder-layers "2,2,3,4,3,2" \
|
|
||||||
--downsampling-factor "1,2,4,8,4,2" \
|
|
||||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
|
||||||
--num-heads "4,4,4,8,4,4" \
|
|
||||||
--encoder-dim "192,256,384,512,384,256" \
|
|
||||||
--query-head-dim 32 \
|
|
||||||
--value-head-dim 12 \
|
|
||||||
--pos-head-dim 4 \
|
|
||||||
--pos-dim 48 \
|
|
||||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
|
||||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
|
||||||
--decoder-dim 512 \
|
|
||||||
--joiner-dim 512 \
|
|
||||||
--causal True \
|
|
||||||
--chunk-size 16 \
|
|
||||||
--left-context-frames 64
|
|
||||||
|
|
||||||
ls -lh $repo/exp
|
|
||||||
|
|
||||||
log "Run onnx_pretrained-streaming.py"
|
|
||||||
|
|
||||||
./zipformer/onnx_pretrained-streaming.py \
|
|
||||||
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \
|
|
||||||
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \
|
|
||||||
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \
|
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
|
||||||
$repo/test_wavs/1089-134686-0001.wav
|
|
||||||
|
|
||||||
rm -rf $repo
|
|
||||||
|
|
||||||
log "--------------------------------------------------------------------------"
|
|
||||||
|
|
||||||
log "=========================================================================="
|
log "=========================================================================="
|
||||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
@ -155,7 +39,7 @@ log "Export via torch.jit.trace()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -204,7 +88,7 @@ popd
|
|||||||
log "Export via torch.jit.script()"
|
log "Export via torch.jit.script()"
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--exp-dir $repo/exp/ \
|
--exp-dir $repo/exp/ \
|
||||||
@ -213,7 +97,7 @@ log "Export via torch.jit.script()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export-onnx.py \
|
./pruned_transducer_stateless3/export-onnx.py \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--exp-dir $repo/exp/
|
--exp-dir $repo/exp/
|
||||||
@ -242,6 +126,7 @@ log "Run onnx_pretrained.py"
|
|||||||
rm -rf $repo
|
rm -rf $repo
|
||||||
log "--------------------------------------------------------------------------"
|
log "--------------------------------------------------------------------------"
|
||||||
|
|
||||||
|
|
||||||
log "=========================================================================="
|
log "=========================================================================="
|
||||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
@ -258,7 +143,7 @@ popd
|
|||||||
log "Export via torch.jit.script()"
|
log "Export via torch.jit.script()"
|
||||||
|
|
||||||
./pruned_transducer_stateless5/export.py \
|
./pruned_transducer_stateless5/export.py \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
@ -274,7 +159,7 @@ log "Export via torch.jit.script()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless5/export-onnx.py \
|
./pruned_transducer_stateless5/export-onnx.py \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
@ -320,6 +205,7 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -329,7 +215,7 @@ popd
|
|||||||
log "Export via torch.jit.script()"
|
log "Export via torch.jit.script()"
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
./pruned_transducer_stateless7/export.py \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -340,7 +226,7 @@ log "Export via torch.jit.script()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export-onnx.py \
|
./pruned_transducer_stateless7/export-onnx.py \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -384,7 +270,7 @@ popd
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./conv_emformer_transducer_stateless2/export-onnx.py \
|
./conv_emformer_transducer_stateless2/export-onnx.py \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -424,7 +310,7 @@ popd
|
|||||||
log "Export via torch.jit.trace()"
|
log "Export via torch.jit.trace()"
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
./lstm_transducer_stateless2/export.py \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -434,7 +320,7 @@ log "Export via torch.jit.trace()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export-onnx.py \
|
./lstm_transducer_stateless2/export-onnx.py \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
|
2
.github/workflows/run-aishell-2022-06-20.yml
vendored
2
.github/workflows/run-aishell-2022-06-20.yml
vendored
@ -45,7 +45,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -1,84 +0,0 @@
|
|||||||
# Copyright 2023 Xiaomi Corp. (author: Zengrui Jin)
|
|
||||||
|
|
||||||
# See ../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
name: run-multi-zh_hans-zipformer
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
pull_request:
|
|
||||||
types: [labeled]
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: run_multi-zh_hans_zipformer-${{ github.ref }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
run_multi-zh_hans_zipformer:
|
|
||||||
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans'
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest]
|
|
||||||
python-version: [3.8]
|
|
||||||
|
|
||||||
fail-fast: false
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: Setup Python ${{ matrix.python-version }}
|
|
||||||
uses: actions/setup-python@v2
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
cache: 'pip'
|
|
||||||
cache-dependency-path: '**/requirements-ci.txt'
|
|
||||||
|
|
||||||
- name: Install Python dependencies
|
|
||||||
run: |
|
|
||||||
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
|
|
||||||
pip uninstall -y protobuf
|
|
||||||
pip install --no-binary protobuf protobuf==3.20.*
|
|
||||||
|
|
||||||
- name: Cache kaldifeat
|
|
||||||
id: my-cache
|
|
||||||
uses: actions/cache@v2
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
~/tmp/kaldifeat
|
|
||||||
key: cache-tmp-${{ matrix.python-version }}-2023-05-22
|
|
||||||
|
|
||||||
- name: Install kaldifeat
|
|
||||||
if: steps.my-cache.outputs.cache-hit != 'true'
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
.github/scripts/install-kaldifeat.sh
|
|
||||||
|
|
||||||
- name: Inference with pre-trained model
|
|
||||||
shell: bash
|
|
||||||
env:
|
|
||||||
GITHUB_EVENT_NAME: ${{ github.event_name }}
|
|
||||||
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
|
|
||||||
run: |
|
|
||||||
sudo apt-get -qq install git-lfs tree
|
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
|
||||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
|
||||||
|
|
||||||
.github/scripts/run-multi-zh_hans-zipformer.sh
|
|
@ -34,7 +34,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
76
.github/workflows/run-yesno-recipe.yml
vendored
76
.github/workflows/run-yesno-recipe.yml
vendored
@ -44,6 +44,11 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Install graphviz
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
sudo apt-get -qq install graphviz
|
||||||
|
|
||||||
- name: Setup Python ${{ matrix.python-version }}
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
@ -65,7 +70,6 @@ jobs:
|
|||||||
pip install --no-binary protobuf protobuf==3.20.*
|
pip install --no-binary protobuf protobuf==3.20.*
|
||||||
|
|
||||||
pip install --no-deps --force-reinstall https://huggingface.co/csukuangfj/k2/resolve/main/cpu/k2-1.24.3.dev20230508+cpu.torch1.13.1-cp38-cp38-linux_x86_64.whl
|
pip install --no-deps --force-reinstall https://huggingface.co/csukuangfj/k2/resolve/main/cpu/k2-1.24.3.dev20230508+cpu.torch1.13.1-cp38-cp38-linux_x86_64.whl
|
||||||
pip install kaldifeat==1.25.0.dev20230726+cpu.torch1.13.1 -f https://csukuangfj.github.io/kaldifeat/cpu.html
|
|
||||||
|
|
||||||
- name: Run yesno recipe
|
- name: Run yesno recipe
|
||||||
shell: bash
|
shell: bash
|
||||||
@ -74,75 +78,9 @@ jobs:
|
|||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
echo $PYTHONPATH
|
echo $PYTHONPATH
|
||||||
|
|
||||||
|
|
||||||
cd egs/yesno/ASR
|
cd egs/yesno/ASR
|
||||||
./prepare.sh
|
./prepare.sh
|
||||||
python3 ./tdnn/train.py
|
python3 ./tdnn/train.py
|
||||||
python3 ./tdnn/decode.py
|
python3 ./tdnn/decode.py
|
||||||
|
# TODO: Check that the WER is less than some value
|
||||||
- name: Test exporting to pretrained.pt
|
|
||||||
shell: bash
|
|
||||||
working-directory: ${{github.workspace}}
|
|
||||||
run: |
|
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
|
||||||
echo $PYTHONPATH
|
|
||||||
|
|
||||||
cd egs/yesno/ASR
|
|
||||||
python3 ./tdnn/export.py --epoch 14 --avg 2
|
|
||||||
|
|
||||||
python3 ./tdnn/pretrained.py \
|
|
||||||
--checkpoint ./tdnn/exp/pretrained.pt \
|
|
||||||
--HLG ./data/lang_phone/HLG.pt \
|
|
||||||
--words-file ./data/lang_phone/words.txt \
|
|
||||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
|
||||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
|
||||||
|
|
||||||
- name: Test exporting to torchscript
|
|
||||||
shell: bash
|
|
||||||
working-directory: ${{github.workspace}}
|
|
||||||
run: |
|
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
|
||||||
echo $PYTHONPATH
|
|
||||||
|
|
||||||
cd egs/yesno/ASR
|
|
||||||
python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1
|
|
||||||
|
|
||||||
python3 ./tdnn/jit_pretrained.py \
|
|
||||||
--nn-model ./tdnn/exp/cpu_jit.pt \
|
|
||||||
--HLG ./data/lang_phone/HLG.pt \
|
|
||||||
--words-file ./data/lang_phone/words.txt \
|
|
||||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
|
||||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
|
||||||
|
|
||||||
- name: Test exporting to onnx
|
|
||||||
shell: bash
|
|
||||||
working-directory: ${{github.workspace}}
|
|
||||||
run: |
|
|
||||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
|
||||||
echo $PYTHONPATH
|
|
||||||
|
|
||||||
cd egs/yesno/ASR
|
|
||||||
python3 ./tdnn/export_onnx.py --epoch 14 --avg 2
|
|
||||||
|
|
||||||
echo "Test float32 model"
|
|
||||||
python3 ./tdnn/onnx_pretrained.py \
|
|
||||||
--nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
|
|
||||||
--HLG ./data/lang_phone/HLG.pt \
|
|
||||||
--words-file ./data/lang_phone/words.txt \
|
|
||||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
|
||||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
|
||||||
|
|
||||||
|
|
||||||
echo "Test int8 model"
|
|
||||||
python3 ./tdnn/onnx_pretrained.py \
|
|
||||||
--nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \
|
|
||||||
--HLG ./data/lang_phone/HLG.pt \
|
|
||||||
--words-file ./data/lang_phone/words.txt \
|
|
||||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
|
||||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
|
||||||
|
|
||||||
- name: Show generated files
|
|
||||||
shell: bash
|
|
||||||
working-directory: ${{github.workspace}}
|
|
||||||
run: |
|
|
||||||
cd egs/yesno/ASR
|
|
||||||
ls -lh tdnn/exp
|
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,6 +4,7 @@ __pycache__
|
|||||||
path.sh
|
path.sh
|
||||||
exp
|
exp
|
||||||
exp*/
|
exp*/
|
||||||
|
tensorboard
|
||||||
*.pt
|
*.pt
|
||||||
download
|
download
|
||||||
dask-worker-space
|
dask-worker-space
|
||||||
|
@ -338,7 +338,7 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
|
|||||||
|
|
||||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
|
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
|
||||||
|
|
||||||
The best results for Chinese CER(%) and English WER(%) respectively (zh: Chinese, en: English):
|
The best results for Chinese CER(%) and English WER(%) respectivly (zh: Chinese, en: English):
|
||||||
|decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en |
|
|decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en |
|
||||||
|--|--|--|--|--|--|--|
|
|--|--|--|--|--|--|--|
|
||||||
|greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
|
|greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
|
||||||
|
@ -95,7 +95,4 @@ rst_epilog = """
|
|||||||
.. _k2: https://github.com/k2-fsa/k2
|
.. _k2: https://github.com/k2-fsa/k2
|
||||||
.. _lhotse: https://github.com/lhotse-speech/lhotse
|
.. _lhotse: https://github.com/lhotse-speech/lhotse
|
||||||
.. _yesno: https://www.openslr.org/1/
|
.. _yesno: https://www.openslr.org/1/
|
||||||
.. _Next-gen Kaldi: https://github.com/k2-fsa
|
|
||||||
.. _Kaldi: https://github.com/kaldi-asr/kaldi
|
|
||||||
.. _lilcom: https://github.com/danpovey/lilcom
|
|
||||||
"""
|
"""
|
||||||
|
@ -71,12 +71,9 @@ As the initial step, let's download the pre-trained model.
|
|||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
$ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||||
$ git lfs pull --include "pretrained.pt"
|
$ git lfs pull --include "pretrained.pt"
|
||||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||||
$ cd ../data/lang_bpe_500
|
|
||||||
$ git lfs pull --include bpe.model
|
|
||||||
$ cd ../../..
|
|
||||||
|
|
||||||
To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command:
|
To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command:
|
||||||
|
|
||||||
@ -88,7 +85,7 @@ To test the model, let's have a look at the decoding results **without** using L
|
|||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model False \
|
--use-averaged-model False \
|
||||||
--exp-dir $exp_dir \
|
--exp-dir $exp_dir \
|
||||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
|
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
--decoding-method modified_beam_search
|
--decoding-method modified_beam_search
|
||||||
@ -138,8 +135,8 @@ Then, we perform LODR decoding by setting ``--decoding-method`` to ``modified_be
|
|||||||
--exp-dir $exp_dir \
|
--exp-dir $exp_dir \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
--decoding-method modified_beam_search_LODR \
|
--decoding-method modified_beam_search_lm_LODR \
|
||||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
|
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||||
--use-shallow-fusion 1 \
|
--use-shallow-fusion 1 \
|
||||||
--lm-type rnn \
|
--lm-type rnn \
|
||||||
--lm-exp-dir $lm_dir \
|
--lm-exp-dir $lm_dir \
|
||||||
|
@ -34,12 +34,9 @@ As the initial step, let's download the pre-trained model.
|
|||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
$ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||||
$ git lfs pull --include "pretrained.pt"
|
$ git lfs pull --include "pretrained.pt"
|
||||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||||
$ cd ../data/lang_bpe_500
|
|
||||||
$ git lfs pull --include bpe.model
|
|
||||||
$ cd ../../..
|
|
||||||
|
|
||||||
As usual, we first test the model's performance without external LM. This can be done via the following command:
|
As usual, we first test the model's performance without external LM. This can be done via the following command:
|
||||||
|
|
||||||
@ -51,7 +48,7 @@ As usual, we first test the model's performance without external LM. This can be
|
|||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model False \
|
--use-averaged-model False \
|
||||||
--exp-dir $exp_dir \
|
--exp-dir $exp_dir \
|
||||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
|
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
--decoding-method modified_beam_search
|
--decoding-method modified_beam_search
|
||||||
@ -104,7 +101,7 @@ is set to `False`.
|
|||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
--decoding-method modified_beam_search_lm_rescore \
|
--decoding-method modified_beam_search_lm_rescore \
|
||||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
|
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||||
--use-shallow-fusion 0 \
|
--use-shallow-fusion 0 \
|
||||||
--lm-type rnn \
|
--lm-type rnn \
|
||||||
--lm-exp-dir $lm_dir \
|
--lm-exp-dir $lm_dir \
|
||||||
@ -176,7 +173,7 @@ Then we can performn LM rescoring + LODR by changing the decoding method to `mod
|
|||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
--decoding-method modified_beam_search_lm_rescore_LODR \
|
--decoding-method modified_beam_search_lm_rescore_LODR \
|
||||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
|
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||||
--use-shallow-fusion 0 \
|
--use-shallow-fusion 0 \
|
||||||
--lm-type rnn \
|
--lm-type rnn \
|
||||||
--lm-exp-dir $lm_dir \
|
--lm-exp-dir $lm_dir \
|
||||||
|
@ -32,12 +32,9 @@ As the initial step, let's download the pre-trained model.
|
|||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
$ cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||||
$ git lfs pull --include "pretrained.pt"
|
$ git lfs pull --include "pretrained.pt"
|
||||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||||
$ cd ../data/lang_bpe_500
|
|
||||||
$ git lfs pull --include bpe.model
|
|
||||||
$ cd ../../..
|
|
||||||
|
|
||||||
To test the model, let's have a look at the decoding results without using LM. This can be done via the following command:
|
To test the model, let's have a look at the decoding results without using LM. This can be done via the following command:
|
||||||
|
|
||||||
@ -49,7 +46,7 @@ To test the model, let's have a look at the decoding results without using LM. T
|
|||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model False \
|
--use-averaged-model False \
|
||||||
--exp-dir $exp_dir \
|
--exp-dir $exp_dir \
|
||||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
|
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
--decoding-method modified_beam_search
|
--decoding-method modified_beam_search
|
||||||
@ -98,7 +95,7 @@ To use shallow fusion for decoding, we can execute the following command:
|
|||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
--decoding-method modified_beam_search_lm_shallow_fusion \
|
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \
|
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||||
--use-shallow-fusion 1 \
|
--use-shallow-fusion 1 \
|
||||||
--lm-type rnn \
|
--lm-type rnn \
|
||||||
--lm-exp-dir $lm_dir \
|
--lm-exp-dir $lm_dir \
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 356 KiB |
@ -1,17 +0,0 @@
|
|||||||
.. _icefall_docker:
|
|
||||||
|
|
||||||
Docker
|
|
||||||
======
|
|
||||||
|
|
||||||
This section describes how to use pre-built docker images to run `icefall`_.
|
|
||||||
|
|
||||||
.. hint::
|
|
||||||
|
|
||||||
If you only have CPUs available, you can still use the pre-built docker
|
|
||||||
images.
|
|
||||||
|
|
||||||
.. toctree::
|
|
||||||
:maxdepth: 2
|
|
||||||
|
|
||||||
./intro.rst
|
|
||||||
|
|
@ -1,171 +0,0 @@
|
|||||||
Introduction
|
|
||||||
=============
|
|
||||||
|
|
||||||
We have pre-built docker images hosted at the following address:
|
|
||||||
|
|
||||||
`<https://hub.docker.com/repository/docker/k2fsa/icefall/general>`_
|
|
||||||
|
|
||||||
.. figure:: img/docker-hub.png
|
|
||||||
:width: 600
|
|
||||||
:align: center
|
|
||||||
|
|
||||||
You can find the ``Dockerfile`` at `<https://github.com/k2-fsa/icefall/tree/master/docker>`_.
|
|
||||||
|
|
||||||
We describe the following items in this section:
|
|
||||||
|
|
||||||
- How to view available tags
|
|
||||||
- How to download pre-built docker images
|
|
||||||
- How to run the `yesno`_ recipe within a docker container on ``CPU``
|
|
||||||
|
|
||||||
View available tags
|
|
||||||
===================
|
|
||||||
|
|
||||||
You can use the following command to view available tags:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
curl -s 'https://registry.hub.docker.com/v2/repositories/k2fsa/icefall/tags/'|jq '."results"[]["name"]'
|
|
||||||
|
|
||||||
which will give you something like below:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
"torch2.0.0-cuda11.7"
|
|
||||||
"torch1.12.1-cuda11.3"
|
|
||||||
"torch1.9.0-cuda10.2"
|
|
||||||
"torch1.13.0-cuda11.6"
|
|
||||||
|
|
||||||
.. hint::
|
|
||||||
|
|
||||||
Available tags will be updated when there are new releases of `torch`_.
|
|
||||||
|
|
||||||
Please select an appropriate combination of `torch`_ and CUDA.
|
|
||||||
|
|
||||||
Download a docker image
|
|
||||||
=======================
|
|
||||||
|
|
||||||
Suppose that you select the tag ``torch1.13.0-cuda11.6``, you can use
|
|
||||||
the following command to download it:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
sudo docker image pull k2fsa/icefall:torch1.13.0-cuda11.6
|
|
||||||
|
|
||||||
Run a docker image with GPU
|
|
||||||
===========================
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
sudo docker run --gpus all --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash
|
|
||||||
|
|
||||||
Run a docker image with CPU
|
|
||||||
===========================
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
sudo docker run --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash
|
|
||||||
|
|
||||||
Run yesno within a docker container
|
|
||||||
===================================
|
|
||||||
|
|
||||||
After starting the container, the following interface is presented:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
root@60c947eac59c:/workspace/icefall#
|
|
||||||
|
|
||||||
It shows the current user is ``root`` and the current working directory
|
|
||||||
is ``/workspace/icefall``.
|
|
||||||
|
|
||||||
Update the code
|
|
||||||
---------------
|
|
||||||
|
|
||||||
Please first run:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
root@60c947eac59c:/workspace/icefall# git pull
|
|
||||||
|
|
||||||
so that your local copy contains the latest code.
|
|
||||||
|
|
||||||
Data preparation
|
|
||||||
----------------
|
|
||||||
|
|
||||||
Now we can use
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
root@60c947eac59c:/workspace/icefall# cd egs/yesno/ASR/
|
|
||||||
|
|
||||||
to switch to the ``yesno`` recipe and run
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ./prepare.sh
|
|
||||||
|
|
||||||
.. hint::
|
|
||||||
|
|
||||||
If you are running without GPU, it may report the following error:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
File "/opt/conda/lib/python3.9/site-packages/k2/__init__.py", line 23, in <module>
|
|
||||||
from _k2 import DeterminizeWeightPushingType
|
|
||||||
ImportError: libcuda.so.1: cannot open shared object file: No such file or directory
|
|
||||||
|
|
||||||
We can use the following command to fix it:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ln -s /opt/conda/lib/stubs/libcuda.so /opt/conda/lib/stubs/libcuda.so.1
|
|
||||||
|
|
||||||
The logs of running ``./prepare.sh`` are listed below:
|
|
||||||
|
|
||||||
.. literalinclude:: ./log/log-preparation.txt
|
|
||||||
|
|
||||||
Training
|
|
||||||
--------
|
|
||||||
|
|
||||||
After preparing the data, we can start training with the following command
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ./tdnn/train.py
|
|
||||||
|
|
||||||
All of the training logs are given below:
|
|
||||||
|
|
||||||
.. hint::
|
|
||||||
|
|
||||||
It is running on CPU and it takes only 16 seconds for this run.
|
|
||||||
|
|
||||||
.. literalinclude:: ./log/log-train-2023-08-01-01-55-27
|
|
||||||
|
|
||||||
|
|
||||||
Decoding
|
|
||||||
--------
|
|
||||||
|
|
||||||
After training, we can decode the trained model with
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ./tdnn/decode.py
|
|
||||||
|
|
||||||
The decoding logs are given below:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
2023-08-01 02:06:22,400 INFO [decode.py:263] Decoding started
|
|
||||||
2023-08-01 02:06:22,400 INFO [decode.py:264] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'export': False, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '4c05309499a08454997adf500b56dcc629e35ae5', 'k2-git-date': 'Tue Jul 25 16:23:36 2023', 'lhotse-version': '1.16.0.dev+git.7640d663.clean', 'torch-version': '1.13.0', 'torch-cuda-available': False, 'torch-cuda-version': '11.6', 'python-version': '3.9', 'icefall-git-branch': 'master', 'icefall-git-sha1': '375520d-clean', 'icefall-git-date': 'Fri Jul 28 07:43:08 2023', 'icefall-path': '/workspace/icefall', 'k2-path': '/opt/conda/lib/python3.9/site-packages/k2/__init__.py', 'lhotse-path': '/opt/conda/lib/python3.9/site-packages/lhotse/__init__.py', 'hostname': '60c947eac59c', 'IP address': '172.17.0.2'}}
|
|
||||||
2023-08-01 02:06:22,401 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
|
||||||
2023-08-01 02:06:22,403 INFO [decode.py:273] device: cpu
|
|
||||||
2023-08-01 02:06:22,406 INFO [decode.py:291] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
|
||||||
2023-08-01 02:06:22,424 INFO [asr_datamodule.py:218] About to get test cuts
|
|
||||||
2023-08-01 02:06:22,425 INFO [asr_datamodule.py:252] About to get test cuts
|
|
||||||
2023-08-01 02:06:22,504 INFO [decode.py:204] batch 0/?, cuts processed until now is 4
|
|
||||||
[W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware.
|
|
||||||
2023-08-01 02:06:22,687 INFO [decode.py:241] The transcripts are stored in tdnn/exp/recogs-test_set.txt
|
|
||||||
2023-08-01 02:06:22,688 INFO [utils.py:564] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
|
||||||
2023-08-01 02:06:22,690 INFO [decode.py:249] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
|
|
||||||
2023-08-01 02:06:22,690 INFO [decode.py:316] Done!
|
|
||||||
|
|
||||||
Congratulations! You have finished successfully running `icefall`_ within a docker container.
|
|
@ -1,180 +0,0 @@
|
|||||||
.. _dummies_tutorial_data_preparation:
|
|
||||||
|
|
||||||
Data Preparation
|
|
||||||
================
|
|
||||||
|
|
||||||
After :ref:`dummies_tutorial_environment_setup`, we can start preparing the
|
|
||||||
data for training and decoding.
|
|
||||||
|
|
||||||
The first step is to prepare the data for training. We have already provided
|
|
||||||
`prepare.sh <https://github.com/k2-fsa/icefall/blob/master/egs/yesno/ASR/prepare.sh>`_
|
|
||||||
that would prepare everything required for training.
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
cd /tmp/icefall
|
|
||||||
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
|
||||||
cd egs/yesno/ASR
|
|
||||||
|
|
||||||
./prepare.sh
|
|
||||||
|
|
||||||
Note that in each recipe from `icefall`_, there exists a file ``prepare.sh``,
|
|
||||||
which you should run before you run anything else.
|
|
||||||
|
|
||||||
That is all you need for data preparation.
|
|
||||||
|
|
||||||
For the more curious
|
|
||||||
--------------------
|
|
||||||
|
|
||||||
If you are wondering how to prepare your own dataset, please refer to the following
|
|
||||||
URLs for more details:
|
|
||||||
|
|
||||||
- `<https://github.com/lhotse-speech/lhotse/tree/master/lhotse/recipes>`_
|
|
||||||
|
|
||||||
It contains recipes for a variety of dataset. If you want to add your own
|
|
||||||
dataset, please read recipes in this folder first.
|
|
||||||
|
|
||||||
- `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/yesno.py>`_
|
|
||||||
|
|
||||||
The `yesno`_ recipe in `lhotse`_.
|
|
||||||
|
|
||||||
If you already have a `Kaldi`_ dataset directory, which contains files like
|
|
||||||
``wav.scp``, ``feats.scp``, then you can refer to `<https://lhotse.readthedocs.io/en/latest/kaldi.html#example>`_.
|
|
||||||
|
|
||||||
A quick look to the generated files
|
|
||||||
-----------------------------------
|
|
||||||
|
|
||||||
``./prepare.sh`` puts generated files into two directories:
|
|
||||||
|
|
||||||
- ``download``
|
|
||||||
- ``data``
|
|
||||||
|
|
||||||
download
|
|
||||||
^^^^^^^^
|
|
||||||
|
|
||||||
The ``download`` directory contains downloaded dataset files:
|
|
||||||
|
|
||||||
.. code-block:: bas
|
|
||||||
|
|
||||||
tree -L 1 ./download/
|
|
||||||
|
|
||||||
./download/
|
|
||||||
|-- waves_yesno
|
|
||||||
`-- waves_yesno.tar.gz
|
|
||||||
|
|
||||||
.. hint::
|
|
||||||
|
|
||||||
Please refer to `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/yesno.py#L41>`_
|
|
||||||
for how the data is downloaded and extracted.
|
|
||||||
|
|
||||||
data
|
|
||||||
^^^^
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
tree ./data/
|
|
||||||
|
|
||||||
./data/
|
|
||||||
|-- fbank
|
|
||||||
| |-- yesno_cuts_test.jsonl.gz
|
|
||||||
| |-- yesno_cuts_train.jsonl.gz
|
|
||||||
| |-- yesno_feats_test.lca
|
|
||||||
| `-- yesno_feats_train.lca
|
|
||||||
|-- lang_phone
|
|
||||||
| |-- HLG.pt
|
|
||||||
| |-- L.pt
|
|
||||||
| |-- L_disambig.pt
|
|
||||||
| |-- Linv.pt
|
|
||||||
| |-- lexicon.txt
|
|
||||||
| |-- lexicon_disambig.txt
|
|
||||||
| |-- tokens.txt
|
|
||||||
| `-- words.txt
|
|
||||||
|-- lm
|
|
||||||
| |-- G.arpa
|
|
||||||
| `-- G.fst.txt
|
|
||||||
`-- manifests
|
|
||||||
|-- yesno_recordings_test.jsonl.gz
|
|
||||||
|-- yesno_recordings_train.jsonl.gz
|
|
||||||
|-- yesno_supervisions_test.jsonl.gz
|
|
||||||
`-- yesno_supervisions_train.jsonl.gz
|
|
||||||
|
|
||||||
4 directories, 18 files
|
|
||||||
|
|
||||||
**data/manifests**:
|
|
||||||
|
|
||||||
This directory contains manifests. They are used to generate files in
|
|
||||||
``data/fbank``.
|
|
||||||
|
|
||||||
To give you an idea of what it contains, we examine the first few lines of
|
|
||||||
the manifests related to the ``train`` dataset.
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
cd data/manifests
|
|
||||||
gunzip -c yesno_recordings_train.jsonl.gz | head -n 3
|
|
||||||
|
|
||||||
The output is given below:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
{"id": "0_0_0_0_1_1_1_1", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_0_1_1_1_1.wav"}], "sampling_rate": 8000, "num_samples": 50800, "duration": 6.35, "channel_ids": [0]}
|
|
||||||
{"id": "0_0_0_1_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_1_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48880, "duration": 6.11, "channel_ids": [0]}
|
|
||||||
{"id": "0_0_1_0_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_1_0_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48160, "duration": 6.02, "channel_ids": [0]}
|
|
||||||
|
|
||||||
Please refer to `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/audio.py#L300>`_
|
|
||||||
for the meaning of each field per line.
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
gunzip -c yesno_supervisions_train.jsonl.gz | head -n 3
|
|
||||||
|
|
||||||
The output is given below:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
{"id": "0_0_0_0_1_1_1_1", "recording_id": "0_0_0_0_1_1_1_1", "start": 0.0, "duration": 6.35, "channel": 0, "text": "NO NO NO NO YES YES YES YES", "language": "Hebrew"}
|
|
||||||
{"id": "0_0_0_1_0_1_1_0", "recording_id": "0_0_0_1_0_1_1_0", "start": 0.0, "duration": 6.11, "channel": 0, "text": "NO NO NO YES NO YES YES NO", "language": "Hebrew"}
|
|
||||||
{"id": "0_0_1_0_0_1_1_0", "recording_id": "0_0_1_0_0_1_1_0", "start": 0.0, "duration": 6.02, "channel": 0, "text": "NO NO YES NO NO YES YES NO", "language": "Hebrew"}
|
|
||||||
|
|
||||||
Please refer to `<https://github.com/lhotse-speech/lhotse/blob/master/lhotse/supervision.py#L510>`_
|
|
||||||
for the meaning of each field per line.
|
|
||||||
|
|
||||||
**data/fbank**:
|
|
||||||
|
|
||||||
This directory contains everything from ``data/manifests``. Furthermore, it also contains features
|
|
||||||
for training.
|
|
||||||
|
|
||||||
``data/fbank/yesno_feats_train.lca`` contains the features for the train dataset.
|
|
||||||
Features are compressed using `lilcom`_.
|
|
||||||
|
|
||||||
``data/fbank/yesno_cuts_train.jsonl.gz`` stores the `CutSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/cut/set.py#L72>`_,
|
|
||||||
which stores `RecordingSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/audio.py#L928>`_,
|
|
||||||
`SupervisionSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/supervision.py#L510>`_,
|
|
||||||
and `FeatureSet <https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/base.py#L593>`_.
|
|
||||||
|
|
||||||
To give you an idea about what it looks like, we can run the following command:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
cd data/fbank
|
|
||||||
|
|
||||||
gunzip -c yesno_cuts_train.jsonl.gz | head -n 3
|
|
||||||
|
|
||||||
The output is given below:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
{"id": "0_0_0_0_1_1_1_1-0", "start": 0, "duration": 6.35, "channel": 0, "supervisions": [{"id": "0_0_0_0_1_1_1_1", "recording_id": "0_0_0_0_1_1_1_1", "start": 0.0, "duration": 6.35, "channel": 0, "text": "NO NO NO NO YES YES YES YES", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 635, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.35, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "0,13000,3570", "channels": 0}, "recording": {"id": "0_0_0_0_1_1_1_1", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_0_1_1_1_1.wav"}], "sampling_rate": 8000, "num_samples": 50800, "duration": 6.35, "channel_ids": [0]}, "type": "MonoCut"}
|
|
||||||
{"id": "0_0_0_1_0_1_1_0-1", "start": 0, "duration": 6.11, "channel": 0, "supervisions": [{"id": "0_0_0_1_0_1_1_0", "recording_id": "0_0_0_1_0_1_1_0", "start": 0.0, "duration": 6.11, "channel": 0, "text": "NO NO NO YES NO YES YES NO", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 611, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.11, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "16570,12964,2929", "channels": 0}, "recording": {"id": "0_0_0_1_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_1_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48880, "duration": 6.11, "channel_ids": [0]}, "type": "MonoCut"}
|
|
||||||
{"id": "0_0_1_0_0_1_1_0-2", "start": 0, "duration": 6.02, "channel": 0, "supervisions": [{"id": "0_0_1_0_0_1_1_0", "recording_id": "0_0_1_0_0_1_1_0", "start": 0.0, "duration": 6.02, "channel": 0, "text": "NO NO YES NO NO YES YES NO", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 602, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.02, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "32463,12936,2696", "channels": 0}, "recording": {"id": "0_0_1_0_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_1_0_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48160, "duration": 6.02, "channel_ids": [0]}, "type": "MonoCut"}
|
|
||||||
|
|
||||||
Note that ``yesno_cuts_train.jsonl.gz`` only stores the information about how to read the features.
|
|
||||||
The actual features are stored separately in ``data/fbank/yesno_feats_train.lca``.
|
|
||||||
|
|
||||||
**data/lang**:
|
|
||||||
|
|
||||||
This directory contains the lexicon.
|
|
||||||
|
|
||||||
**data/lm**:
|
|
||||||
|
|
||||||
This directory contains language models.
|
|
@ -1,39 +0,0 @@
|
|||||||
.. _dummies_tutorial_decoding:
|
|
||||||
|
|
||||||
Decoding
|
|
||||||
========
|
|
||||||
|
|
||||||
After :ref:`dummies_tutorial_training`, we can start decoding.
|
|
||||||
|
|
||||||
The command to start the decoding is quite simple:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
cd /tmp/icefall
|
|
||||||
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
|
||||||
cd egs/yesno/ASR
|
|
||||||
|
|
||||||
# We use CPU for decoding by setting the following environment variable
|
|
||||||
export CUDA_VISIBLE_DEVICES=""
|
|
||||||
|
|
||||||
./tdnn/decode.py
|
|
||||||
|
|
||||||
The output logs are given below:
|
|
||||||
|
|
||||||
.. literalinclude:: ./code/decoding-yesno.txt
|
|
||||||
|
|
||||||
For the more curious
|
|
||||||
--------------------
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
./tdnn/decode.py --help
|
|
||||||
|
|
||||||
will print the usage information about ``./tdnn/decode.py``. For instance, you
|
|
||||||
can specify:
|
|
||||||
|
|
||||||
- ``--epoch`` to use which checkpoint for decoding
|
|
||||||
- ``--avg`` to select how many checkpoints to use for model averaging
|
|
||||||
|
|
||||||
You usually try different combinations of ``--epoch`` and ``--avg`` and select
|
|
||||||
one that leads to the lowest WER (`Word Error Rate <https://en.wikipedia.org/wiki/Word_error_rate>`_).
|
|
@ -1,121 +0,0 @@
|
|||||||
.. _dummies_tutorial_environment_setup:
|
|
||||||
|
|
||||||
Environment setup
|
|
||||||
=================
|
|
||||||
|
|
||||||
We will create an environment for `Next-gen Kaldi`_ that runs on ``CPU``
|
|
||||||
in this tutorial.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
Since the `yesno`_ dataset used in this tutorial is very tiny, training on
|
|
||||||
``CPU`` works very well for it.
|
|
||||||
|
|
||||||
If your dataset is very large, e.g., hundreds or thousands of hours of
|
|
||||||
training data, please follow :ref:`install icefall` to install `icefall`_
|
|
||||||
that works with ``GPU``.
|
|
||||||
|
|
||||||
|
|
||||||
Create a virtual environment
|
|
||||||
----------------------------
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
virtualenv -p python3 /tmp/icefall_env
|
|
||||||
|
|
||||||
The above command creates a virtual environment in the directory ``/tmp/icefall_env``.
|
|
||||||
You can select any directory you want.
|
|
||||||
|
|
||||||
The output of the above command is given below:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
Already using interpreter /usr/bin/python3
|
|
||||||
Using base prefix '/usr'
|
|
||||||
New python executable in /tmp/icefall_env/bin/python3
|
|
||||||
Also creating executable in /tmp/icefall_env/bin/python
|
|
||||||
Installing setuptools, pkg_resources, pip, wheel...done.
|
|
||||||
|
|
||||||
Now we can activate the environment using:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
source /tmp/icefall_env/bin/activate
|
|
||||||
|
|
||||||
Install dependencies
|
|
||||||
--------------------
|
|
||||||
|
|
||||||
.. warning::
|
|
||||||
|
|
||||||
Remeber to activate your virtual environment before you continue!
|
|
||||||
|
|
||||||
After activating the virtual environment, we can use the following command
|
|
||||||
to install dependencies of `icefall`_:
|
|
||||||
|
|
||||||
.. hint::
|
|
||||||
|
|
||||||
Remeber that we will run this tutorial on ``CPU``, so we install
|
|
||||||
dependencies required only by running on ``CPU``.
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
# Caution: Installation order matters!
|
|
||||||
|
|
||||||
# We use torch 2.0.0 and torchaduio 2.0.0 in this tutorial.
|
|
||||||
# Other versions should also work.
|
|
||||||
|
|
||||||
pip install torch==2.0.0+cpu torchaudio==2.0.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
|
||||||
|
|
||||||
# If you are using macOS or Windows, please use the following command to install torch and torchaudio
|
|
||||||
# pip install torch==2.0.0 torchaudio==2.0.0 -f https://download.pytorch.org/whl/torch_stable.html
|
|
||||||
|
|
||||||
# Now install k2
|
|
||||||
# Please refer to https://k2-fsa.github.io/k2/installation/from_wheels.html#linux-cpu-example
|
|
||||||
|
|
||||||
pip install k2==1.24.3.dev20230726+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html
|
|
||||||
|
|
||||||
# Install the latest version of lhotse
|
|
||||||
|
|
||||||
pip install git+https://github.com/lhotse-speech/lhotse
|
|
||||||
|
|
||||||
|
|
||||||
Install icefall
|
|
||||||
---------------
|
|
||||||
|
|
||||||
We will put the source code of `icefall`_ into the directory ``/tmp``
|
|
||||||
You can select any directory you want.
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
cd /tmp
|
|
||||||
git clone https://github.com/k2-fsa/icefall
|
|
||||||
cd icefall
|
|
||||||
pip install -r ./requirements.txt
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
# Anytime we want to use icefall, we have to set the following
|
|
||||||
# environment variable
|
|
||||||
|
|
||||||
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
|
||||||
|
|
||||||
.. hint::
|
|
||||||
|
|
||||||
If you get the following error during this tutorial:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
ModuleNotFoundError: No module named 'icefall'
|
|
||||||
|
|
||||||
please set the above environment variable to fix it.
|
|
||||||
|
|
||||||
|
|
||||||
Congratulations! You have installed `icefall`_ successfully.
|
|
||||||
|
|
||||||
For the more curious
|
|
||||||
--------------------
|
|
||||||
|
|
||||||
`icefall`_ contains a collection of Python scripts and you don't need to
|
|
||||||
use ``python3 setup.py install`` or ``pip install icefall`` to install it.
|
|
||||||
All you need to do is to download the code and set the environment variable
|
|
||||||
``PYTHONPATH``.
|
|
@ -1,34 +0,0 @@
|
|||||||
Icefall for dummies tutorial
|
|
||||||
============================
|
|
||||||
|
|
||||||
This tutorial walks you step by step about how to create a simple
|
|
||||||
ASR (`Automatic Speech Recognition <https://en.wikipedia.org/wiki/Speech_recognition>`_)
|
|
||||||
system with `Next-gen Kaldi`_.
|
|
||||||
|
|
||||||
We use the `yesno`_ dataset for demonstration. We select it out of two reasons:
|
|
||||||
|
|
||||||
- It is quite tiny, containing only about 12 minutes of data
|
|
||||||
- The training can be finished within 20 seconds on ``CPU``.
|
|
||||||
|
|
||||||
That also means you don't need a ``GPU`` to run this tutorial.
|
|
||||||
|
|
||||||
Let's get started!
|
|
||||||
|
|
||||||
Please follow items below **sequentially**.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
The :ref:`dummies_tutorial_data_preparation` runs only on Linux and on macOS.
|
|
||||||
All other parts run on Linux, macOS, and Windows.
|
|
||||||
|
|
||||||
Help from the community is appreciated to port the :ref:`dummies_tutorial_data_preparation`
|
|
||||||
to Windows.
|
|
||||||
|
|
||||||
.. toctree::
|
|
||||||
:maxdepth: 2
|
|
||||||
|
|
||||||
./environment-setup.rst
|
|
||||||
./data-preparation.rst
|
|
||||||
./training.rst
|
|
||||||
./decoding.rst
|
|
||||||
./model-export.rst
|
|
@ -1,310 +0,0 @@
|
|||||||
Model Export
|
|
||||||
============
|
|
||||||
|
|
||||||
There are three ways to export a pre-trained model.
|
|
||||||
|
|
||||||
- Export the model parameters via `model.state_dict() <https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.state_dict>`_
|
|
||||||
- Export via `torchscript <https://pytorch.org/docs/stable/jit.html>`_: either `torch.jit.script() <https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script>`_ or `torch.jit.trace() <https://pytorch.org/docs/stable/generated/torch.jit.trace.html>`_
|
|
||||||
- Export to `ONNX`_ via `torch.onnx.export() <https://pytorch.org/docs/stable/onnx.html>`_
|
|
||||||
|
|
||||||
Each method is explained below in detail.
|
|
||||||
|
|
||||||
Export the model parameters via model.state_dict()
|
|
||||||
---------------------------------------------------
|
|
||||||
|
|
||||||
The command for this kind of export is
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
cd /tmp/icefall
|
|
||||||
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
|
||||||
cd egs/yesno/ASR
|
|
||||||
|
|
||||||
# assume that "--epoch 14 --avg 2" produces the lowest WER.
|
|
||||||
|
|
||||||
./tdnn/export.py --epoch 14 --avg 2
|
|
||||||
|
|
||||||
The output logs are given below:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
2023-08-16 20:42:03,912 INFO [export.py:76] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'jit': False}
|
|
||||||
2023-08-16 20:42:03,913 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
|
||||||
2023-08-16 20:42:03,950 INFO [export.py:93] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
|
||||||
2023-08-16 20:42:03,971 INFO [export.py:106] Not using torch.jit.script
|
|
||||||
2023-08-16 20:42:03,974 INFO [export.py:111] Saved to tdnn/exp/pretrained.pt
|
|
||||||
|
|
||||||
We can see from the logs that the exported model is saved to the file ``tdnn/exp/pretrained.pt``.
|
|
||||||
|
|
||||||
To give you an idea of what ``tdnn/exp/pretrained.pt`` contains, we can use the following command:
|
|
||||||
|
|
||||||
.. code-block:: python3
|
|
||||||
|
|
||||||
>>> import torch
|
|
||||||
>>> m = torch.load("tdnn/exp/pretrained.pt")
|
|
||||||
>>> list(m.keys())
|
|
||||||
['model']
|
|
||||||
>>> list(m["model"].keys())
|
|
||||||
['tdnn.0.weight', 'tdnn.0.bias', 'tdnn.2.running_mean', 'tdnn.2.running_var', 'tdnn.2.num_batches_tracked', 'tdnn.3.weight', 'tdnn.3.bias', 'tdnn.5.running_mean', 'tdnn.5.running_var', 'tdnn.5.num_batches_tracked', 'tdnn.6.weight', 'tdnn.6.bias', 'tdnn.8.running_mean', 'tdnn.8.running_var', 'tdnn.8.num_batches_tracked', 'output_linear.weight', 'output_linear.bias']
|
|
||||||
|
|
||||||
We can use ``tdnn/exp/pretrained.pt`` in the following way with ``./tdnn/decode.py``:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
cd tdnn/exp
|
|
||||||
ln -s pretrained.pt epoch-99.pt
|
|
||||||
cd ../..
|
|
||||||
|
|
||||||
./tdnn/decode.py --epoch 99 --avg 1
|
|
||||||
|
|
||||||
The output logs of the above command are given below:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
2023-08-16 20:45:48,089 INFO [decode.py:262] Decoding started
|
|
||||||
2023-08-16 20:45:48,090 INFO [decode.py:263] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 99, 'avg': 1, 'export': False, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': False, 'k2-git-sha1': 'ad79f1c699c684de9785ed6ca5edb805a41f78c3', 'k2-git-date': 'Wed Jul 26 09:30:42 2023', 'lhotse-version': '1.16.0.dev+git.aa073f6.clean', 'torch-version': '2.0.0', 'torch-cuda-available': False, 'torch-cuda-version': None, 'python-version': '3.1', 'icefall-git-branch': 'master', 'icefall-git-sha1': '9a47c08-clean', 'icefall-git-date': 'Mon Aug 14 22:10:50 2023', 'icefall-path': '/private/tmp/icefall', 'k2-path': '/private/tmp/icefall_env/lib/python3.11/site-packages/k2/__init__.py', 'lhotse-path': '/private/tmp/icefall_env/lib/python3.11/site-packages/lhotse/__init__.py', 'hostname': 'fangjuns-MacBook-Pro.local', 'IP address': '127.0.0.1'}}
|
|
||||||
2023-08-16 20:45:48,092 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
|
||||||
2023-08-16 20:45:48,103 INFO [decode.py:272] device: cpu
|
|
||||||
2023-08-16 20:45:48,109 INFO [checkpoint.py:112] Loading checkpoint from tdnn/exp/epoch-99.pt
|
|
||||||
2023-08-16 20:45:48,115 INFO [asr_datamodule.py:218] About to get test cuts
|
|
||||||
2023-08-16 20:45:48,115 INFO [asr_datamodule.py:253] About to get test cuts
|
|
||||||
2023-08-16 20:45:50,386 INFO [decode.py:203] batch 0/?, cuts processed until now is 4
|
|
||||||
2023-08-16 20:45:50,556 INFO [decode.py:240] The transcripts are stored in tdnn/exp/recogs-test_set.txt
|
|
||||||
2023-08-16 20:45:50,557 INFO [utils.py:564] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ]
|
|
||||||
2023-08-16 20:45:50,558 INFO [decode.py:248] Wrote detailed error stats to tdnn/exp/errs-test_set.txt
|
|
||||||
2023-08-16 20:45:50,559 INFO [decode.py:315] Done!
|
|
||||||
|
|
||||||
We can see that it produces an identical WER as before.
|
|
||||||
|
|
||||||
We can also use it to decode files with the following command:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
# ./tdnn/pretrained.py requires kaldifeat
|
|
||||||
#
|
|
||||||
# Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html
|
|
||||||
# for how to install kaldifeat
|
|
||||||
|
|
||||||
pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
|
|
||||||
|
|
||||||
./tdnn/pretrained.py \
|
|
||||||
--checkpoint ./tdnn/exp/pretrained.pt \
|
|
||||||
--HLG ./data/lang_phone/HLG.pt \
|
|
||||||
--words-file ./data/lang_phone/words.txt \
|
|
||||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
|
||||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
|
||||||
|
|
||||||
The output is given below:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
2023-08-16 20:53:19,208 INFO [pretrained.py:136] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'checkpoint': './tdnn/exp/pretrained.pt', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']}
|
|
||||||
2023-08-16 20:53:19,208 INFO [pretrained.py:142] device: cpu
|
|
||||||
2023-08-16 20:53:19,208 INFO [pretrained.py:144] Creating model
|
|
||||||
2023-08-16 20:53:19,212 INFO [pretrained.py:156] Loading HLG from ./data/lang_phone/HLG.pt
|
|
||||||
2023-08-16 20:53:19,213 INFO [pretrained.py:160] Constructing Fbank computer
|
|
||||||
2023-08-16 20:53:19,213 INFO [pretrained.py:170] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']
|
|
||||||
2023-08-16 20:53:19,224 INFO [pretrained.py:176] Decoding started
|
|
||||||
2023-08-16 20:53:19,304 INFO [pretrained.py:212]
|
|
||||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav:
|
|
||||||
NO NO NO YES NO NO NO YES
|
|
||||||
|
|
||||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav:
|
|
||||||
NO NO YES NO NO NO YES NO
|
|
||||||
|
|
||||||
|
|
||||||
2023-08-16 20:53:19,304 INFO [pretrained.py:214] Decoding Done
|
|
||||||
|
|
||||||
|
|
||||||
Export via torch.jit.script()
|
|
||||||
-----------------------------
|
|
||||||
|
|
||||||
The command for this kind of export is
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
cd /tmp/icefall
|
|
||||||
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
|
||||||
cd egs/yesno/ASR
|
|
||||||
|
|
||||||
# assume that "--epoch 14 --avg 2" produces the lowest WER.
|
|
||||||
|
|
||||||
./tdnn/export.py --epoch 14 --avg 2 --jit true
|
|
||||||
|
|
||||||
The output logs are given below:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
2023-08-16 20:47:44,666 INFO [export.py:76] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'jit': True}
|
|
||||||
2023-08-16 20:47:44,667 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
|
||||||
2023-08-16 20:47:44,670 INFO [export.py:93] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
|
||||||
2023-08-16 20:47:44,677 INFO [export.py:100] Using torch.jit.script
|
|
||||||
2023-08-16 20:47:44,843 INFO [export.py:104] Saved to tdnn/exp/cpu_jit.pt
|
|
||||||
|
|
||||||
From the output logs we can see that the generated file is saved to ``tdnn/exp/cpu_jit.pt``.
|
|
||||||
|
|
||||||
Don't be confused by the name ``cpu_jit.pt``. The ``cpu`` part means the model is moved to
|
|
||||||
CPU before exporting. That means, when you load it with:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
torch.jit.load()
|
|
||||||
|
|
||||||
you don't need to specify the argument `map_location <https://pytorch.org/docs/stable/generated/torch.jit.load.html#torch.jit.load>`_
|
|
||||||
and it resides on CPU by default.
|
|
||||||
|
|
||||||
To use ``tdnn/exp/cpu_jit.pt`` with `icefall`_ to decode files, we can use:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
# ./tdnn/jit_pretrained.py requires kaldifeat
|
|
||||||
#
|
|
||||||
# Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html
|
|
||||||
# for how to install kaldifeat
|
|
||||||
|
|
||||||
pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
|
|
||||||
|
|
||||||
|
|
||||||
./tdnn/jit_pretrained.py \
|
|
||||||
--nn-model ./tdnn/exp/cpu_jit.pt \
|
|
||||||
--HLG ./data/lang_phone/HLG.pt \
|
|
||||||
--words-file ./data/lang_phone/words.txt \
|
|
||||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
|
||||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
|
||||||
|
|
||||||
The output is given below:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
2023-08-16 20:56:00,603 INFO [jit_pretrained.py:121] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'nn_model': './tdnn/exp/cpu_jit.pt', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']}
|
|
||||||
2023-08-16 20:56:00,603 INFO [jit_pretrained.py:127] device: cpu
|
|
||||||
2023-08-16 20:56:00,603 INFO [jit_pretrained.py:129] Loading torchscript model
|
|
||||||
2023-08-16 20:56:00,640 INFO [jit_pretrained.py:134] Loading HLG from ./data/lang_phone/HLG.pt
|
|
||||||
2023-08-16 20:56:00,641 INFO [jit_pretrained.py:138] Constructing Fbank computer
|
|
||||||
2023-08-16 20:56:00,641 INFO [jit_pretrained.py:148] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']
|
|
||||||
2023-08-16 20:56:00,642 INFO [jit_pretrained.py:154] Decoding started
|
|
||||||
2023-08-16 20:56:00,727 INFO [jit_pretrained.py:190]
|
|
||||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav:
|
|
||||||
NO NO NO YES NO NO NO YES
|
|
||||||
|
|
||||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav:
|
|
||||||
NO NO YES NO NO NO YES NO
|
|
||||||
|
|
||||||
|
|
||||||
2023-08-16 20:56:00,727 INFO [jit_pretrained.py:192] Decoding Done
|
|
||||||
|
|
||||||
.. hint::
|
|
||||||
|
|
||||||
We provide only code for ``torch.jit.script()``. You can try ``torch.jit.trace()``
|
|
||||||
if you want.
|
|
||||||
|
|
||||||
Export via torch.onnx.export()
|
|
||||||
------------------------------
|
|
||||||
|
|
||||||
The command for this kind of export is
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
cd /tmp/icefall
|
|
||||||
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
|
||||||
cd egs/yesno/ASR
|
|
||||||
|
|
||||||
# tdnn/export_onnx.py requires onnx and onnxruntime
|
|
||||||
pip install onnx onnxruntime
|
|
||||||
|
|
||||||
# assume that "--epoch 14 --avg 2" produces the lowest WER.
|
|
||||||
|
|
||||||
./tdnn/export_onnx.py \
|
|
||||||
--epoch 14 \
|
|
||||||
--avg 2
|
|
||||||
|
|
||||||
The output logs are given below:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
2023-08-16 20:59:20,888 INFO [export_onnx.py:83] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2}
|
|
||||||
2023-08-16 20:59:20,888 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt
|
|
||||||
2023-08-16 20:59:20,892 INFO [export_onnx.py:100] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt']
|
|
||||||
================ Diagnostic Run torch.onnx.export version 2.0.0 ================
|
|
||||||
verbose: False, log level: Level.ERROR
|
|
||||||
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
|
|
||||||
|
|
||||||
2023-08-16 20:59:21,047 INFO [export_onnx.py:127] Saved to tdnn/exp/model-epoch-14-avg-2.onnx
|
|
||||||
2023-08-16 20:59:21,047 INFO [export_onnx.py:136] meta_data: {'model_type': 'tdnn', 'version': '1', 'model_author': 'k2-fsa', 'comment': 'non-streaming tdnn for the yesno recipe', 'vocab_size': 4}
|
|
||||||
2023-08-16 20:59:21,049 INFO [export_onnx.py:140] Generate int8 quantization models
|
|
||||||
2023-08-16 20:59:21,075 INFO [onnx_quantizer.py:538] Quantization parameters for tensor:"/Transpose_1_output_0" not specified
|
|
||||||
2023-08-16 20:59:21,081 INFO [export_onnx.py:151] Saved to tdnn/exp/model-epoch-14-avg-2.int8.onnx
|
|
||||||
|
|
||||||
We can see from the logs that it generates two files:
|
|
||||||
|
|
||||||
- ``tdnn/exp/model-epoch-14-avg-2.onnx`` (ONNX model with ``float32`` weights)
|
|
||||||
- ``tdnn/exp/model-epoch-14-avg-2.int8.onnx`` (ONNX model with ``int8`` weights)
|
|
||||||
|
|
||||||
To use the generated ONNX model files for decoding with `onnxruntime`_, we can use
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
# ./tdnn/onnx_pretrained.py requires kaldifeat
|
|
||||||
#
|
|
||||||
# Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html
|
|
||||||
# for how to install kaldifeat
|
|
||||||
|
|
||||||
pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
|
|
||||||
|
|
||||||
./tdnn/onnx_pretrained.py \
|
|
||||||
--nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
|
|
||||||
--HLG ./data/lang_phone/HLG.pt \
|
|
||||||
--words-file ./data/lang_phone/words.txt \
|
|
||||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
|
||||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
|
||||||
|
|
||||||
The output is given below:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:166] {'feature_dim': 23, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'nn_model': './tdnn/exp/model-epoch-14-avg-2.onnx', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']}
|
|
||||||
2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:171] device: cpu
|
|
||||||
2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:173] Loading onnx model ./tdnn/exp/model-epoch-14-avg-2.onnx
|
|
||||||
2023-08-16 21:03:24,267 INFO [onnx_pretrained.py:176] Loading HLG from ./data/lang_phone/HLG.pt
|
|
||||||
2023-08-16 21:03:24,270 INFO [onnx_pretrained.py:180] Constructing Fbank computer
|
|
||||||
2023-08-16 21:03:24,273 INFO [onnx_pretrained.py:190] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']
|
|
||||||
2023-08-16 21:03:24,279 INFO [onnx_pretrained.py:196] Decoding started
|
|
||||||
2023-08-16 21:03:24,318 INFO [onnx_pretrained.py:232]
|
|
||||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav:
|
|
||||||
NO NO NO YES NO NO NO YES
|
|
||||||
|
|
||||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav:
|
|
||||||
NO NO YES NO NO NO YES NO
|
|
||||||
|
|
||||||
|
|
||||||
2023-08-16 21:03:24,318 INFO [onnx_pretrained.py:234] Decoding Done
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
To use the ``int8`` ONNX model for decoding, please use:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
./tdnn/onnx_pretrained.py \
|
|
||||||
--nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
|
|
||||||
--HLG ./data/lang_phone/HLG.pt \
|
|
||||||
--words-file ./data/lang_phone/words.txt \
|
|
||||||
download/waves_yesno/0_0_0_1_0_0_0_1.wav \
|
|
||||||
download/waves_yesno/0_0_1_0_0_0_1_0.wav
|
|
||||||
|
|
||||||
For the more curious
|
|
||||||
--------------------
|
|
||||||
|
|
||||||
If you are wondering how to deploy the model without ``torch``, please
|
|
||||||
continue reading. We will show how to use `sherpa-onnx`_ to run the
|
|
||||||
exported ONNX models, which depends only on `onnxruntime`_ and does not
|
|
||||||
depend on ``torch``.
|
|
||||||
|
|
||||||
In this tutorial, we will only demonstrate the usage of `sherpa-onnx`_ with the
|
|
||||||
pre-trained model of the `yesno`_ recipe. There are also other two frameworks
|
|
||||||
available:
|
|
||||||
|
|
||||||
- `sherpa`_. It works with torchscript models.
|
|
||||||
- `sherpa-ncnn`_. It works with models exported using :ref:`icefall_export_to_ncnn` with `ncnn`_
|
|
||||||
|
|
||||||
Please see `<https://k2-fsa.github.io/sherpa/>`_ for further details.
|
|
@ -1,39 +0,0 @@
|
|||||||
.. _dummies_tutorial_training:
|
|
||||||
|
|
||||||
Training
|
|
||||||
========
|
|
||||||
|
|
||||||
After :ref:`dummies_tutorial_data_preparation`, we can start training.
|
|
||||||
|
|
||||||
The command to start the training is quite simple:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
cd /tmp/icefall
|
|
||||||
export PYTHONPATH=/tmp/icefall:$PYTHONPATH
|
|
||||||
cd egs/yesno/ASR
|
|
||||||
|
|
||||||
# We use CPU for training by setting the following environment variable
|
|
||||||
export CUDA_VISIBLE_DEVICES=""
|
|
||||||
|
|
||||||
./tdnn/train.py
|
|
||||||
|
|
||||||
That's it!
|
|
||||||
|
|
||||||
You can find the training logs below:
|
|
||||||
|
|
||||||
.. literalinclude:: ./code/train-yesno.txt
|
|
||||||
|
|
||||||
For the more curious
|
|
||||||
--------------------
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
./tdnn/train.py --help
|
|
||||||
|
|
||||||
will print the usage information about ``./tdnn/train.py``. For instance, you
|
|
||||||
can specify the number of epochs to train and the location to save the training
|
|
||||||
results.
|
|
||||||
|
|
||||||
The training text logs are saved in ``tdnn/exp/log`` while the tensorboard
|
|
||||||
logs are in ``tdnn/exp/tensorboard``.
|
|
@ -20,13 +20,10 @@ speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
|
|||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
:caption: Contents:
|
:caption: Contents:
|
||||||
|
|
||||||
for-dummies/index.rst
|
|
||||||
installation/index
|
installation/index
|
||||||
docker/index
|
|
||||||
faqs
|
faqs
|
||||||
model-export/index
|
model-export/index
|
||||||
|
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 3
|
:maxdepth: 3
|
||||||
|
|
||||||
|
@ -3,11 +3,6 @@
|
|||||||
Installation
|
Installation
|
||||||
============
|
============
|
||||||
|
|
||||||
.. hint::
|
|
||||||
|
|
||||||
We also provide :ref:`icefall_docker` support, which has already setup
|
|
||||||
the environment for you.
|
|
||||||
|
|
||||||
.. hint::
|
.. hint::
|
||||||
|
|
||||||
We have a colab notebook guiding you step by step to setup the environment.
|
We have a colab notebook guiding you step by step to setup the environment.
|
||||||
|
@ -41,7 +41,7 @@ as an example.
|
|||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--tokens data/lang_bpe_500/tokens.txt \
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ In each recipe, there is also a file ``pretrained.py``, which can use
|
|||||||
|
|
||||||
./pruned_transducer_stateless3/pretrained.py \
|
./pruned_transducer_stateless3/pretrained.py \
|
||||||
--checkpoint ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/pretrained-iter-1224000-avg-14.pt \
|
--checkpoint ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/pretrained-iter-1224000-avg-14.pt \
|
||||||
--tokens ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \
|
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav \
|
||||||
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav \
|
./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav \
|
||||||
|
@ -153,10 +153,11 @@ Next, we use the following code to export our model:
|
|||||||
|
|
||||||
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir $dir/exp \
|
--exp-dir $dir/exp \
|
||||||
--tokens $dir/data/lang_bpe_500/tokens.txt \
|
--bpe-model $dir/data/lang_bpe_500/bpe.model \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
|
\
|
||||||
--num-encoder-layers 12 \
|
--num-encoder-layers 12 \
|
||||||
--chunk-length 32 \
|
--chunk-length 32 \
|
||||||
--cnn-module-kernel 31 \
|
--cnn-module-kernel 31 \
|
||||||
|
@ -73,7 +73,7 @@ Next, we use the following code to export our model:
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export-for-ncnn.py \
|
./lstm_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir $dir/exp \
|
--exp-dir $dir/exp \
|
||||||
--tokens $dir/data/lang_bpe_500/tokens.txt \
|
--bpe-model $dir/data/lang_bpe_500/bpe.model \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
|
@ -72,11 +72,12 @@ Next, we use the following code to export our model:
|
|||||||
dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
||||||
--tokens $dir/data/lang_bpe_500/tokens.txt \
|
--bpe-model $dir/data/lang_bpe_500/bpe.model \
|
||||||
--exp-dir $dir/exp \
|
--exp-dir $dir/exp \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
|
\
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
--num-left-chunks 4 \
|
--num-left-chunks 4 \
|
||||||
--num-encoder-layers "2,4,3,2,4" \
|
--num-encoder-layers "2,4,3,2,4" \
|
||||||
|
@ -71,7 +71,7 @@ Export the model to ONNX
|
|||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
|
@ -32,7 +32,7 @@ as an example in the following.
|
|||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--tokens data/lang_bpe_500/tokens.txt \
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
--epoch $epoch \
|
--epoch $epoch \
|
||||||
--avg $avg \
|
--avg $avg \
|
||||||
--jit 1
|
--jit 1
|
||||||
|
@ -33,7 +33,7 @@ as an example in the following.
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
./lstm_transducer_stateless2/export.py \
|
||||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||||
--tokens data/lang_bpe_500/tokens.txt \
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
--iter $iter \
|
--iter $iter \
|
||||||
--avg $avg \
|
--avg $avg \
|
||||||
--jit-trace 1
|
--jit-trace 1
|
||||||
|
80
egs/disc_tts/ASR/local/compute_fbank_disc_tts.py
Executable file
80
egs/disc_tts/ASR/local/compute_fbank_disc_tts.py
Executable file
@ -0,0 +1,80 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
|
||||||
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_fbank_disc_tts(
|
||||||
|
dataset: Optional[str] = None,
|
||||||
|
):
|
||||||
|
src_dir = Path("data/fbank")
|
||||||
|
output_dir = Path("data/fbank")
|
||||||
|
num_jobs = min(1, os.cpu_count())
|
||||||
|
num_mel_bins = 80
|
||||||
|
|
||||||
|
if dataset is None:
|
||||||
|
dataset_parts = ("dac", "encodec", "gt", "hifigan", "hubert", "vq", "wavlm")
|
||||||
|
else:
|
||||||
|
dataset_parts = dataset.split(" ", -1)
|
||||||
|
|
||||||
|
prefix = "disc_tts"
|
||||||
|
suffix = "jsonl.gz"
|
||||||
|
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
for partition in dataset_parts:
|
||||||
|
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||||
|
if (output_dir / cuts_filename).is_file():
|
||||||
|
logging.info(f"{partition} already exists - skipping.")
|
||||||
|
continue
|
||||||
|
logging.info(f"Processing {partition}")
|
||||||
|
raw_cuts_filename = f"{prefix}_cuts_{partition}_raw.{suffix}"
|
||||||
|
cut_set = CutSet.from_file(src_dir / raw_cuts_filename)
|
||||||
|
|
||||||
|
cut_set = cut_set.compute_and_store_features(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||||
|
# when an executor is specified, make more partitions
|
||||||
|
num_jobs=num_jobs if ex is None else 80,
|
||||||
|
executor=ex,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
)
|
||||||
|
cut_set.to_file(output_dir / cuts_filename)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
args = get_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
compute_fbank_disc_tts(
|
||||||
|
dataset=args.dataset,
|
||||||
|
)
|
@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -15,34 +15,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
This file computes fbank features of the LibriSpeech dataset.
|
|
||||||
It looks for manifests in the directory data/manifests.
|
|
||||||
|
|
||||||
The generated fbank features are saved in data/fbank.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
from lhotse import CutSet, SupervisionSegment
|
||||||
from lhotse import CutSet
|
|
||||||
from lhotse.cut import MonoCut
|
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
|
||||||
# it wastes a lot of CPU and slow things down.
|
|
||||||
# Do this outside of main() in case it needs to take effect
|
|
||||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -57,28 +37,31 @@ def get_args():
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_librispeech(
|
def normalize_text(utt: str) -> str:
|
||||||
|
utt = re.sub(r"[{0}]+".format("-"), " ", utt)
|
||||||
|
return re.sub(r"[^a-zA-Z\s']", "", utt).upper()
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_disc_tts(
|
||||||
dataset: Optional[str] = None,
|
dataset: Optional[str] = None,
|
||||||
):
|
):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path(f"data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path(f"data/fbank")
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
dataset_parts = (
|
dataset_parts = ("dac", "encodec", "gt", "hifigan", "hubert", "vq", "wavlm")
|
||||||
"train-clean-100-sp1_1",
|
|
||||||
"train-clean-360-sp1_1",
|
|
||||||
"train-other-500-sp1_1",
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
dataset_parts = dataset.split(" ", -1)
|
dataset_parts = dataset.split(" ", -1)
|
||||||
|
|
||||||
prefix = "librispeech"
|
logging.info("Loading manifest")
|
||||||
|
prefix = f"disc_tts"
|
||||||
suffix = "jsonl.gz"
|
suffix = "jsonl.gz"
|
||||||
manifests = read_manifests_if_cached(
|
manifests = read_manifests_if_cached(
|
||||||
dataset_parts=dataset_parts,
|
dataset_parts=dataset_parts,
|
||||||
output_dir=src_dir,
|
output_dir=src_dir,
|
||||||
prefix=prefix,
|
|
||||||
suffix=suffix,
|
suffix=suffix,
|
||||||
|
prefix=prefix,
|
||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
@ -90,31 +73,46 @@ def compute_fbank_librispeech(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
logging.info(f"Processing {partition}")
|
||||||
if (output_dir / cuts_filename).is_file():
|
raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}"
|
||||||
logging.info(f"{partition} already exists - skipping.")
|
if raw_cuts_path.is_file():
|
||||||
|
logging.info(f"{partition} already exists - skipping")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
logging.info(f"Normalizing text in {partition}")
|
||||||
|
for sup in m["supervisions"]:
|
||||||
|
text = str(sup.text)
|
||||||
|
orig_text = text
|
||||||
|
sup.text = normalize_text(sup.text)
|
||||||
|
text = str(sup.text)
|
||||||
|
if len(orig_text) != len(text):
|
||||||
|
logging.info(
|
||||||
|
f"\nOriginal text vs normalized text:\n{orig_text}\n{text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create long-recording cut manifests.
|
||||||
cut_set = CutSet.from_manifests(
|
cut_set = CutSet.from_manifests(
|
||||||
recordings=m["recordings"],
|
recordings=m["recordings"],
|
||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
).resample(16000)
|
||||||
logging.info(f"Processing {partition}")
|
|
||||||
for i in tqdm(range(len(cut_set))):
|
|
||||||
cut_set[i].discrete_tokens = cut_set[i].supervisions[0].discrete_tokens
|
|
||||||
try:
|
|
||||||
del cut_set[i].supervisions[0].custom
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
cut_set.to_file(output_dir / cuts_filename)
|
# Run data augmentation that needs to be done in the
|
||||||
|
# time domain.
|
||||||
|
logging.info(f"Saving to {raw_cuts_path}")
|
||||||
|
cut_set.to_file(raw_cuts_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
args = get_args()
|
args = get_args()
|
||||||
logging.info(vars(args))
|
logging.info(vars(args))
|
||||||
compute_fbank_librispeech(
|
preprocess_disc_tts(
|
||||||
dataset=args.dataset,
|
dataset=args.dataset,
|
||||||
)
|
)
|
||||||
|
logging.info("Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
egs/disc_tts/ASR/shared
Symbolic link
1
egs/disc_tts/ASR/shared
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../icefall/shared
|
170
egs/disc_tts/ASR/zipformer/asr_datamodule.py
Normal file
170
egs/disc_tts/ASR/zipformer/asr_datamodule.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
import argparse
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||||
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
|
K2SpeechRecognitionDataset,
|
||||||
|
PrecomputedFeatures,
|
||||||
|
SimpleCutSampler,
|
||||||
|
)
|
||||||
|
from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class _SeedWorkers:
|
||||||
|
def __init__(self, seed: int):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, worker_id: int):
|
||||||
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscTTSAsrDataModule:
|
||||||
|
"""
|
||||||
|
DataModule for k2 ASR experiments.
|
||||||
|
It assumes there is always one train and valid dataloader,
|
||||||
|
but there can be multiple test dataloaders (e.g. DiscTTS test-clean
|
||||||
|
and test-other).
|
||||||
|
|
||||||
|
It contains all the common data pipeline modules used in ASR
|
||||||
|
experiments, e.g.:
|
||||||
|
- dynamic batch size,
|
||||||
|
- bucketing samplers,
|
||||||
|
- cut concatenation,
|
||||||
|
- augmentation,
|
||||||
|
- on-the-fly feature extraction
|
||||||
|
|
||||||
|
This class should be derived for specific corpora used in ASR tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args: argparse.Namespace):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group(
|
||||||
|
title="ASR data related options",
|
||||||
|
description="These options are used for the preparation of "
|
||||||
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
|
"augmentations, etc.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--manifest-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/fbank"),
|
||||||
|
help="Path to directory with train/valid/test cuts.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--max-duration",
|
||||||
|
type=int,
|
||||||
|
default=200.0,
|
||||||
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--shuffle",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled (=default), the examples will be "
|
||||||
|
"shuffled for each epoch.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--drop-last",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to drop last batch. Used by sampler.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--return-cuts",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, each batch will have the "
|
||||||
|
"field: batch['supervisions']['cut'] with the cuts that "
|
||||||
|
"were used to construct it.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--input-strategy",
|
||||||
|
type=str,
|
||||||
|
default="PrecomputedFeatures",
|
||||||
|
help="AudioSamples or PrecomputedFeatures",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
|
logging.debug("About to create test dataset")
|
||||||
|
test = K2SpeechRecognitionDataset(
|
||||||
|
input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
sampler = SimpleCutSampler(
|
||||||
|
cuts,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False,
|
||||||
|
)
|
||||||
|
logging.debug("About to create test dataloader")
|
||||||
|
test_dl = DataLoader(
|
||||||
|
test,
|
||||||
|
batch_size=None,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
)
|
||||||
|
return test_dl
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_dac_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get dac test cuts")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "disc_tts_cuts_dac.jsonl.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_encodec_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get encodec test cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "disc_tts_cuts_encodec.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_gt_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get gt test cuts")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "disc_tts_cuts_gt.jsonl.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_hifigan_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get hifigan test cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "disc_tts_cuts_hifigan.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_hubert_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get hubert test cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "disc_tts_cuts_hubert.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_vq_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get vq test cuts")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "disc_tts_cuts_vq.jsonl.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_wavlm_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get wavlm test cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "disc_tts_cuts_wavlm.jsonl.gz"
|
||||||
|
)
|
1020
egs/disc_tts/ASR/zipformer/decode.py
Executable file
1020
egs/disc_tts/ASR/zipformer/decode.py
Executable file
File diff suppressed because it is too large
Load Diff
193
egs/disc_tts/ASR/zipformer/generate_averaged_model.py
Executable file
193
egs/disc_tts/ASR/zipformer/generate_averaged_model.py
Executable file
@ -0,0 +1,193 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
(1) use the checkpoint exp_dir/epoch-xxx.pt
|
||||||
|
./zipformer/generate_averaged_model.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer/exp
|
||||||
|
|
||||||
|
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
|
||||||
|
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
|
||||||
|
|
||||||
|
(2) use the checkpoint exp_dir/checkpoint-iter.pt
|
||||||
|
./zipformer/generate_averaged_model.py \
|
||||||
|
--iter 22000 \
|
||||||
|
--avg 5 \
|
||||||
|
--exp-dir ./zipformer/exp
|
||||||
|
|
||||||
|
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
|
||||||
|
You can later load it by `torch.load("iter-22000-avg-5.pt")`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=9,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
|
help="Path to the tokens.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
print("Script started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
print(f"Device: {device}")
|
||||||
|
|
||||||
|
symbol_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
params.blank_id = symbol_table["<blk>"]
|
||||||
|
params.unk_id = symbol_table["<unk>"]
|
||||||
|
params.vocab_size = len(symbol_table)
|
||||||
|
|
||||||
|
print("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
print(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
|
||||||
|
torch.save({"model": model.state_dict()}, filename)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
print(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||||
|
torch.save({"model": model.state_dict()}, filename)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
358
egs/disc_tts/ASR/zipformer/model.py
Normal file
358
egs/disc_tts/ASR/zipformer/model.py
Normal file
@ -0,0 +1,358 @@
|
|||||||
|
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Wei Kang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
|
|
||||||
|
class AsrModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_embed: nn.Module,
|
||||||
|
encoder: EncoderInterface,
|
||||||
|
decoder: Optional[nn.Module] = None,
|
||||||
|
joiner: Optional[nn.Module] = None,
|
||||||
|
encoder_dim: int = 384,
|
||||||
|
decoder_dim: int = 512,
|
||||||
|
vocab_size: int = 500,
|
||||||
|
use_transducer: bool = True,
|
||||||
|
use_ctc: bool = False,
|
||||||
|
):
|
||||||
|
"""A joint CTC & Transducer ASR model.
|
||||||
|
|
||||||
|
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
|
||||||
|
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
|
||||||
|
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_embed:
|
||||||
|
It is a Convolutional 2D subsampling module. It converts
|
||||||
|
an input of shape (N, T, idim) to an output of of shape
|
||||||
|
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
|
||||||
|
encoder:
|
||||||
|
It is the transcription network in the paper. Its accepts
|
||||||
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||||
|
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
|
||||||
|
`logit_lens` of shape (N,).
|
||||||
|
decoder:
|
||||||
|
It is the prediction network in the paper. Its input shape
|
||||||
|
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||||
|
It should contain one attribute: `blank_id`.
|
||||||
|
It is used when use_transducer is True.
|
||||||
|
joiner:
|
||||||
|
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||||
|
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||||
|
unnormalized probs, i.e., not processed by log-softmax.
|
||||||
|
It is used when use_transducer is True.
|
||||||
|
use_transducer:
|
||||||
|
Whether use transducer head. Default: True.
|
||||||
|
use_ctc:
|
||||||
|
Whether use CTC head. Default: False.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
use_transducer or use_ctc
|
||||||
|
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
|
||||||
|
|
||||||
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
|
|
||||||
|
self.encoder_embed = encoder_embed
|
||||||
|
self.encoder = encoder
|
||||||
|
|
||||||
|
self.use_transducer = use_transducer
|
||||||
|
if use_transducer:
|
||||||
|
# Modules for Transducer head
|
||||||
|
assert decoder is not None
|
||||||
|
assert hasattr(decoder, "blank_id")
|
||||||
|
assert joiner is not None
|
||||||
|
|
||||||
|
self.decoder = decoder
|
||||||
|
self.joiner = joiner
|
||||||
|
|
||||||
|
self.simple_am_proj = ScaledLinear(
|
||||||
|
encoder_dim, vocab_size, initial_scale=0.25
|
||||||
|
)
|
||||||
|
self.simple_lm_proj = ScaledLinear(
|
||||||
|
decoder_dim, vocab_size, initial_scale=0.25
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert decoder is None
|
||||||
|
assert joiner is None
|
||||||
|
|
||||||
|
self.use_ctc = use_ctc
|
||||||
|
if use_ctc:
|
||||||
|
# Modules for CTC head
|
||||||
|
self.ctc_output = nn.Sequential(
|
||||||
|
nn.Dropout(p=0.1),
|
||||||
|
nn.Linear(encoder_dim, vocab_size),
|
||||||
|
nn.LogSoftmax(dim=-1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_encoder(
|
||||||
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute encoder outputs.
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C).
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
|
before padding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
encoder_out:
|
||||||
|
Encoder output, of shape (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
Encoder output lengths, of shape (N,).
|
||||||
|
"""
|
||||||
|
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
||||||
|
x, x_lens = self.encoder_embed(x, x_lens)
|
||||||
|
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
|
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||||
|
|
||||||
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
def forward_ctc(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
targets: torch.Tensor,
|
||||||
|
target_lengths: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute CTC loss.
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
Encoder output, of shape (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
Encoder output lengths, of shape (N,).
|
||||||
|
targets:
|
||||||
|
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||||
|
to be un-padded and concatenated within 1 dimension.
|
||||||
|
"""
|
||||||
|
# Compute CTC log-prob
|
||||||
|
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||||
|
|
||||||
|
ctc_loss = torch.nn.functional.ctc_loss(
|
||||||
|
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||||
|
targets=targets,
|
||||||
|
input_lengths=encoder_out_lens,
|
||||||
|
target_lengths=target_lengths,
|
||||||
|
reduction="sum",
|
||||||
|
)
|
||||||
|
return ctc_loss
|
||||||
|
|
||||||
|
def forward_transducer(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
y: k2.RaggedTensor,
|
||||||
|
y_lens: torch.Tensor,
|
||||||
|
prune_range: int = 5,
|
||||||
|
am_scale: float = 0.0,
|
||||||
|
lm_scale: float = 0.0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute Transducer loss.
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
Encoder output, of shape (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
Encoder output lengths, of shape (N,).
|
||||||
|
y:
|
||||||
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||||
|
utterance.
|
||||||
|
prune_range:
|
||||||
|
The prune range for rnnt loss, it means how many symbols(context)
|
||||||
|
we are considering for each frame to compute the loss.
|
||||||
|
am_scale:
|
||||||
|
The scale to smooth the loss with am (output of encoder network)
|
||||||
|
part
|
||||||
|
lm_scale:
|
||||||
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
|
part
|
||||||
|
"""
|
||||||
|
# Now for the decoder, i.e., the prediction network
|
||||||
|
blank_id = self.decoder.blank_id
|
||||||
|
sos_y = add_sos(y, sos_id=blank_id)
|
||||||
|
|
||||||
|
# sos_y_padded: [B, S + 1], start with SOS.
|
||||||
|
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||||
|
|
||||||
|
# decoder_out: [B, S + 1, decoder_dim]
|
||||||
|
decoder_out = self.decoder(sos_y_padded)
|
||||||
|
|
||||||
|
# Note: y does not start with SOS
|
||||||
|
# y_padded : [B, S]
|
||||||
|
y_padded = y.pad(mode="constant", padding_value=0)
|
||||||
|
|
||||||
|
y_padded = y_padded.to(torch.int64)
|
||||||
|
boundary = torch.zeros(
|
||||||
|
(encoder_out.size(0), 4),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=encoder_out.device,
|
||||||
|
)
|
||||||
|
boundary[:, 2] = y_lens
|
||||||
|
boundary[:, 3] = encoder_out_lens
|
||||||
|
|
||||||
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
|
# if self.training and random.random() < 0.25:
|
||||||
|
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||||
|
# if self.training and random.random() < 0.25:
|
||||||
|
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
|
lm=lm.float(),
|
||||||
|
am=am.float(),
|
||||||
|
symbols=y_padded,
|
||||||
|
termination_symbol=blank_id,
|
||||||
|
lm_only_scale=lm_scale,
|
||||||
|
am_only_scale=am_scale,
|
||||||
|
boundary=boundary,
|
||||||
|
reduction="sum",
|
||||||
|
return_grad=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ranges : [B, T, prune_range]
|
||||||
|
ranges = k2.get_rnnt_prune_ranges(
|
||||||
|
px_grad=px_grad,
|
||||||
|
py_grad=py_grad,
|
||||||
|
boundary=boundary,
|
||||||
|
s_range=prune_range,
|
||||||
|
)
|
||||||
|
|
||||||
|
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||||
|
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||||
|
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||||
|
am=self.joiner.encoder_proj(encoder_out),
|
||||||
|
lm=self.joiner.decoder_proj(decoder_out),
|
||||||
|
ranges=ranges,
|
||||||
|
)
|
||||||
|
|
||||||
|
# logits : [B, T, prune_range, vocab_size]
|
||||||
|
|
||||||
|
# project_input=False since we applied the decoder's input projections
|
||||||
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
|
logits=logits.float(),
|
||||||
|
symbols=y_padded,
|
||||||
|
ranges=ranges,
|
||||||
|
termination_symbol=blank_id,
|
||||||
|
boundary=boundary,
|
||||||
|
reduction="sum",
|
||||||
|
)
|
||||||
|
|
||||||
|
return simple_loss, pruned_loss
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
y: k2.RaggedTensor,
|
||||||
|
prune_range: int = 5,
|
||||||
|
am_scale: float = 0.0,
|
||||||
|
lm_scale: float = 0.0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C).
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
|
before padding.
|
||||||
|
y:
|
||||||
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||||
|
utterance.
|
||||||
|
prune_range:
|
||||||
|
The prune range for rnnt loss, it means how many symbols(context)
|
||||||
|
we are considering for each frame to compute the loss.
|
||||||
|
am_scale:
|
||||||
|
The scale to smooth the loss with am (output of encoder network)
|
||||||
|
part
|
||||||
|
lm_scale:
|
||||||
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
|
part
|
||||||
|
Returns:
|
||||||
|
Return the transducer losses and CTC loss,
|
||||||
|
in form of (simple_loss, pruned_loss, ctc_loss)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||||
|
the form:
|
||||||
|
lm_scale * lm_probs + am_scale * am_probs +
|
||||||
|
(1-lm_scale-am_scale) * combined_probs
|
||||||
|
"""
|
||||||
|
assert x.ndim == 3, x.shape
|
||||||
|
assert x_lens.ndim == 1, x_lens.shape
|
||||||
|
assert y.num_axes == 2, y.num_axes
|
||||||
|
|
||||||
|
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
|
||||||
|
|
||||||
|
# Compute encoder outputs
|
||||||
|
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||||
|
|
||||||
|
row_splits = y.shape.row_splits(1)
|
||||||
|
y_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
|
||||||
|
if self.use_transducer:
|
||||||
|
# Compute transducer loss
|
||||||
|
simple_loss, pruned_loss = self.forward_transducer(
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
y=y.to(x.device),
|
||||||
|
y_lens=y_lens,
|
||||||
|
prune_range=prune_range,
|
||||||
|
am_scale=am_scale,
|
||||||
|
lm_scale=lm_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
simple_loss = torch.empty(0)
|
||||||
|
pruned_loss = torch.empty(0)
|
||||||
|
|
||||||
|
if self.use_ctc:
|
||||||
|
# Compute CTC loss
|
||||||
|
targets = y.values
|
||||||
|
ctc_loss = self.forward_ctc(
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
targets=targets,
|
||||||
|
target_lengths=y_lens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ctc_loss = torch.empty(0)
|
||||||
|
|
||||||
|
return simple_loss, pruned_loss, ctc_loss
|
381
egs/disc_tts/ASR/zipformer/pretrained.py
Executable file
381
egs/disc_tts/ASR/zipformer/pretrained.py
Executable file
@ -0,0 +1,381 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads a checkpoint and uses it to decode waves.
|
||||||
|
You can generate the checkpoint with the following command:
|
||||||
|
|
||||||
|
Note: This is a example for librispeech dataset, if you are using different
|
||||||
|
dataset, you should change the argument values according to your dataset.
|
||||||
|
|
||||||
|
- For non-streaming model:
|
||||||
|
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9
|
||||||
|
|
||||||
|
- For streaming model:
|
||||||
|
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--causal 1 \
|
||||||
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
- For non-streaming model:
|
||||||
|
|
||||||
|
(1) greedy search
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||||
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
|
--method greedy_search \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
(2) modified beam search
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||||
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
|
--method modified_beam_search \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
(3) fast beam search
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||||
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
|
--method fast_beam_search \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
- For streaming model:
|
||||||
|
|
||||||
|
(1) greedy search
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||||
|
--causal 1 \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 128 \
|
||||||
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
|
--method greedy_search \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
(2) modified beam search
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||||
|
--causal 1 \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 128 \
|
||||||
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
|
--method modified_beam_search \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
(3) fast beam search
|
||||||
|
./zipformer/pretrained.py \
|
||||||
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||||
|
--causal 1 \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 128 \
|
||||||
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
|
--method fast_beam_search \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
|
||||||
|
You can also use `./zipformer/exp/epoch-xx.pt`.
|
||||||
|
|
||||||
|
Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import kaldifeat
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from beam_search import (
|
||||||
|
fast_beam_search_one_best,
|
||||||
|
greedy_search_batch,
|
||||||
|
modified_beam_search,
|
||||||
|
)
|
||||||
|
from export import num_tokens
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the checkpoint. "
|
||||||
|
"The checkpoint is assumed to be saved by "
|
||||||
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
help="""Path to tokens.txt.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="""Possible values are:
|
||||||
|
- greedy_search
|
||||||
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam-size",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sym-per-frame",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="""Maximum number of symbols per frame. Used only when
|
||||||
|
--method is greedy_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert (
|
||||||
|
sample_rate == expected_sample_rate
|
||||||
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0].contiguous())
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1
|
||||||
|
|
||||||
|
logging.info(f"{params}")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
if params.causal:
|
||||||
|
assert (
|
||||||
|
"," not in params.chunk_size
|
||||||
|
), "chunk_size should be one value in decoding."
|
||||||
|
assert (
|
||||||
|
"," not in params.left_context_frames
|
||||||
|
), "left_context_frames should be one value in decoding."
|
||||||
|
|
||||||
|
logging.info("Creating model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||||
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = params.sample_rate
|
||||||
|
opts.mel_opts.num_bins = params.feature_dim
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {params.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||||
|
)
|
||||||
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
|
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||||
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
|
# model forward
|
||||||
|
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
|
||||||
|
|
||||||
|
hyps = []
|
||||||
|
msg = f"Using {params.method}"
|
||||||
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
|
if params.method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
elif params.method == "modified_beam_search":
|
||||||
|
hyp_tokens = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
|
hyp_tokens = greedy_search_batch(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
445
egs/disc_tts/ASR/zipformer/pretrained_ctc.py
Executable file
445
egs/disc_tts/ASR/zipformer/pretrained_ctc.py
Executable file
@ -0,0 +1,445 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads a checkpoint and uses it to decode waves.
|
||||||
|
You can generate the checkpoint with the following command:
|
||||||
|
|
||||||
|
- For non-streaming model:
|
||||||
|
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9
|
||||||
|
|
||||||
|
- For streaming model:
|
||||||
|
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--causal 1 \
|
||||||
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
(1) ctc-decoding
|
||||||
|
./zipformer/pretrained_ctc.py \
|
||||||
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||||
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
|
--method ctc-decoding \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
(2) 1best
|
||||||
|
./zipformer/pretrained_ctc.py \
|
||||||
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||||
|
--HLG data/lang_bpe_500/HLG.pt \
|
||||||
|
--words-file data/lang_bpe_500/words.txt \
|
||||||
|
--method 1best \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
(3) nbest-rescoring
|
||||||
|
./zipformer/pretrained_ctc.py \
|
||||||
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||||
|
--HLG data/lang_bpe_500/HLG.pt \
|
||||||
|
--words-file data/lang_bpe_500/words.txt \
|
||||||
|
--G data/lm/G_4_gram.pt \
|
||||||
|
--method nbest-rescoring \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
|
||||||
|
(4) whole-lattice-rescoring
|
||||||
|
./zipformer/pretrained_ctc.py \
|
||||||
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||||
|
--HLG data/lang_bpe_500/HLG.pt \
|
||||||
|
--words-file data/lang_bpe_500/words.txt \
|
||||||
|
--G data/lm/G_4_gram.pt \
|
||||||
|
--method whole-lattice-rescoring \
|
||||||
|
--sample-rate 16000 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import kaldifeat
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from ctc_decode import get_decoding_params
|
||||||
|
from export import num_tokens
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
from icefall.decode import (
|
||||||
|
get_lattice,
|
||||||
|
one_best_decoding,
|
||||||
|
rescore_with_n_best_list,
|
||||||
|
rescore_with_whole_lattice,
|
||||||
|
)
|
||||||
|
from icefall.utils import get_texts
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the checkpoint. "
|
||||||
|
"The checkpoint is assumed to be saved by "
|
||||||
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--words-file",
|
||||||
|
type=str,
|
||||||
|
help="""Path to words.txt.
|
||||||
|
Used only when method is not ctc-decoding.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--HLG",
|
||||||
|
type=str,
|
||||||
|
help="""Path to HLG.pt.
|
||||||
|
Used only when method is not ctc-decoding.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
help="""Path to tokens.txt.
|
||||||
|
Used only when method is ctc-decoding.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
default="1best",
|
||||||
|
help="""Decoding method.
|
||||||
|
Possible values are:
|
||||||
|
(0) ctc-decoding - Use CTC decoding. It uses a token table,
|
||||||
|
i.e., lang_dir/tokens.txt, to convert
|
||||||
|
word pieces to words. It needs neither a lexicon
|
||||||
|
nor an n-gram LM.
|
||||||
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
|
the transformer encoder output is used for decoding.
|
||||||
|
We call it HLG decoding.
|
||||||
|
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||||
|
rescore them with an LM, the path with
|
||||||
|
the highest score is the decoding result.
|
||||||
|
We call it HLG decoding + nbest n-gram LM rescoring.
|
||||||
|
(3) whole-lattice-rescoring - Use an LM to rescore the
|
||||||
|
decoding lattice and then use 1best to decode the
|
||||||
|
rescored lattice.
|
||||||
|
We call it HLG decoding + whole-lattice n-gram LM rescoring.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--G",
|
||||||
|
type=str,
|
||||||
|
help="""An LM for rescoring.
|
||||||
|
Used only when method is
|
||||||
|
whole-lattice-rescoring or nbest-rescoring.
|
||||||
|
It's usually a 4-gram LM.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="""
|
||||||
|
Used only when method is attention-decoder.
|
||||||
|
It specifies the size of n-best list.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=1.3,
|
||||||
|
help="""
|
||||||
|
Used only when method is whole-lattice-rescoring and nbest-rescoring.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
(Note: You need to tune it on a dataset.)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="""
|
||||||
|
Used only when method is nbest-rescoring.
|
||||||
|
It specifies the scale for lattice.scores when
|
||||||
|
extracting n-best lists. A smaller value results in
|
||||||
|
more unique number of paths with the risk of missing
|
||||||
|
the best path.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float = 16000
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert sample_rate == expected_sample_rate, (
|
||||||
|
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0].contiguous())
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
# add decoding params
|
||||||
|
params.update(get_decoding_params())
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
params.vocab_size = num_tokens(token_table)
|
||||||
|
params.blank_id = token_table["<blk>"]
|
||||||
|
assert params.blank_id == 0
|
||||||
|
|
||||||
|
logging.info(f"{params}")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info("Creating model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||||
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = params.sample_rate
|
||||||
|
opts.mel_opts.num_bins = params.feature_dim
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {params.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||||
|
)
|
||||||
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
|
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||||
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
|
||||||
|
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||||
|
|
||||||
|
batch_size = ctc_output.shape[0]
|
||||||
|
supervision_segments = torch.tensor(
|
||||||
|
[
|
||||||
|
[i, 0, feature_lengths[i].item() // params.subsampling_factor]
|
||||||
|
for i in range(batch_size)
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.method == "ctc-decoding":
|
||||||
|
logging.info("Use CTC decoding")
|
||||||
|
max_token_id = params.vocab_size - 1
|
||||||
|
|
||||||
|
H = k2.ctc_topo(
|
||||||
|
max_token=max_token_id,
|
||||||
|
modified=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=ctc_output,
|
||||||
|
decoding_graph=H,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
token_ids = get_texts(best_path)
|
||||||
|
hyps = [[token_table[i] for i in ids] for ids in token_ids]
|
||||||
|
elif params.method in [
|
||||||
|
"1best",
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
]:
|
||||||
|
logging.info(f"Loading HLG from {params.HLG}")
|
||||||
|
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||||
|
HLG = HLG.to(device)
|
||||||
|
if not hasattr(HLG, "lm_scores"):
|
||||||
|
# For whole-lattice-rescoring and attention-decoder
|
||||||
|
HLG.lm_scores = HLG.scores.clone()
|
||||||
|
|
||||||
|
if params.method in [
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
]:
|
||||||
|
logging.info(f"Loading G from {params.G}")
|
||||||
|
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||||
|
G = G.to(device)
|
||||||
|
if params.method == "whole-lattice-rescoring":
|
||||||
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
# it with the whole lattice later
|
||||||
|
G = k2.add_epsilon_self_loops(G)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
|
||||||
|
# G.lm_scores is used to replace HLG.lm_scores during
|
||||||
|
# LM rescoring.
|
||||||
|
G.lm_scores = G.scores.clone()
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=ctc_output,
|
||||||
|
decoding_graph=HLG,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.method == "1best":
|
||||||
|
logging.info("Use HLG decoding")
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
if params.method == "nbest-rescoring":
|
||||||
|
logging.info("Use HLG decoding + LM rescoring")
|
||||||
|
best_path_dict = rescore_with_n_best_list(
|
||||||
|
lattice=lattice,
|
||||||
|
G=G,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
lm_scale_list=[params.ngram_lm_scale],
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
elif params.method == "whole-lattice-rescoring":
|
||||||
|
logging.info("Use HLG decoding + LM rescoring")
|
||||||
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=[params.ngram_lm_scale],
|
||||||
|
)
|
||||||
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
|
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
|
words = " ".join(hyp)
|
||||||
|
words = words.replace("▁", " ").strip()
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -65,7 +65,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule2 import LibriSpeechAsrDataModule
|
from asr_datamodule import DiscTTSAsrDataModule
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
@ -468,12 +468,6 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-codebook",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -512,8 +506,8 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||||
|
|
||||||
- token_dim: The model input dim. It has to match the one used
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
in computing tokens.
|
in computing features.
|
||||||
|
|
||||||
- subsampling_factor: The subsampling factor for the model.
|
- subsampling_factor: The subsampling factor for the model.
|
||||||
|
|
||||||
@ -535,7 +529,7 @@ def get_params() -> AttributeDict:
|
|||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000, # For the 100h subset, use 800
|
"valid_interval": 3000, # For the 100h subset, use 800
|
||||||
# parameters for zipformer
|
# parameters for zipformer
|
||||||
"token_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||||
"warm_step": 2000,
|
"warm_step": 2000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
@ -550,33 +544,20 @@ def _to_int_tuple(s: str):
|
|||||||
|
|
||||||
|
|
||||||
def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||||
# encoder_embed converts the input of shape (N, T, num_tokens)
|
# encoder_embed converts the input of shape (N, T, num_features)
|
||||||
# to the shape (N, (T - 7) // 2, encoder_dims).
|
# to the shape (N, (T - 7) // 2, encoder_dims).
|
||||||
# That is, it does two things simultaneously:
|
# That is, it does two things simultaneously:
|
||||||
# (1) subsampling: T -> (T - 7) // 2
|
# (1) subsampling: T -> (T - 7) // 2
|
||||||
# (2) embedding: num_tokens -> encoder_dims
|
# (2) embedding: num_features -> encoder_dims
|
||||||
# In the normal configuration, we will downsample once more at the end
|
# In the normal configuration, we will downsample once more at the end
|
||||||
# by a factor of 2, and most of the encoder stacks will run at a lower
|
# by a factor of 2, and most of the encoder stacks will run at a lower
|
||||||
# sampling rate.
|
# sampling rate.
|
||||||
if params.use_codebook:
|
|
||||||
codebook_path = (
|
|
||||||
"./download/DiscreteAudioToken/wavlm_large_l24_kms2000/codebook.pt"
|
|
||||||
)
|
|
||||||
tokens_embed = nn.Embedding.from_pretrained(
|
|
||||||
torch.load(codebook_path), padding_idx=2000
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tokens_embed = nn.Embedding(
|
|
||||||
num_embeddings=2001,
|
|
||||||
embedding_dim=80,
|
|
||||||
padding_idx=2000,
|
|
||||||
)
|
|
||||||
encoder_embed = Conv2dSubsampling(
|
encoder_embed = Conv2dSubsampling(
|
||||||
in_channels=params.token_dim,
|
in_channels=params.feature_dim,
|
||||||
out_channels=_to_int_tuple(params.encoder_dim)[0],
|
out_channels=_to_int_tuple(params.encoder_dim)[0],
|
||||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||||
)
|
)
|
||||||
return tokens_embed, encoder_embed
|
return encoder_embed
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
@ -623,13 +604,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
assert params.use_transducer or params.use_ctc, (
|
assert (
|
||||||
f"At least one of them should be True, "
|
params.use_transducer or params.use_ctc
|
||||||
|
), (f"At least one of them should be True, "
|
||||||
f"but got params.use_transducer={params.use_transducer}, "
|
f"but got params.use_transducer={params.use_transducer}, "
|
||||||
f"params.use_ctc={params.use_ctc}"
|
f"params.use_ctc={params.use_ctc}")
|
||||||
)
|
|
||||||
|
|
||||||
token_embed, encoder_embed = get_encoder_embed(params)
|
encoder_embed = get_encoder_embed(params)
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
|
|
||||||
if params.use_transducer:
|
if params.use_transducer:
|
||||||
@ -640,7 +621,6 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
joiner = None
|
joiner = None
|
||||||
|
|
||||||
model = AsrModel(
|
model = AsrModel(
|
||||||
token_embed=token_embed,
|
|
||||||
encoder_embed=encoder_embed,
|
encoder_embed=encoder_embed,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
@ -650,7 +630,6 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
use_transducer=params.use_transducer,
|
use_transducer=params.use_transducer,
|
||||||
use_ctc=params.use_ctc,
|
use_ctc=params.use_ctc,
|
||||||
use_codebook=params.use_codebook,
|
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -797,31 +776,25 @@ def compute_loss(
|
|||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
tokens = batch["tokens"]
|
feature = batch["inputs"]
|
||||||
# at entry, token is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert tokens.ndim == 2
|
assert feature.ndim == 3
|
||||||
tokens = tokens.to(device)
|
feature = feature.to(device)
|
||||||
|
|
||||||
token_lens = batch["token_lens"].to(device)
|
supervisions = batch["supervisions"]
|
||||||
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
frequency_masks = (
|
|
||||||
batch["frequency_masks"].to(device)
|
|
||||||
if "frequency_masks" in batch.keys()
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_idx_train = params.batch_idx_train
|
batch_idx_train = params.batch_idx_train
|
||||||
warm_step = params.warm_step
|
warm_step = params.warm_step
|
||||||
|
|
||||||
texts = [c.supervisions[0].text for c in batch["cuts"]]
|
texts = batch["supervisions"]["text"]
|
||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
simple_loss, pruned_loss, ctc_loss = model(
|
||||||
x=tokens,
|
x=feature,
|
||||||
x_lens=token_lens,
|
x_lens=feature_lens,
|
||||||
frequency_masks=frequency_masks,
|
|
||||||
y=y,
|
y=y,
|
||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
@ -835,16 +808,17 @@ def compute_loss(
|
|||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
# to params.simple_loss scale by warm_step.
|
# to params.simple_loss scale by warm_step.
|
||||||
simple_loss_scale = (
|
simple_loss_scale = (
|
||||||
s
|
s if batch_idx_train >= warm_step
|
||||||
if batch_idx_train >= warm_step
|
|
||||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||||
)
|
)
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
1.0
|
1.0 if batch_idx_train >= warm_step
|
||||||
if batch_idx_train >= warm_step
|
|
||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||||
)
|
)
|
||||||
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
loss += (
|
||||||
|
simple_loss_scale * simple_loss
|
||||||
|
+ pruned_loss_scale * pruned_loss
|
||||||
|
)
|
||||||
|
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
loss += params.ctc_loss_scale * ctc_loss
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
@ -854,7 +828,7 @@ def compute_loss(
|
|||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (token_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -970,7 +944,7 @@ def train_one_epoch(
|
|||||||
set_batch_count(model, get_adjusted_batch_count(params))
|
set_batch_count(model, get_adjusted_batch_count(params))
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["cuts"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
@ -1199,7 +1173,7 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = DiscTTSAsrDataModule(args)
|
||||||
|
|
||||||
train_cuts = librispeech.train_clean_100_cuts()
|
train_cuts = librispeech.train_clean_100_cuts()
|
||||||
if params.full_libri:
|
if params.full_libri:
|
||||||
@ -1222,7 +1196,7 @@ def run(rank, world_size, args):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# In pruned RNN-T, we require that T >= S
|
# In pruned RNN-T, we require that T >= S
|
||||||
# where T is the number of token frames after subsampling
|
# where T is the number of feature frames after subsampling
|
||||||
# and S is the number of tokens in the utterance
|
# and S is the number of tokens in the utterance
|
||||||
|
|
||||||
# In ./zipformer.py, the conv module uses the following expression
|
# In ./zipformer.py, the conv module uses the following expression
|
||||||
@ -1243,7 +1217,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
# We only load the sampler's state dict when it loads a checkpoint
|
# We only load the sampler's state dict when it loads a checkpoint
|
||||||
@ -1260,7 +1234,7 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if 0 and not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
@ -1343,12 +1317,12 @@ def display_and_save_batch(
|
|||||||
logging.info(f"Saving batch to {filename}")
|
logging.info(f"Saving batch to {filename}")
|
||||||
torch.save(batch, filename)
|
torch.save(batch, filename)
|
||||||
|
|
||||||
tokens = batch["tokens"]
|
supervisions = batch["supervisions"]
|
||||||
|
features = batch["inputs"]
|
||||||
|
|
||||||
logging.info(f"tokens shape: {tokens.shape}")
|
logging.info(f"features shape: {features.shape}")
|
||||||
|
|
||||||
texts = [c.supervisions[0].text for c in batch["cuts"]]
|
y = sp.encode(supervisions["text"], out_type=int)
|
||||||
y = sp.encode(texts, out_type=int)
|
|
||||||
num_tokens = sum(len(i) for i in y)
|
num_tokens = sum(len(i) for i in y)
|
||||||
logging.info(f"num tokens: {num_tokens}")
|
logging.info(f"num tokens: {num_tokens}")
|
||||||
|
|
||||||
@ -1397,7 +1371,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
DiscTTSAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
@ -1,20 +0,0 @@
|
|||||||
import jsonlines
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
with open(
|
|
||||||
"/mnt/lustre/sjtu/home/yfy62/discrete_token_data/GigaSpeech/xl/wavlm_large_l21_kms2000/out_quantized_sp1.1"
|
|
||||||
) as f:
|
|
||||||
discrete_tokens = f.read().splitlines()
|
|
||||||
|
|
||||||
discrete_tokens_info = {}
|
|
||||||
for discrete_token in discrete_tokens:
|
|
||||||
discrete_token = discrete_token.split(" ", 1)
|
|
||||||
discrete_tokens_info[discrete_token[0]] = discrete_token[1]
|
|
||||||
|
|
||||||
|
|
||||||
with jsonlines.open("gigaspeech_supervisions_XL.jsonl") as reader:
|
|
||||||
with jsonlines.open("gigaspeech_supervisions_XL_new.jsonl", mode="w") as writer:
|
|
||||||
for obj in tqdm(reader):
|
|
||||||
obj["custom"] = {"discrete_tokens": discrete_tokens_info[obj["id"]]}
|
|
||||||
|
|
||||||
writer.write(obj)
|
|
@ -21,11 +21,10 @@ import re
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import jsonlines
|
import jsonlines
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from lhotse import CutSet, SupervisionSegment
|
from lhotse import CutSet, SupervisionSegment
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
from lhotse.serialization import open_best
|
from lhotse.serialization import open_best
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
# Similar text filtering and normalization procedure as in:
|
# Similar text filtering and normalization procedure as in:
|
||||||
# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh
|
# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh
|
||||||
@ -40,26 +39,32 @@ def normalize_text(
|
|||||||
|
|
||||||
|
|
||||||
def has_no_oov(
|
def has_no_oov(
|
||||||
sup: SupervisionSegment, oov_pattern=re.compile(r"<(SIL|MUSIC|NOISE|OTHER)>"),
|
sup: SupervisionSegment,
|
||||||
|
oov_pattern=re.compile(r"<(SIL|MUSIC|NOISE|OTHER)>"),
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return oov_pattern.search(sup.text) is None
|
return oov_pattern.search(sup.text) is None
|
||||||
|
|
||||||
|
|
||||||
def preprocess_gigaspeech():
|
def preprocess_gigaspeech():
|
||||||
# src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
# output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
src_dir = Path(".")
|
|
||||||
output_dir = Path(".")
|
|
||||||
output_dir.mkdir(exist_ok=True)
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
dataset_parts = ("XL",)
|
dataset_parts = (
|
||||||
|
"DEV",
|
||||||
|
"TEST",
|
||||||
|
"M",
|
||||||
|
)
|
||||||
|
|
||||||
prefix = "gigaspeech"
|
prefix = "gigaspeech"
|
||||||
suffix = "jsonl.gz"
|
suffix = "jsonl.gz"
|
||||||
|
|
||||||
logging.info("Loading manifest (may take 1 minutes)")
|
logging.info("Loading manifest (may take 1 minutes)")
|
||||||
manifests = read_manifests_if_cached(
|
manifests = read_manifests_if_cached(
|
||||||
dataset_parts=dataset_parts, output_dir=src_dir, prefix=prefix, suffix=suffix,
|
dataset_parts=dataset_parts,
|
||||||
|
output_dir=src_dir,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
)
|
)
|
||||||
assert manifests is not None
|
assert manifests is not None
|
||||||
|
|
||||||
@ -71,7 +76,7 @@ def preprocess_gigaspeech():
|
|||||||
)
|
)
|
||||||
|
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.jsonl"
|
raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.jsonl.gz"
|
||||||
if raw_cuts_path.is_file():
|
if raw_cuts_path.is_file():
|
||||||
logging.info(f"{partition} already exists - skipping")
|
logging.info(f"{partition} already exists - skipping")
|
||||||
continue
|
continue
|
||||||
@ -88,7 +93,8 @@ def preprocess_gigaspeech():
|
|||||||
# Create long-recording cut manifests.
|
# Create long-recording cut manifests.
|
||||||
logging.info(f"Preprocessing {partition}")
|
logging.info(f"Preprocessing {partition}")
|
||||||
cut_set = CutSet.from_manifests(
|
cut_set = CutSet.from_manifests(
|
||||||
recordings=m["recordings"], supervisions=m["supervisions"],
|
recordings=m["recordings"],
|
||||||
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("About to split cuts into smaller chunks.")
|
logging.info("About to split cuts into smaller chunks.")
|
||||||
@ -99,6 +105,27 @@ def preprocess_gigaspeech():
|
|||||||
logging.info(f"Saving to {raw_cuts_path}")
|
logging.info(f"Saving to {raw_cuts_path}")
|
||||||
cut_set.to_file(raw_cuts_path)
|
cut_set.to_file(raw_cuts_path)
|
||||||
|
|
||||||
|
for partition in dataset_parts:
|
||||||
|
cuts_path = output_dir / f"{prefix}_cuts_{partition}.jsonl"
|
||||||
|
if cuts_path.is_file():
|
||||||
|
logging.info(f"{partition} already exists - skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logging.info(f"Processing {partition}")
|
||||||
|
raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.jsonl.gz"
|
||||||
|
with open_best(raw_cuts_path) as reader, jsonlines.open(
|
||||||
|
cuts_path, "a"
|
||||||
|
) as writer:
|
||||||
|
for cut in reader:
|
||||||
|
cut = eval(cut)
|
||||||
|
cut["custom"] = {
|
||||||
|
"discrete_tokens": cut["supervisions"][0]["custom"][
|
||||||
|
"discrete_tokens"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
del cut["supervisions"][0]["custom"]
|
||||||
|
writer.write(cut)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
@ -1,18 +0,0 @@
|
|||||||
import jsonlines
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
with jsonlines.open("gigaspeech_cuts_XL_raw.jsonl") as reader:
|
|
||||||
with jsonlines.open("gigaspeech_cuts_XL.jsonl", mode="w") as writer:
|
|
||||||
for obj in tqdm(reader):
|
|
||||||
obj["custom"] = {
|
|
||||||
"discrete_tokens": obj["supervisions"][0]["custom"]["discrete_tokens"]
|
|
||||||
}
|
|
||||||
del obj["supervisions"][0]["custom"]
|
|
||||||
|
|
||||||
# Speed perturb
|
|
||||||
obj["duration"] /= 1.1
|
|
||||||
obj["supervisions"][0]["duration"] /= 1.1
|
|
||||||
obj["id"] += "_sp1.1"
|
|
||||||
obj["supervisions"][0]["id"] += "_sp1.1"
|
|
||||||
|
|
||||||
writer.write(obj)
|
|
@ -1,120 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
This file computes fbank features of the LibriSpeech dataset.
|
|
||||||
It looks for manifests in the directory data/manifests.
|
|
||||||
|
|
||||||
The generated fbank features are saved in data/fbank.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from lhotse import CutSet
|
|
||||||
from lhotse.cut import MonoCut
|
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
|
||||||
# it wastes a lot of CPU and slow things down.
|
|
||||||
# Do this outside of main() in case it needs to take effect
|
|
||||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--dataset",
|
|
||||||
type=str,
|
|
||||||
help="""Dataset parts to compute fbank. If None, we will use all""",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_librispeech(
|
|
||||||
dataset: Optional[str] = None,
|
|
||||||
):
|
|
||||||
src_dir = Path("data/manifests")
|
|
||||||
output_dir = Path("data/fbank")
|
|
||||||
|
|
||||||
if dataset is None:
|
|
||||||
dataset_parts = (
|
|
||||||
"train-clean-100-sp0_9",
|
|
||||||
"train-clean-360-sp0_9",
|
|
||||||
"train-other-500-sp0_9",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
dataset_parts = dataset.split(" ", -1)
|
|
||||||
|
|
||||||
prefix = "librispeech"
|
|
||||||
suffix = "jsonl.gz"
|
|
||||||
manifests = read_manifests_if_cached(
|
|
||||||
dataset_parts=dataset_parts,
|
|
||||||
output_dir=src_dir,
|
|
||||||
prefix=prefix,
|
|
||||||
suffix=suffix,
|
|
||||||
)
|
|
||||||
assert manifests is not None
|
|
||||||
|
|
||||||
assert len(manifests) == len(dataset_parts), (
|
|
||||||
len(manifests),
|
|
||||||
len(dataset_parts),
|
|
||||||
list(manifests.keys()),
|
|
||||||
dataset_parts,
|
|
||||||
)
|
|
||||||
|
|
||||||
for partition, m in manifests.items():
|
|
||||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
|
||||||
if (output_dir / cuts_filename).is_file():
|
|
||||||
logging.info(f"{partition} already exists - skipping.")
|
|
||||||
continue
|
|
||||||
cut_set = CutSet.from_manifests(
|
|
||||||
recordings=m["recordings"],
|
|
||||||
supervisions=m["supervisions"],
|
|
||||||
)
|
|
||||||
logging.info(f"Processing {partition}")
|
|
||||||
for i in tqdm(range(len(cut_set))):
|
|
||||||
cut_set[i].discrete_tokens = cut_set[i].supervisions[0].discrete_tokens
|
|
||||||
try:
|
|
||||||
del cut_set[i].supervisions[0].custom
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
cut_set.to_file(output_dir / cuts_filename)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
args = get_args()
|
|
||||||
logging.info(vars(args))
|
|
||||||
compute_fbank_librispeech(
|
|
||||||
dataset=args.dataset,
|
|
||||||
)
|
|
@ -105,10 +105,7 @@ def compute_fbank_librispeech(
|
|||||||
logging.info(f"Processing {partition}")
|
logging.info(f"Processing {partition}")
|
||||||
for i in tqdm(range(len(cut_set))):
|
for i in tqdm(range(len(cut_set))):
|
||||||
cut_set[i].discrete_tokens = cut_set[i].supervisions[0].discrete_tokens
|
cut_set[i].discrete_tokens = cut_set[i].supervisions[0].discrete_tokens
|
||||||
try:
|
del cut_set[i].supervisions[0].custom
|
||||||
del cut_set[i].supervisions[0].custom
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
cut_set.to_file(output_dir / cuts_filename)
|
cut_set.to_file(output_dir / cuts_filename)
|
||||||
|
|
@ -28,7 +28,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|||||||
DiscretizedInputAugment,
|
DiscretizedInputAugment,
|
||||||
DiscretizedInputSpeechRecognitionDataset,
|
DiscretizedInputSpeechRecognitionDataset,
|
||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
SimpleCutSampler,
|
SingleCutSampler,
|
||||||
)
|
)
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -190,7 +190,7 @@ class LibriSpeechAsrDataModule:
|
|||||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||||
num_frame_masks=num_frame_masks,
|
num_frame_masks=num_frame_masks,
|
||||||
tokens_mask_size=27,
|
tokens_mask_size=27,
|
||||||
num_token_masks=4,
|
num_token_masks=2,
|
||||||
frames_mask_size=100,
|
frames_mask_size=100,
|
||||||
)
|
)
|
||||||
)
|
)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user