mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
Use tokens.txt to replace bpe.model (#1162)
This commit is contained in:
parent
d6b28a11a7
commit
a81396b482
@ -38,7 +38,7 @@ log "Decode with models exported by torch.jit.trace()"
|
|||||||
for m in ctc-decoding 1best; do
|
for m in ctc-decoding 1best; do
|
||||||
./conformer_ctc3/jit_pretrained.py \
|
./conformer_ctc3/jit_pretrained.py \
|
||||||
--model-filename $repo/exp/jit_trace.pt \
|
--model-filename $repo/exp/jit_trace.pt \
|
||||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--G $repo/data/lm/G_4_gram.pt \
|
--G $repo/data/lm/G_4_gram.pt \
|
||||||
@ -53,7 +53,7 @@ log "Export to torchscript model"
|
|||||||
|
|
||||||
./conformer_ctc3/export.py \
|
./conformer_ctc3/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--lang-dir $repo/data/lang_bpe_500 \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--jit-trace 1 \
|
--jit-trace 1 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -80,9 +80,9 @@ done
|
|||||||
for m in ctc-decoding 1best; do
|
for m in ctc-decoding 1best; do
|
||||||
./conformer_ctc3/pretrained.py \
|
./conformer_ctc3/pretrained.py \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--G $repo/data/lm/G_4_gram.pt \
|
--G $repo/data/lm/G_4_gram.pt \
|
||||||
--method $m \
|
--method $m \
|
||||||
--sample-rate 16000 \
|
--sample-rate 16000 \
|
||||||
@ -93,7 +93,7 @@ done
|
|||||||
|
|
||||||
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
|
||||||
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
|
||||||
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
|
||||||
mkdir -p conformer_ctc3/exp
|
mkdir -p conformer_ctc3/exp
|
||||||
ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt
|
ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt
|
||||||
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
ln -s $PWD/$repo/data/lang_bpe_500 data/
|
||||||
|
|||||||
@ -31,7 +31,7 @@ log "Test exporting with torch.jit.trace()"
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
./lstm_transducer_stateless2/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
@ -55,7 +55,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -68,7 +68,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -36,7 +36,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -49,7 +49,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -35,7 +35,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -48,7 +48,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -30,14 +30,14 @@ popd
|
|||||||
log "Export to torchscript model"
|
log "Export to torchscript model"
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit-trace 1
|
--jit-trace 1
|
||||||
@ -74,7 +74,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -32,7 +32,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--num-encoder-layers 18 \
|
--num-encoder-layers 18 \
|
||||||
--dim-feedforward 2048 \
|
--dim-feedforward 2048 \
|
||||||
--nhead 8 \
|
--nhead 8 \
|
||||||
@ -51,7 +51,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav \
|
$repo/test_wavs/1221-135766-0002.wav \
|
||||||
|
|||||||
@ -33,7 +33,7 @@ log "Export to torchscript model"
|
|||||||
./pruned_transducer_stateless7/export.py \
|
./pruned_transducer_stateless7/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -56,7 +56,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -69,7 +69,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -37,7 +37,7 @@ log "Export to torchscript model"
|
|||||||
./pruned_transducer_stateless7_ctc/export.py \
|
./pruned_transducer_stateless7_ctc/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -74,7 +74,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -36,7 +36,7 @@ log "Export to torchscript model"
|
|||||||
./pruned_transducer_stateless7_ctc_bs/export.py \
|
./pruned_transducer_stateless7_ctc_bs/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -72,7 +72,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -85,7 +85,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -37,7 +37,7 @@ log "Export to torchscript model"
|
|||||||
./pruned_transducer_stateless7_streaming/export.py \
|
./pruned_transducer_stateless7_streaming/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -81,7 +81,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
@ -95,7 +95,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--decode-chunk-len 32 \
|
--decode-chunk-len 32 \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
|
|||||||
@ -41,7 +41,7 @@ log "Decode with models exported by torch.jit.script()"
|
|||||||
log "Export to torchscript model"
|
log "Export to torchscript model"
|
||||||
./pruned_transducer_stateless8/export.py \
|
./pruned_transducer_stateless8/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -65,7 +65,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -78,7 +78,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -32,7 +32,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--simulate-streaming 1 \
|
--simulate-streaming 1 \
|
||||||
--causal-convolution 1 \
|
--causal-convolution 1 \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
@ -47,7 +47,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--simulate-streaming 1 \
|
--simulate-streaming 1 \
|
||||||
--causal-convolution 1 \
|
--causal-convolution 1 \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
|
|||||||
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -37,7 +37,7 @@ log "Export to torchscript model"
|
|||||||
./zipformer_mmi/export.py \
|
./zipformer_mmi/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -61,7 +61,7 @@ for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescor
|
|||||||
--method $method \
|
--method $method \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--lang-dir $repo/data/lang_bpe_500 \
|
--lang-dir $repo/data/lang_bpe_500 \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -27,7 +27,7 @@ log "CTC decoding"
|
|||||||
--method ctc-decoding \
|
--method ctc-decoding \
|
||||||
--num-classes 500 \
|
--num-classes 500 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.flac \
|
$repo/test_wavs/1089-134686-0001.flac \
|
||||||
$repo/test_wavs/1221-135766-0001.flac \
|
$repo/test_wavs/1221-135766-0001.flac \
|
||||||
$repo/test_wavs/1221-135766-0002.flac
|
$repo/test_wavs/1221-135766-0002.flac
|
||||||
@ -38,7 +38,7 @@ log "HLG decoding"
|
|||||||
--method 1best \
|
--method 1best \
|
||||||
--num-classes 500 \
|
--num-classes 500 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||||
$repo/test_wavs/1089-134686-0001.flac \
|
$repo/test_wavs/1089-134686-0001.flac \
|
||||||
|
|||||||
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -27,7 +27,7 @@ log "Beam search decoding"
|
|||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -17,7 +17,6 @@ git lfs install
|
|||||||
git clone $repo_url
|
git clone $repo_url
|
||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
|
||||||
log "Display test files"
|
log "Display test files"
|
||||||
tree $repo/
|
tree $repo/
|
||||||
ls -lh $repo/test_wavs/*.wav
|
ls -lh $repo/test_wavs/*.wav
|
||||||
@ -29,12 +28,11 @@ popd
|
|||||||
|
|
||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless2/export.py \
|
./pruned_transducer_stateless2/export-onnx.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--lang-dir $repo/data/lang_char \
|
--lang-dir $repo/data/lang_char \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1
|
||||||
--onnx 1
|
|
||||||
|
|
||||||
log "Export to torchscript model"
|
log "Export to torchscript model"
|
||||||
|
|
||||||
@ -59,19 +57,17 @@ log "Decode with ONNX models"
|
|||||||
|
|
||||||
./pruned_transducer_stateless2/onnx_check.py \
|
./pruned_transducer_stateless2/onnx_check.py \
|
||||||
--jit-filename $repo/exp/cpu_jit.pt \
|
--jit-filename $repo/exp/cpu_jit.pt \
|
||||||
--onnx-encoder-filename $repo/exp/encoder.onnx \
|
--onnx-encoder-filename $repo/exp/encoder-epoch-10-avg-2.onnx \
|
||||||
--onnx-decoder-filename $repo/exp/decoder.onnx \
|
--onnx-decoder-filename $repo/exp/decoder-epoch-10-avg-2.onnx \
|
||||||
--onnx-joiner-filename $repo/exp/joiner.onnx \
|
--onnx-joiner-filename $repo/exp/joiner-epoch-10-avg-2.onnx \
|
||||||
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
|
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
|
||||||
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
|
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj-epoch-10-avg-2.onnx
|
||||||
|
|
||||||
./pruned_transducer_stateless2/onnx_pretrained.py \
|
./pruned_transducer_stateless2/onnx_pretrained.py \
|
||||||
--tokens $repo/data/lang_char/tokens.txt \
|
--tokens $repo/data/lang_char/tokens.txt \
|
||||||
--encoder-model-filename $repo/exp/encoder.onnx \
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
--decoder-model-filename $repo/exp/decoder.onnx \
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
--joiner-model-filename $repo/exp/joiner.onnx \
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||||
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
|
|
||||||
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
|
|
||||||
$repo/test_wavs/DEV_T0000000000.wav \
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
$repo/test_wavs/DEV_T0000000001.wav \
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
$repo/test_wavs/DEV_T0000000002.wav
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
@ -104,9 +100,9 @@ for sym in 1 2 3; do
|
|||||||
--lang-dir $repo/data/lang_char \
|
--lang-dir $repo/data/lang_char \
|
||||||
--decoding-method greedy_search \
|
--decoding-method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
$repo/test_wavs/DEV_T0000000000.wav \
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
$repo/test_wavs/DEV_T0000000001.wav \
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
$repo/test_wavs/DEV_T0000000002.wav
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
done
|
done
|
||||||
|
|
||||||
for method in modified_beam_search beam_search fast_beam_search; do
|
for method in modified_beam_search beam_search fast_beam_search; do
|
||||||
@ -117,7 +113,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/epoch-99.pt \
|
--checkpoint $repo/exp/epoch-99.pt \
|
||||||
--lang-dir $repo/data/lang_char \
|
--lang-dir $repo/data/lang_char \
|
||||||
$repo/test_wavs/DEV_T0000000000.wav \
|
$repo/test_wavs/DEV_T0000000000.wav \
|
||||||
$repo/test_wavs/DEV_T0000000001.wav \
|
$repo/test_wavs/DEV_T0000000001.wav \
|
||||||
$repo/test_wavs/DEV_T0000000002.wav
|
$repo/test_wavs/DEV_T0000000002.wav
|
||||||
done
|
done
|
||||||
|
|||||||
12
.github/scripts/test-ncnn-export.sh
vendored
12
.github/scripts/test-ncnn-export.sh
vendored
@ -45,7 +45,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -56,11 +55,10 @@ log "Export via torch.jit.trace()"
|
|||||||
|
|
||||||
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
\
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--num-encoder-layers 12 \
|
--num-encoder-layers 12 \
|
||||||
--chunk-length 32 \
|
--chunk-length 32 \
|
||||||
--cnn-module-kernel 31 \
|
--cnn-module-kernel 31 \
|
||||||
@ -91,7 +89,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
|
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -102,7 +99,7 @@ log "Export via torch.jit.trace()"
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export-for-ncnn.py \
|
./lstm_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0
|
--use-averaged-model 0
|
||||||
@ -140,7 +137,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -148,7 +144,7 @@ ln -s pretrained.pt epoch-99.pt
|
|||||||
popd
|
popd
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
@ -199,7 +195,7 @@ ln -s pretrained.pt epoch-9999.pt
|
|||||||
popd
|
popd
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
|
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
|
||||||
--lang-dir $repo/data/lang_char_bpe \
|
--tokens $repo/data/lang_char_bpe/tokens.txt \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
|
|||||||
138
.github/scripts/test-onnx-export.sh
vendored
138
.github/scripts/test-onnx-export.sh
vendored
@ -10,7 +10,123 @@ log() {
|
|||||||
|
|
||||||
cd egs/librispeech/ASR
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
log "=========================================================================="
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git lfs install
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "Export via torch.jit.script()"
|
||||||
|
./zipformer/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
log "Test export to ONNX format"
|
||||||
|
./zipformer/export-onnx.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--num-encoder-layers "2,2,3,4,3,2" \
|
||||||
|
--downsampling-factor "1,2,4,8,4,2" \
|
||||||
|
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||||
|
--num-heads "4,4,4,8,4,4" \
|
||||||
|
--encoder-dim "192,256,384,512,384,256" \
|
||||||
|
--query-head-dim 32 \
|
||||||
|
--value-head-dim 12 \
|
||||||
|
--pos-head-dim 4 \
|
||||||
|
--pos-dim 48 \
|
||||||
|
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||||
|
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512 \
|
||||||
|
--causal False \
|
||||||
|
--chunk-size "16,32,64,-1" \
|
||||||
|
--left-context-frames "64,128,256,-1"
|
||||||
|
|
||||||
|
ls -lh $repo/exp
|
||||||
|
|
||||||
|
log "Run onnx_check.py"
|
||||||
|
|
||||||
|
./zipformer/onnx_check.py \
|
||||||
|
--jit-filename $repo/exp/jit_script.pt \
|
||||||
|
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
log "Run onnx_pretrained.py"
|
||||||
|
|
||||||
|
./zipformer/onnx_pretrained.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git lfs install
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
log "Test export streaming model to ONNX format"
|
||||||
|
./zipformer/export-onnx-streaming.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--num-encoder-layers "2,2,3,4,3,2" \
|
||||||
|
--downsampling-factor "1,2,4,8,4,2" \
|
||||||
|
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||||
|
--num-heads "4,4,4,8,4,4" \
|
||||||
|
--encoder-dim "192,256,384,512,384,256" \
|
||||||
|
--query-head-dim 32 \
|
||||||
|
--value-head-dim 12 \
|
||||||
|
--pos-head-dim 4 \
|
||||||
|
--pos-dim 48 \
|
||||||
|
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||||
|
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512 \
|
||||||
|
--causal True \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 64
|
||||||
|
|
||||||
|
ls -lh $repo/exp
|
||||||
|
|
||||||
|
log "Run onnx_pretrained-streaming.py"
|
||||||
|
|
||||||
|
./zipformer/onnx_pretrained-streaming.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
|
log "--------------------------------------------------------------------------"
|
||||||
|
|
||||||
log "=========================================================================="
|
log "=========================================================================="
|
||||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
@ -39,7 +155,7 @@ log "Export via torch.jit.trace()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -88,7 +204,7 @@ popd
|
|||||||
log "Export via torch.jit.script()"
|
log "Export via torch.jit.script()"
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--exp-dir $repo/exp/ \
|
--exp-dir $repo/exp/ \
|
||||||
@ -97,7 +213,7 @@ log "Export via torch.jit.script()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export-onnx.py \
|
./pruned_transducer_stateless3/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--exp-dir $repo/exp/
|
--exp-dir $repo/exp/
|
||||||
@ -126,7 +242,6 @@ log "Run onnx_pretrained.py"
|
|||||||
rm -rf $repo
|
rm -rf $repo
|
||||||
log "--------------------------------------------------------------------------"
|
log "--------------------------------------------------------------------------"
|
||||||
|
|
||||||
|
|
||||||
log "=========================================================================="
|
log "=========================================================================="
|
||||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
@ -143,7 +258,7 @@ popd
|
|||||||
log "Export via torch.jit.script()"
|
log "Export via torch.jit.script()"
|
||||||
|
|
||||||
./pruned_transducer_stateless5/export.py \
|
./pruned_transducer_stateless5/export.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
@ -159,7 +274,7 @@ log "Export via torch.jit.script()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless5/export-onnx.py \
|
./pruned_transducer_stateless5/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
@ -205,7 +320,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -215,7 +329,7 @@ popd
|
|||||||
log "Export via torch.jit.script()"
|
log "Export via torch.jit.script()"
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
./pruned_transducer_stateless7/export.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -226,7 +340,7 @@ log "Export via torch.jit.script()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export-onnx.py \
|
./pruned_transducer_stateless7/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -270,7 +384,7 @@ popd
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./conv_emformer_transducer_stateless2/export-onnx.py \
|
./conv_emformer_transducer_stateless2/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -310,7 +424,7 @@ popd
|
|||||||
log "Export via torch.jit.trace()"
|
log "Export via torch.jit.trace()"
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
./lstm_transducer_stateless2/export.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -320,7 +434,7 @@ log "Export via torch.jit.trace()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export-onnx.py \
|
./lstm_transducer_stateless2/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
|
|||||||
@ -1,321 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
#
|
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# This script converts several saved checkpoints
|
|
||||||
# to a single one using model averaging.
|
|
||||||
"""
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
|
|
||||||
(1) Export to torchscript model using torch.jit.script()
|
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
|
||||||
--lang-dir data/lang_char \
|
|
||||||
--epoch 30 \
|
|
||||||
--avg 9 \
|
|
||||||
--jit 1
|
|
||||||
|
|
||||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
|
||||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
|
||||||
|
|
||||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
|
||||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
|
||||||
|
|
||||||
Check
|
|
||||||
https://github.com/k2-fsa/sherpa
|
|
||||||
for how to use the exported models outside of icefall.
|
|
||||||
|
|
||||||
(2) Export `model.state_dict()`
|
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
|
||||||
--lang-dir data/lang_char \
|
|
||||||
--epoch 20 \
|
|
||||||
--avg 10
|
|
||||||
|
|
||||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
|
||||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
|
||||||
|
|
||||||
To use the generated file with `pruned_transducer_stateless7/decode.py`,
|
|
||||||
you can do:
|
|
||||||
|
|
||||||
cd /path/to/exp_dir
|
|
||||||
ln -s pretrained.pt epoch-9999.pt
|
|
||||||
|
|
||||||
cd /path/to/egs/librispeech/ASR
|
|
||||||
./pruned_transducer_stateless7/decode.py \
|
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
|
||||||
--epoch 9999 \
|
|
||||||
--avg 1 \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method greedy_search \
|
|
||||||
--lang-dir data/lang_char
|
|
||||||
|
|
||||||
Check ./pretrained.py for its usage.
|
|
||||||
|
|
||||||
Note: If you don't want to train a model from scratch, we have
|
|
||||||
provided one for you. You can get it at
|
|
||||||
|
|
||||||
https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21
|
|
||||||
|
|
||||||
with the following commands:
|
|
||||||
|
|
||||||
sudo apt-get install git-lfs
|
|
||||||
git lfs install
|
|
||||||
git clone https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21
|
|
||||||
# You will find the pre-trained model in icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21exp
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
|
||||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
|
||||||
average_checkpoints,
|
|
||||||
average_checkpoints_with_averaged_model,
|
|
||||||
find_checkpoints,
|
|
||||||
load_checkpoint,
|
|
||||||
)
|
|
||||||
from icefall.lexicon import Lexicon
|
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--epoch",
|
|
||||||
type=int,
|
|
||||||
default=30,
|
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
|
||||||
Note: Epoch counts from 1.
|
|
||||||
You can specify --avg to use more checkpoints for model averaging.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--iter",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="""If positive, --epoch is ignored and it
|
|
||||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
|
||||||
You can specify --avg to use more checkpoints for model averaging.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--avg",
|
|
||||||
type=int,
|
|
||||||
default=9,
|
|
||||||
help="Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
|
||||||
"'--epoch' and '--iter'",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-averaged-model",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="Whether to load averaged model. Currently it only supports "
|
|
||||||
"using --epoch. If True, it would decode with the averaged model "
|
|
||||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
|
||||||
"Actually only the models with epoch number of `epoch-avg` and "
|
|
||||||
"`epoch` are loaded for averaging. ",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--exp-dir",
|
|
||||||
type=str,
|
|
||||||
default="pruned_transducer_stateless7/exp",
|
|
||||||
help="""It specifies the directory where all training related
|
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--lang-dir",
|
|
||||||
type=str,
|
|
||||||
default="data/lang_char",
|
|
||||||
help="""The lang dir
|
|
||||||
It contains language related input files such as
|
|
||||||
"lexicon.txt"
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--jit",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""True to save a model after applying torch.jit.script.
|
|
||||||
It will generate a file named cpu_jit.pt
|
|
||||||
|
|
||||||
Check ./jit_pretrained.py for how to use it.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--context-size",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main():
|
|
||||||
args = get_parser().parse_args()
|
|
||||||
args.exp_dir = Path(args.exp_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
|
||||||
params.update(vars(args))
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda", 0)
|
|
||||||
|
|
||||||
logging.info(f"device: {device}")
|
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
|
||||||
params.blank_id = 0
|
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
|
||||||
|
|
||||||
logging.info(params)
|
|
||||||
|
|
||||||
logging.info("About to create model")
|
|
||||||
model = get_transducer_model(params)
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
if not params.use_averaged_model:
|
|
||||||
if params.iter > 0:
|
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
|
||||||
: params.avg
|
|
||||||
]
|
|
||||||
if len(filenames) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"No checkpoints found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
elif len(filenames) < params.avg:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
logging.info(f"averaging {filenames}")
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
|
||||||
elif params.avg == 1:
|
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
||||||
else:
|
|
||||||
start = params.epoch - params.avg + 1
|
|
||||||
filenames = []
|
|
||||||
for i in range(start, params.epoch + 1):
|
|
||||||
if i >= 1:
|
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
|
||||||
logging.info(f"averaging {filenames}")
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
|
||||||
else:
|
|
||||||
if params.iter > 0:
|
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
|
||||||
: params.avg + 1
|
|
||||||
]
|
|
||||||
if len(filenames) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"No checkpoints found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
elif len(filenames) < params.avg + 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
filename_start = filenames[-1]
|
|
||||||
filename_end = filenames[0]
|
|
||||||
logging.info(
|
|
||||||
"Calculating the averaged model over iteration checkpoints"
|
|
||||||
f" from {filename_start} (excluded) to {filename_end}"
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(
|
|
||||||
average_checkpoints_with_averaged_model(
|
|
||||||
filename_start=filename_start,
|
|
||||||
filename_end=filename_end,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert params.avg > 0, params.avg
|
|
||||||
start = params.epoch - params.avg
|
|
||||||
assert start >= 1, start
|
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
|
||||||
logging.info(
|
|
||||||
f"Calculating the averaged model over epoch range from "
|
|
||||||
f"{start} (excluded) to {params.epoch}"
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(
|
|
||||||
average_checkpoints_with_averaged_model(
|
|
||||||
filename_start=filename_start,
|
|
||||||
filename_end=filename_end,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
model.to("cpu")
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
if params.jit is True:
|
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
|
||||||
# We won't use the forward() method of the model in C++, so just ignore
|
|
||||||
# it here.
|
|
||||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
|
||||||
# torch scriptabe.
|
|
||||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
|
||||||
logging.info("Using torch.jit.script")
|
|
||||||
model = torch.jit.script(model)
|
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
|
||||||
model.save(str(filename))
|
|
||||||
logging.info(f"Saved to {filename}")
|
|
||||||
else:
|
|
||||||
logging.info("Not using torchscript. Export model.state_dict()")
|
|
||||||
# Save it using a format so that it can be loaded
|
|
||||||
# by :func:`load_checkpoint`
|
|
||||||
filename = params.exp_dir / "pretrained.pt"
|
|
||||||
torch.save({"model": model.state_dict()}, str(filename))
|
|
||||||
logging.info(f"Saved to {filename}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
main()
|
|
||||||
1
egs/aishell/ASR/pruned_transducer_stateless7/export.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7/export.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7/export.py
|
||||||
@ -1,348 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
This script loads a checkpoint and uses it to decode waves.
|
|
||||||
You can generate the checkpoint with the following command:
|
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
|
||||||
--lang-dir data/lang_char \
|
|
||||||
--epoch 20 \
|
|
||||||
--avg 10
|
|
||||||
|
|
||||||
Usage of this script:
|
|
||||||
|
|
||||||
(1) greedy search
|
|
||||||
./pruned_transducer_stateless7/pretrained.py \
|
|
||||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
|
||||||
--lang-dir ./data/lang_char \
|
|
||||||
--method greedy_search \
|
|
||||||
/path/to/foo.wav \
|
|
||||||
/path/to/bar.wav
|
|
||||||
|
|
||||||
(2) beam search
|
|
||||||
./pruned_transducer_stateless7/pretrained.py \
|
|
||||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
|
||||||
--lang-dir ./data/lang_char \
|
|
||||||
--method beam_search \
|
|
||||||
--beam-size 4 \
|
|
||||||
/path/to/foo.wav \
|
|
||||||
/path/to/bar.wav
|
|
||||||
|
|
||||||
(3) modified beam search
|
|
||||||
./pruned_transducer_stateless7/pretrained.py \
|
|
||||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
|
||||||
--lang-dir ./data/lang_char \
|
|
||||||
--method modified_beam_search \
|
|
||||||
--beam-size 4 \
|
|
||||||
/path/to/foo.wav \
|
|
||||||
/path/to/bar.wav
|
|
||||||
|
|
||||||
(4) fast beam search
|
|
||||||
./pruned_transducer_stateless7/pretrained.py \
|
|
||||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
|
||||||
--lang-dir ./data/lang_char \
|
|
||||||
--method fast_beam_search \
|
|
||||||
--beam-size 4 \
|
|
||||||
/path/to/foo.wav \
|
|
||||||
/path/to/bar.wav
|
|
||||||
|
|
||||||
You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`.
|
|
||||||
|
|
||||||
Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by
|
|
||||||
./pruned_transducer_stateless7/export.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import k2
|
|
||||||
import kaldifeat
|
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
from beam_search import (
|
|
||||||
beam_search,
|
|
||||||
fast_beam_search_one_best,
|
|
||||||
greedy_search,
|
|
||||||
greedy_search_batch,
|
|
||||||
modified_beam_search,
|
|
||||||
)
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
|
||||||
|
|
||||||
from icefall.lexicon import Lexicon
|
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--checkpoint",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the checkpoint. "
|
|
||||||
"The checkpoint is assumed to be saved by "
|
|
||||||
"icefall.checkpoint.save_checkpoint().",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--lang-dir",
|
|
||||||
type=str,
|
|
||||||
help="""The lang dir
|
|
||||||
It contains language related input files such as
|
|
||||||
"lexicon.txt"
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--method",
|
|
||||||
type=str,
|
|
||||||
default="greedy_search",
|
|
||||||
help="""Possible values are:
|
|
||||||
- greedy_search
|
|
||||||
- beam_search
|
|
||||||
- modified_beam_search
|
|
||||||
- fast_beam_search
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"sound_files",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
help="The input sound file(s) to transcribe. "
|
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
|
||||||
"For example, wav and flac are supported. "
|
|
||||||
"The sample rate has to be 16kHz.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--sample-rate",
|
|
||||||
type=int,
|
|
||||||
default=16000,
|
|
||||||
help="The sample rate of the input sound file",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--beam-size",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="""An integer indicating how many candidates we will keep for each
|
|
||||||
frame. Used only when --method is beam_search or
|
|
||||||
modified_beam_search.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--beam",
|
|
||||||
type=float,
|
|
||||||
default=4,
|
|
||||||
help="""A floating point value to calculate the cutoff score during beam
|
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
|
||||||
`beam` in Kaldi.
|
|
||||||
Used only when --method is fast_beam_search""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-contexts",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="""Used only when --method is fast_beam_search""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-states",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="""Used only when --method is fast_beam_search""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--context-size",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-sym-per-frame",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="""Maximum number of symbols per frame. Used only when
|
|
||||||
--method is greedy_search.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def read_sound_files(
|
|
||||||
filenames: List[str], expected_sample_rate: float
|
|
||||||
) -> List[torch.Tensor]:
|
|
||||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
|
||||||
Args:
|
|
||||||
filenames:
|
|
||||||
A list of sound filenames.
|
|
||||||
expected_sample_rate:
|
|
||||||
The expected sample rate of the sound files.
|
|
||||||
Returns:
|
|
||||||
Return a list of 1-D float32 torch tensors.
|
|
||||||
"""
|
|
||||||
ans = []
|
|
||||||
for f in filenames:
|
|
||||||
wave, sample_rate = torchaudio.load(f)
|
|
||||||
assert (
|
|
||||||
sample_rate == expected_sample_rate
|
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
|
||||||
# We use only the first channel
|
|
||||||
ans.append(wave[0])
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main():
|
|
||||||
parser = get_parser()
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
params = get_params()
|
|
||||||
|
|
||||||
params.update(vars(args))
|
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
|
||||||
params.blank_id = 0
|
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
|
||||||
token_table = lexicon.token_table
|
|
||||||
|
|
||||||
logging.info(f"{params}")
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda", 0)
|
|
||||||
|
|
||||||
logging.info(f"device: {device}")
|
|
||||||
|
|
||||||
logging.info("Creating model")
|
|
||||||
model = get_transducer_model(params)
|
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
|
||||||
|
|
||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
|
||||||
model.load_state_dict(checkpoint["model"], strict=False)
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
model.device = device
|
|
||||||
|
|
||||||
logging.info("Constructing Fbank computer")
|
|
||||||
opts = kaldifeat.FbankOptions()
|
|
||||||
opts.device = device
|
|
||||||
opts.frame_opts.dither = 0
|
|
||||||
opts.frame_opts.snip_edges = False
|
|
||||||
opts.frame_opts.samp_freq = params.sample_rate
|
|
||||||
opts.mel_opts.num_bins = params.feature_dim
|
|
||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
|
||||||
|
|
||||||
logging.info(f"Reading sound files: {params.sound_files}")
|
|
||||||
waves = read_sound_files(
|
|
||||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
|
||||||
)
|
|
||||||
waves = [w.to(device) for w in waves]
|
|
||||||
|
|
||||||
logging.info("Decoding started")
|
|
||||||
features = fbank(waves)
|
|
||||||
feature_lengths = [f.size(0) for f in features]
|
|
||||||
|
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
|
||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
|
||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
|
||||||
hyps = []
|
|
||||||
msg = f"Using {params.method}"
|
|
||||||
if params.method == "beam_search":
|
|
||||||
msg += f" with beam size {params.beam_size}"
|
|
||||||
logging.info(msg)
|
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
|
||||||
model=model,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam,
|
|
||||||
max_contexts=params.max_contexts,
|
|
||||||
max_states=params.max_states,
|
|
||||||
)
|
|
||||||
elif params.method == "modified_beam_search":
|
|
||||||
hyp_tokens = modified_beam_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam_size,
|
|
||||||
)
|
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
|
||||||
hyp_tokens = greedy_search_batch(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for i in range(num_waves):
|
|
||||||
# fmt: off
|
|
||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
|
||||||
# fmt: on
|
|
||||||
if params.method == "greedy_search":
|
|
||||||
hyp_tokens = greedy_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
|
||||||
)
|
|
||||||
elif params.method == "beam_search":
|
|
||||||
hyp_tokens = beam_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
beam=params.beam_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
|
||||||
|
|
||||||
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
|
|
||||||
s = "\n"
|
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
|
||||||
words = " ".join(hyp)
|
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
main()
|
|
||||||
1
egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py
Symbolic link
1
egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py
|
||||||
@ -23,12 +23,13 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, str2bool
|
from icefall.utils import AttributeDict, num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -63,11 +64,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500",
|
required=True,
|
||||||
help="""It contains language related input files such as "lexicon.txt"
|
help="Path to the tokens.txt.",
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -98,16 +98,16 @@ def get_params() -> AttributeDict:
|
|||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
# Load tokens.txt here
|
||||||
max_token_id = max(lexicon.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
num_classes = max_token_id + 1 # +1 for the blank
|
|
||||||
|
num_classes = num_tokens(token_table) + 1 # +1 for the blank
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|||||||
@ -24,7 +24,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
@ -70,11 +69,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.
|
help="Path to the tokens.txt.",
|
||||||
Used only when method is ctc-decoding.
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -83,10 +80,9 @@ def get_parser():
|
|||||||
default="1best",
|
default="1best",
|
||||||
help="""Decoding method.
|
help="""Decoding method.
|
||||||
Possible values are:
|
Possible values are:
|
||||||
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
(0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file
|
||||||
piece model, i.e., lang_dir/bpe.model, to convert
|
to convert tokens to actual words or characters. It needs
|
||||||
word pieces to words. It needs neither a lexicon
|
neither a lexicon nor an n-gram LM.
|
||||||
nor an n-gram LM.
|
|
||||||
(1) 1best - Use the best path as decoding output. Only
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
the transformer encoder output is used for decoding.
|
the transformer encoder output is used for decoding.
|
||||||
We call it HLG decoding.
|
We call it HLG decoding.
|
||||||
@ -297,6 +293,7 @@ def main():
|
|||||||
waves = [w.to(device) for w in waves]
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
hyps = []
|
||||||
features = fbank(waves)
|
features = fbank(waves)
|
||||||
|
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||||
@ -313,10 +310,17 @@ def main():
|
|||||||
|
|
||||||
if params.method == "ctc-decoding":
|
if params.method == "ctc-decoding":
|
||||||
logging.info("Use CTC decoding")
|
logging.info("Use CTC decoding")
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
|
||||||
bpe_model.load(params.bpe_model)
|
|
||||||
max_token_id = params.num_classes - 1
|
max_token_id = params.num_classes - 1
|
||||||
|
|
||||||
|
# Load tokens.txt here
|
||||||
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
H = k2.ctc_topo(
|
H = k2.ctc_topo(
|
||||||
max_token=max_token_id,
|
max_token=max_token_id,
|
||||||
modified=params.num_classes > 500,
|
modified=params.num_classes > 500,
|
||||||
@ -337,9 +341,9 @@ def main():
|
|||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
token_ids = get_texts(best_path)
|
hyp_tokens = get_texts(best_path)
|
||||||
hyps = bpe_model.decode(token_ids)
|
for hyp in hyp_tokens:
|
||||||
hyps = [s.split() for s in hyps]
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method in [
|
elif params.method in [
|
||||||
"1best",
|
"1best",
|
||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
@ -408,16 +412,16 @@ def main():
|
|||||||
)
|
)
|
||||||
best_path = next(iter(best_path_dict.values()))
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
|
||||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
hyp_tokens = get_texts(best_path)
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -23,6 +23,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./conformer_ctc2/export.py \
|
./conformer_ctc2/export.py \
|
||||||
--exp-dir ./conformer_ctc2/exp \
|
--exp-dir ./conformer_ctc2/exp \
|
||||||
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -46,6 +47,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decode import get_params
|
from decode import get_params
|
||||||
@ -56,8 +58,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.utils import num_tokens, str2bool
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -123,10 +124,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500",
|
required=True,
|
||||||
help="The lang dir",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -143,14 +144,14 @@ def get_parser():
|
|||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
# Load tokens.txt here
|
||||||
max_token_id = max(lexicon.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
num_classes = max_token_id + 1 # +1 for the blank
|
|
||||||
|
num_classes = num_tokens(token_table) + 1 # +1 for the blank
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|||||||
@ -25,7 +25,7 @@ Usage:
|
|||||||
|
|
||||||
./conformer_ctc3/export.py \
|
./conformer_ctc3/export.py \
|
||||||
--exp-dir ./conformer_ctc3/exp \
|
--exp-dir ./conformer_ctc3/exp \
|
||||||
--lang-dir data/lang_bpe_500 \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--jit-trace 1
|
--jit-trace 1
|
||||||
@ -36,7 +36,7 @@ It will generates the file: `jit_trace.pt`.
|
|||||||
|
|
||||||
./conformer_ctc3/export.py \
|
./conformer_ctc3/export.py \
|
||||||
--exp-dir ./conformer_ctc3/exp \
|
--exp-dir ./conformer_ctc3/exp \
|
||||||
--lang-dir data/lang_bpe_500 \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -62,6 +62,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_ctc_model, get_params
|
from train import add_model_arguments, get_ctc_model, get_params
|
||||||
@ -72,8 +73,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.utils import num_tokens, str2bool
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -130,10 +130,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--tokens",
|
||||||
type=Path,
|
type=str,
|
||||||
default="data/lang_bpe_500",
|
required=True,
|
||||||
help="The lang dir containing word table and LG graph",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -171,9 +171,10 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
# Load tokens.txt here
|
||||||
max_token_id = max(lexicon.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
num_classes = max_token_id + 1 # +1 for the blank
|
|
||||||
|
num_classes = num_tokens(token_table) + 1 # +1 for the blank
|
||||||
params.vocab_size = num_classes
|
params.vocab_size = num_classes
|
||||||
|
|
||||||
if params.streaming_model:
|
if params.streaming_model:
|
||||||
|
|||||||
@ -24,7 +24,7 @@ Usage (for non-streaming mode):
|
|||||||
(1) ctc-decoding
|
(1) ctc-decoding
|
||||||
./conformer_ctc3/pretrained.py \
|
./conformer_ctc3/pretrained.py \
|
||||||
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method ctc-decoding \
|
--method ctc-decoding \
|
||||||
--sample-rate 16000 \
|
--sample-rate 16000 \
|
||||||
test_wavs/1089-134686-0001.wav
|
test_wavs/1089-134686-0001.wav
|
||||||
@ -67,7 +67,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from decode import get_decoding_params
|
from decode import get_decoding_params
|
||||||
@ -114,11 +113,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.
|
help="Path to the tokens.txt.",
|
||||||
Used only when method is ctc-decoding.
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -127,10 +124,9 @@ def get_parser():
|
|||||||
default="1best",
|
default="1best",
|
||||||
help="""Decoding method.
|
help="""Decoding method.
|
||||||
Possible values are:
|
Possible values are:
|
||||||
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
(0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file
|
||||||
piece model, i.e., lang_dir/bpe.model, to convert
|
to convert tokens to actual words or characters. It needs
|
||||||
word pieces to words. It needs neither a lexicon
|
neither a lexicon nor an n-gram LM.
|
||||||
nor an n-gram LM.
|
|
||||||
(1) 1best - Use the best path as decoding output. Only
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
the transformer encoder output is used for decoding.
|
the transformer encoder output is used for decoding.
|
||||||
We call it HLG decoding.
|
We call it HLG decoding.
|
||||||
@ -316,6 +312,7 @@ def main():
|
|||||||
waves = [w.to(device) for w in waves]
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
hyps = []
|
||||||
features = fbank(waves)
|
features = fbank(waves)
|
||||||
feature_lengths = [f.size(0) for f in features]
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
@ -348,10 +345,17 @@ def main():
|
|||||||
|
|
||||||
if params.method == "ctc-decoding":
|
if params.method == "ctc-decoding":
|
||||||
logging.info("Use CTC decoding")
|
logging.info("Use CTC decoding")
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
|
||||||
bpe_model.load(params.bpe_model)
|
|
||||||
max_token_id = params.num_classes - 1
|
max_token_id = params.num_classes - 1
|
||||||
|
|
||||||
|
# Load tokens.txt here
|
||||||
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
H = k2.ctc_topo(
|
H = k2.ctc_topo(
|
||||||
max_token=max_token_id,
|
max_token=max_token_id,
|
||||||
modified=False,
|
modified=False,
|
||||||
@ -372,9 +376,9 @@ def main():
|
|||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
token_ids = get_texts(best_path)
|
hyp_tokens = get_texts(best_path)
|
||||||
hyps = bpe_model.decode(token_ids)
|
for hyp in hyp_tokens:
|
||||||
hyps = [s.split() for s in hyps]
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method in [
|
elif params.method in [
|
||||||
"1best",
|
"1best",
|
||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
@ -439,16 +443,16 @@ def main():
|
|||||||
)
|
)
|
||||||
best_path = next(iter(best_path_dict.values()))
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
|
||||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
hyp_tokens = get_texts(best_path)
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./conv_emformer_transducer_stateless/export.py \
|
./conv_emformer_transducer_stateless/export.py \
|
||||||
--exp-dir ./conv_emformer_transducer_stateless/exp \
|
--exp-dir ./conv_emformer_transducer_stateless/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--use-averaged-model=True \
|
--use-averaged-model=True \
|
||||||
@ -62,7 +62,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
@ -72,7 +72,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -118,10 +118,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
required=True,
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -166,12 +166,12 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ for more details about how to use this file.
|
|||||||
Usage:
|
Usage:
|
||||||
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir ./conv_emformer_transducer_stateless2/exp \
|
--exp-dir ./conv_emformer_transducer_stateless2/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--use-averaged-model=True \
|
--use-averaged-model=True \
|
||||||
@ -37,7 +37,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -48,7 +48,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import setup_logger, str2bool
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -94,10 +94,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
required=True,
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -217,12 +217,12 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -28,7 +27,7 @@ popd
|
|||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./conv_emformer_transducer_stateless2/export-onnx.py \
|
./conv_emformer_transducer_stateless2/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -55,14 +54,14 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import onnx
|
import onnx
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from emformer import Emformer
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||||
from emformer import Emformer
|
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -70,7 +69,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import setup_logger, str2bool
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -127,10 +126,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
required=True,
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -484,12 +483,12 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./conv_emformer_transducer_stateless2/export.py \
|
./conv_emformer_transducer_stateless2/export.py \
|
||||||
--exp-dir ./conv_emformer_transducer_stateless2/exp \
|
--exp-dir ./conv_emformer_transducer_stateless2/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--use-averaged-model=True \
|
--use-averaged-model=True \
|
||||||
@ -62,7 +62,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -73,7 +73,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -119,10 +119,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
required=True,
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -167,12 +167,12 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ popd
|
|||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./conv_emformer_transducer_stateless2/export-onnx.py \
|
./conv_emformer_transducer_stateless2/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
|
|||||||
@ -26,7 +26,7 @@ Usage:
|
|||||||
|
|
||||||
./lstm_transducer_stateless/export.py \
|
./lstm_transducer_stateless/export.py \
|
||||||
--exp-dir ./lstm_transducer_stateless/exp \
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 35 \
|
--epoch 35 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--jit-trace 1
|
--jit-trace 1
|
||||||
@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`,
|
|||||||
|
|
||||||
./lstm_transducer_stateless/export.py \
|
./lstm_transducer_stateless/export.py \
|
||||||
--exp-dir ./lstm_transducer_stateless/exp \
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 35 \
|
--epoch 35 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -79,7 +79,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
@ -91,7 +91,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -148,10 +148,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -266,12 +266,13 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size, <blk> is
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
# defined in local/train_bpe_model.py
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./lstm_transducer_stateless/pretrained.py \
|
./lstm_transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -28,7 +28,7 @@ Usage:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./lstm_transducer_stateless/pretrained.py \
|
./lstm_transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -37,7 +37,7 @@ Usage:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./lstm_transducer_stateless/pretrained.py \
|
./lstm_transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -46,7 +46,7 @@ Usage:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./lstm_transducer_stateless/pretrained.py \
|
./lstm_transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -66,7 +66,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -79,6 +78,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -95,9 +96,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -214,13 +215,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -275,6 +277,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -286,8 +294,8 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -296,16 +304,16 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -326,12 +334,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -29,7 +29,7 @@ popd
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export-for-ncnn.py \
|
./lstm_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
@ -49,7 +49,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -60,7 +60,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import setup_logger, str2bool
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -106,10 +106,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -221,12 +221,13 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size, <blk> is
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
# defined in local/train_bpe_model.py
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -613,7 +613,7 @@ def main():
|
|||||||
quantize_dynamic(
|
quantize_dynamic(
|
||||||
model_input=decoder_filename,
|
model_input=decoder_filename,
|
||||||
model_output=decoder_filename_int8,
|
model_output=decoder_filename_int8,
|
||||||
op_types_to_quantize=["MatMul"],
|
op_types_to_quantize=["MatMul", "Gather"],
|
||||||
weight_type=QuantType.QInt8,
|
weight_type=QuantType.QInt8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ popd
|
|||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export-onnx.py \
|
./lstm_transducer_stateless2/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -52,8 +52,8 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import onnx
|
import onnx
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
@ -68,7 +68,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import setup_logger, str2bool
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -125,10 +125,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -437,12 +437,13 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size, <blk> is
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
# defined in local/train_bpe_model.py
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -607,7 +608,7 @@ def main():
|
|||||||
quantize_dynamic(
|
quantize_dynamic(
|
||||||
model_input=decoder_filename,
|
model_input=decoder_filename,
|
||||||
model_output=decoder_filename_int8,
|
model_output=decoder_filename_int8,
|
||||||
op_types_to_quantize=["MatMul"],
|
op_types_to_quantize=["MatMul", "Gather"],
|
||||||
weight_type=QuantType.QInt8,
|
weight_type=QuantType.QInt8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -27,7 +27,7 @@ Usage:
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
./lstm_transducer_stateless2/export.py \
|
||||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 35 \
|
--epoch 35 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--jit-trace 1
|
--jit-trace 1
|
||||||
@ -39,7 +39,7 @@ It will generate 3 files: `encoder_jit_trace.pt`,
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
./lstm_transducer_stateless2/export.py \
|
||||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 35 \
|
--epoch 35 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -80,7 +80,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
@ -92,7 +92,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -149,10 +149,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -267,12 +267,13 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size, <blk> is
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
# defined in local/train_bpe_model.py
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./lstm_transducer_stateless2/pretrained.py \
|
./lstm_transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -28,7 +28,7 @@ Usage:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./lstm_transducer_stateless2/pretrained.py \
|
./lstm_transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -37,7 +37,7 @@ Usage:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./lstm_transducer_stateless2/pretrained.py \
|
./lstm_transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -46,7 +46,7 @@ Usage:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./lstm_transducer_stateless2/pretrained.py \
|
./lstm_transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -69,7 +69,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -82,6 +81,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -98,9 +99,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -217,13 +218,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -278,6 +280,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -289,8 +297,8 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -299,16 +307,16 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -329,12 +337,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -26,7 +26,7 @@ Usage:
|
|||||||
|
|
||||||
./lstm_transducer_stateless3/export.py \
|
./lstm_transducer_stateless3/export.py \
|
||||||
--exp-dir ./lstm_transducer_stateless3/exp \
|
--exp-dir ./lstm_transducer_stateless3/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 40 \
|
--epoch 40 \
|
||||||
--avg 20 \
|
--avg 20 \
|
||||||
--jit-trace 1
|
--jit-trace 1
|
||||||
@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`,
|
|||||||
|
|
||||||
./lstm_transducer_stateless3/export.py \
|
./lstm_transducer_stateless3/export.py \
|
||||||
--exp-dir ./lstm_transducer_stateless3/exp \
|
--exp-dir ./lstm_transducer_stateless3/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 40 \
|
--epoch 40 \
|
||||||
--avg 20
|
--avg 20
|
||||||
|
|
||||||
@ -79,7 +79,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
@ -91,7 +91,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -148,10 +148,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -266,12 +266,13 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size, <blk> is
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
# defined in local/train_bpe_model.py
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./lstm_transducer_stateless3/pretrained.py \
|
./lstm_transducer_stateless3/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -28,7 +28,7 @@ Usage:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./lstm_transducer_stateless3/pretrained.py \
|
./lstm_transducer_stateless3/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -37,7 +37,7 @@ Usage:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./lstm_transducer_stateless3/pretrained.py \
|
./lstm_transducer_stateless3/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -79,6 +79,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -95,9 +97,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -214,13 +216,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -275,6 +278,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -286,8 +295,8 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -296,16 +305,16 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -326,12 +335,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./prunted_stateless_emformer_rnnt/export.py \
|
./prunted_stateless_emformer_rnnt/export.py \
|
||||||
--exp-dir ./prunted_stateless_emformer_rnnt/exp \
|
--exp-dir ./prunted_stateless_emformer_rnnt/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -115,10 +115,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -154,13 +154,12 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
params.vocab_size = sp.get_piece_size()
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -508,7 +508,7 @@ def main():
|
|||||||
quantize_dynamic(
|
quantize_dynamic(
|
||||||
model_input=decoder_filename,
|
model_input=decoder_filename,
|
||||||
model_output=decoder_filename_int8,
|
model_output=decoder_filename_int8,
|
||||||
op_types_to_quantize=["MatMul"],
|
op_types_to_quantize=["MatMul", "Gather"],
|
||||||
weight_type=QuantType.QInt8,
|
weight_type=QuantType.QInt8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./pruned_transducer_stateless/export.py \
|
./pruned_transducer_stateless/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless/exp \
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -47,12 +47,12 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -87,10 +87,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -135,13 +135,13 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size, <blk> is
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
# defined in local/train_bpe_model.py
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
if params.streaming_model:
|
if params.streaming_model:
|
||||||
assert params.causal_convolution
|
assert params.causal_convolution
|
||||||
|
|||||||
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless/pretrained.py \
|
./pruned_transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -28,7 +28,7 @@ Usage:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless/pretrained.py \
|
./pruned_transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -37,7 +37,7 @@ Usage:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless/pretrained.py \
|
./pruned_transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -46,7 +46,7 @@ Usage:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless/pretrained.py \
|
./pruned_transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -66,7 +66,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -79,7 +78,7 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -97,9 +96,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -237,13 +236,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
assert (
|
assert (
|
||||||
@ -314,6 +314,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -325,8 +331,8 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -335,16 +341,16 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -365,12 +371,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./pruned_transducer_stateless2/export.py \
|
./pruned_transducer_stateless2/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -47,12 +47,12 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -98,10 +98,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -145,12 +145,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
if params.streaming_model:
|
if params.streaming_model:
|
||||||
assert params.causal_convolution
|
assert params.causal_convolution
|
||||||
|
|||||||
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless2/pretrained.py \
|
./pruned_transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -28,7 +28,7 @@ Usage:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless2/pretrained.py \
|
./pruned_transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -37,7 +37,7 @@ Usage:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless2/pretrained.py \
|
./pruned_transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -46,7 +46,7 @@ Usage:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless2/pretrained.py \
|
./pruned_transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -66,7 +66,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -79,7 +78,7 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -97,9 +96,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -238,13 +237,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
assert (
|
assert (
|
||||||
@ -315,6 +315,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -326,8 +332,8 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -336,16 +342,16 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -366,12 +372,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -28,7 +28,7 @@ popd
|
|||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export-onnx.py \
|
./pruned_transducer_stateless3/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--exp-dir $repo/exp/
|
--exp-dir $repo/exp/
|
||||||
@ -48,8 +48,8 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import onnx
|
import onnx
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
@ -59,7 +59,7 @@ from scaling_converter import convert_scaled_to_non_scaled
|
|||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||||
from icefall.utils import setup_logger
|
from icefall.utils import num_tokens, setup_logger
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -105,10 +105,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -393,12 +393,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -518,7 +520,7 @@ def main():
|
|||||||
quantize_dynamic(
|
quantize_dynamic(
|
||||||
model_input=decoder_filename,
|
model_input=decoder_filename,
|
||||||
model_output=decoder_filename_int8,
|
model_output=decoder_filename_int8,
|
||||||
op_types_to_quantize=["MatMul"],
|
op_types_to_quantize=["MatMul", "Gather"],
|
||||||
weight_type=QuantType.QInt8,
|
weight_type=QuantType.QInt8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -26,7 +26,7 @@ Usage:
|
|||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -44,7 +44,7 @@ It will also generate 3 other files: `encoder_jit_script.pt`,
|
|||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--jit-trace 1
|
--jit-trace 1
|
||||||
@ -56,7 +56,7 @@ It will generates 3 files: `encoder_jit_trace.pt`,
|
|||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -97,14 +97,14 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -150,10 +150,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -342,12 +342,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
if params.streaming_model:
|
if params.streaming_model:
|
||||||
assert params.causal_convolution
|
assert params.causal_convolution
|
||||||
|
|||||||
@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
|
|||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ Usage of this script:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless3/pretrained.py \
|
./pruned_transducer_stateless3/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -37,7 +37,7 @@ Usage of this script:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless3/pretrained.py \
|
./pruned_transducer_stateless3/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -46,7 +46,7 @@ Usage of this script:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless3/pretrained.py \
|
./pruned_transducer_stateless3/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -55,7 +55,7 @@ Usage of this script:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless3/pretrained.py \
|
./pruned_transducer_stateless3/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -75,7 +75,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -88,7 +87,7 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -106,9 +105,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -247,13 +246,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
assert (
|
assert (
|
||||||
@ -324,6 +324,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -335,8 +341,8 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -345,16 +351,16 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -375,12 +381,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./pruned_transducer_stateless4/export.py \
|
./pruned_transducer_stateless4/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -59,7 +59,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -116,10 +116,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -164,12 +164,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
if params.streaming_model:
|
if params.streaming_model:
|
||||||
assert params.causal_convolution
|
assert params.causal_convolution
|
||||||
|
|||||||
@ -28,7 +28,7 @@ popd
|
|||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./pruned_transducer_stateless5/export-onnx-streaming.py \
|
./pruned_transducer_stateless5/export-onnx-streaming.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
@ -58,13 +58,13 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import onnx
|
import onnx
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
@ -74,7 +74,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import setup_logger, str2bool
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -131,10 +131,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -489,12 +489,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -662,7 +664,7 @@ def main():
|
|||||||
quantize_dynamic(
|
quantize_dynamic(
|
||||||
model_input=decoder_filename,
|
model_input=decoder_filename,
|
||||||
model_output=decoder_filename_int8,
|
model_output=decoder_filename_int8,
|
||||||
op_types_to_quantize=["MatMul"],
|
op_types_to_quantize=["MatMul", "Gather"],
|
||||||
weight_type=QuantType.QInt8,
|
weight_type=QuantType.QInt8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ popd
|
|||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./pruned_transducer_stateless5/export-onnx.py \
|
./pruned_transducer_stateless5/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
@ -55,13 +55,13 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import onnx
|
import onnx
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
@ -71,7 +71,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import setup_logger, str2bool
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -128,10 +128,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -416,12 +416,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -586,7 +588,7 @@ def main():
|
|||||||
quantize_dynamic(
|
quantize_dynamic(
|
||||||
model_input=decoder_filename,
|
model_input=decoder_filename,
|
||||||
model_output=decoder_filename_int8,
|
model_output=decoder_filename_int8,
|
||||||
op_types_to_quantize=["MatMul"],
|
op_types_to_quantize=["MatMul", "Gather"],
|
||||||
weight_type=QuantType.QInt8,
|
weight_type=QuantType.QInt8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./pruned_transducer_stateless5/export.py \
|
./pruned_transducer_stateless5/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -59,7 +59,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -116,10 +116,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -164,12 +164,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
if params.streaming_model:
|
if params.streaming_model:
|
||||||
assert params.causal_convolution
|
assert params.causal_convolution
|
||||||
|
|||||||
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless5/pretrained.py \
|
./pruned_transducer_stateless5/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -28,7 +28,7 @@ Usage:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless5/pretrained.py \
|
./pruned_transducer_stateless5/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -37,7 +37,7 @@ Usage:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless5/pretrained.py \
|
./pruned_transducer_stateless5/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -46,7 +46,7 @@ Usage:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless5/pretrained.py \
|
./pruned_transducer_stateless5/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -66,7 +66,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -79,6 +78,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -95,9 +96,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -214,13 +215,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -275,6 +277,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -286,8 +294,8 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -296,16 +304,16 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -326,12 +334,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./pruned_transducer_stateless6/export.py \
|
./pruned_transducer_stateless6/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless6/exp \
|
--exp-dir ./pruned_transducer_stateless6/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -47,12 +47,12 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -98,10 +98,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -135,12 +135,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang
|
||||||
|
# Zengrui Jin)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This script exports a transducer model from PyTorch to ONNX.
|
This script exports a transducer model from PyTorch to ONNX.
|
||||||
@ -18,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt"
|
git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -28,7 +28,7 @@ popd
|
|||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export-onnx.py \
|
./pruned_transducer_stateless7/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -50,8 +50,8 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import onnx
|
import onnx
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
@ -66,7 +66,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import setup_logger, str2bool
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -123,10 +123,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
help="Path to the tokens.txt.",
|
||||||
help="Path to the BPE model",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -411,12 +410,12 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -581,7 +580,7 @@ def main():
|
|||||||
quantize_dynamic(
|
quantize_dynamic(
|
||||||
model_input=decoder_filename,
|
model_input=decoder_filename,
|
||||||
model_output=decoder_filename_int8,
|
model_output=decoder_filename_int8,
|
||||||
op_types_to_quantize=["MatMul"],
|
op_types_to_quantize=["MatMul", "Gather"],
|
||||||
weight_type=QuantType.QInt8,
|
weight_type=QuantType.QInt8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
|
||||||
|
# Zengrui Jin)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -26,7 +27,7 @@ Usage:
|
|||||||
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
./pruned_transducer_stateless7/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 9 \
|
--avg 9 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -45,7 +46,7 @@ for how to use the exported models outside of icefall.
|
|||||||
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
./pruned_transducer_stateless7/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -65,7 +66,7 @@ you can do:
|
|||||||
--avg 1 \
|
--avg 1 \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search \
|
--decoding-method greedy_search \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
|
|
||||||
Check ./pretrained.py for its usage.
|
Check ./pretrained.py for its usage.
|
||||||
|
|
||||||
@ -86,7 +87,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
@ -98,7 +99,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -155,10 +156,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
help="Path to the tokens.txt.",
|
||||||
help="Path to the BPE model",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -198,12 +198,12 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -292,7 +292,7 @@ def main():
|
|||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit is True:
|
if params.jit:
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
# We won't use the forward() method of the model in C++, so just ignore
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
# it here.
|
# it here.
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||||
|
# Zengrui Jin)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -20,7 +21,7 @@ You can generate the checkpoint with the following command:
|
|||||||
|
|
||||||
./pruned_transducer_stateless7/export.py \
|
./pruned_transducer_stateless7/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -29,7 +30,7 @@ Usage of this script:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless7/pretrained.py \
|
./pruned_transducer_stateless7/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -37,7 +38,7 @@ Usage of this script:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless7/pretrained.py \
|
./pruned_transducer_stateless7/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -46,7 +47,7 @@ Usage of this script:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless7/pretrained.py \
|
./pruned_transducer_stateless7/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -55,7 +56,7 @@ Usage of this script:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless7/pretrained.py \
|
./pruned_transducer_stateless7/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -75,7 +76,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -88,7 +88,7 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -106,9 +106,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -225,13 +225,13 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -286,6 +286,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -297,8 +303,8 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -307,16 +313,16 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -337,12 +343,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -26,7 +26,7 @@ Usage:
|
|||||||
|
|
||||||
./pruned_transducer_stateless7_ctc/export.py \
|
./pruned_transducer_stateless7_ctc/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
|
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 9 \
|
--avg 9 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
|
|||||||
|
|
||||||
./pruned_transducer_stateless7_ctc/export.py \
|
./pruned_transducer_stateless7_ctc/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
|
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -97,7 +97,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -154,10 +154,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -197,12 +197,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
|
|||||||
|
|
||||||
./pruned_transducer_stateless7_ctc/export.py \
|
./pruned_transducer_stateless7_ctc/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
|
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ Usage of this script:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless7_ctc/pretrained.py \
|
./pruned_transducer_stateless7_ctc/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -37,7 +37,7 @@ Usage of this script:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless7_ctc/pretrained.py \
|
./pruned_transducer_stateless7_ctc/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -46,7 +46,7 @@ Usage of this script:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless7_ctc/pretrained.py \
|
./pruned_transducer_stateless7_ctc/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -55,7 +55,7 @@ Usage of this script:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless7_ctc/pretrained.py \
|
./pruned_transducer_stateless7_ctc/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -75,7 +75,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -88,6 +87,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -104,9 +105,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -223,13 +224,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -284,6 +286,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -295,8 +303,8 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -305,16 +313,16 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -335,12 +343,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -22,14 +22,14 @@ You can use the following command to get the exported models:
|
|||||||
|
|
||||||
./pruned_transducer_stateless7_ctc/export.py \
|
./pruned_transducer_stateless7_ctc/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
|
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
Usage of this script:
|
Usage of this script:
|
||||||
|
|
||||||
(1) ctc-decoding
|
(1) ctc-decoding
|
||||||
./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
|
./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
--method ctc-decoding \
|
--method ctc-decoding \
|
||||||
@ -38,7 +38,7 @@ Usage of this script:
|
|||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
|
|
||||||
(2) 1best
|
(2) 1best
|
||||||
./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
|
./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
||||||
--HLG data/lang_bpe_500/HLG.pt \
|
--HLG data/lang_bpe_500/HLG.pt \
|
||||||
--words-file data/lang_bpe_500/words.txt \
|
--words-file data/lang_bpe_500/words.txt \
|
||||||
@ -48,7 +48,7 @@ Usage of this script:
|
|||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
|
|
||||||
(3) nbest-rescoring
|
(3) nbest-rescoring
|
||||||
./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
|
./bruned_transducer_stateless7_ctc/pretrained_ctc.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
||||||
--HLG data/lang_bpe_500/HLG.pt \
|
--HLG data/lang_bpe_500/HLG.pt \
|
||||||
--words-file data/lang_bpe_500/words.txt \
|
--words-file data/lang_bpe_500/words.txt \
|
||||||
@ -60,7 +60,7 @@ Usage of this script:
|
|||||||
|
|
||||||
|
|
||||||
(4) whole-lattice-rescoring
|
(4) whole-lattice-rescoring
|
||||||
./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
|
./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
|
||||||
--HLG data/lang_bpe_500/HLG.pt \
|
--HLG data/lang_bpe_500/HLG.pt \
|
||||||
--words-file data/lang_bpe_500/words.txt \
|
--words-file data/lang_bpe_500/words.txt \
|
||||||
|
|||||||
@ -26,7 +26,7 @@ Usage:
|
|||||||
|
|
||||||
./pruned_transducer_stateless7_ctc_bs/export.py \
|
./pruned_transducer_stateless7_ctc_bs/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 13 \
|
--avg 13 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
|
|||||||
|
|
||||||
./pruned_transducer_stateless7_ctc_bs/export.py \
|
./pruned_transducer_stateless7_ctc_bs/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 13
|
--avg 13
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -97,7 +97,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -154,10 +154,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -197,12 +197,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ Usage:
|
|||||||
|
|
||||||
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
|
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
|
||||||
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 13 \
|
--avg 13 \
|
||||||
--onnx 1
|
--onnx 1
|
||||||
@ -48,7 +48,7 @@ Check `onnx_check.py` for how to use them.
|
|||||||
(2) Export to ONNX format which can be used in Triton Server
|
(2) Export to ONNX format which can be used in Triton Server
|
||||||
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
|
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
|
||||||
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 13 \
|
--avg 13 \
|
||||||
--onnx-triton 1
|
--onnx-triton 1
|
||||||
@ -86,9 +86,10 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
@ -98,8 +99,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -156,10 +156,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -728,12 +728,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
|
|||||||
|
|
||||||
./pruned_transducer_stateless7_ctc_bs/export.py \
|
./pruned_transducer_stateless7_ctc_bs/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 13
|
--avg 13
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ Usage of this script:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
|
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -37,7 +37,7 @@ Usage of this script:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
|
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -46,7 +46,7 @@ Usage of this script:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
|
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -55,7 +55,7 @@ Usage of this script:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
|
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -75,7 +75,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -88,6 +87,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -104,9 +105,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -223,13 +224,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -284,6 +286,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -295,8 +303,8 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -305,16 +313,16 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -335,12 +343,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -22,14 +22,14 @@ You can use the following command to get the exported models:
|
|||||||
|
|
||||||
./pruned_transducer_stateless7_ctc_bs/export.py \
|
./pruned_transducer_stateless7_ctc_bs/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
Usage of this script:
|
Usage of this script:
|
||||||
|
|
||||||
(1) ctc-decoding
|
(1) ctc-decoding
|
||||||
./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \
|
./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
--method ctc-decoding \
|
--method ctc-decoding \
|
||||||
@ -38,7 +38,7 @@ Usage of this script:
|
|||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
|
|
||||||
(2) 1best
|
(2) 1best
|
||||||
./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \
|
./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
||||||
--HLG data/lang_bpe_500/HLG.pt \
|
--HLG data/lang_bpe_500/HLG.pt \
|
||||||
--words-file data/lang_bpe_500/words.txt \
|
--words-file data/lang_bpe_500/words.txt \
|
||||||
@ -48,7 +48,7 @@ Usage of this script:
|
|||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
|
|
||||||
(3) nbest-rescoring
|
(3) nbest-rescoring
|
||||||
./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
|
./bruned_transducer_stateless7_ctc/pretrained_ctc.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
||||||
--HLG data/lang_bpe_500/HLG.pt \
|
--HLG data/lang_bpe_500/HLG.pt \
|
||||||
--words-file data/lang_bpe_500/words.txt \
|
--words-file data/lang_bpe_500/words.txt \
|
||||||
@ -60,7 +60,7 @@ Usage of this script:
|
|||||||
|
|
||||||
|
|
||||||
(4) whole-lattice-rescoring
|
(4) whole-lattice-rescoring
|
||||||
./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \
|
./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
|
||||||
--HLG data/lang_bpe_500/HLG.pt \
|
--HLG data/lang_bpe_500/HLG.pt \
|
||||||
--words-file data/lang_bpe_500/words.txt \
|
--words-file data/lang_bpe_500/words.txt \
|
||||||
|
|||||||
@ -66,6 +66,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -76,8 +77,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
from icefall.utils import setup_logger, str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -123,10 +123,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_char",
|
default="data/lang_char/tokens.txt",
|
||||||
help="The lang dir",
|
help="The tokens.txt file",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -246,9 +246,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
# Load tokens.txt here
|
||||||
params.blank_id = 0
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ popd
|
|||||||
2. Export to ncnn
|
2. Export to ncnn
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
@ -64,7 +64,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -75,7 +75,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import setup_logger, str2bool
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -121,10 +121,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -244,12 +244,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,7 @@ popd
|
|||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-onnx-zh.py \
|
./pruned_transducer_stateless7_streaming/export-onnx-zh.py \
|
||||||
--lang-dir $repo/data/lang_char_bpe \
|
--tokens $repo/data/lang_char_bpe/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -60,6 +60,7 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import onnx
|
import onnx
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -76,8 +77,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
from icefall.utils import setup_logger, str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -134,10 +134,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_char",
|
default="data/lang_char/tokens.txt",
|
||||||
help="The lang dir",
|
help="The tokens.txt file",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -493,9 +493,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
# Load tokens.txt here
|
||||||
params.blank_id = 0
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -661,7 +666,7 @@ def main():
|
|||||||
quantize_dynamic(
|
quantize_dynamic(
|
||||||
model_input=decoder_filename,
|
model_input=decoder_filename,
|
||||||
model_output=decoder_filename_int8,
|
model_output=decoder_filename_int8,
|
||||||
op_types_to_quantize=["MatMul"],
|
op_types_to_quantize=["MatMul", "Gather"],
|
||||||
weight_type=QuantType.QInt8,
|
weight_type=QuantType.QInt8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -27,7 +27,7 @@ popd
|
|||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -48,8 +48,8 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import onnx
|
import onnx
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
@ -65,7 +65,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import setup_logger, str2bool
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -122,10 +122,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -481,12 +481,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -652,7 +654,7 @@ def main():
|
|||||||
quantize_dynamic(
|
quantize_dynamic(
|
||||||
model_input=decoder_filename,
|
model_input=decoder_filename,
|
||||||
model_output=decoder_filename_int8,
|
model_output=decoder_filename_int8,
|
||||||
op_types_to_quantize=["MatMul"],
|
op_types_to_quantize=["MatMul", "Gather"],
|
||||||
weight_type=QuantType.QInt8,
|
weight_type=QuantType.QInt8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -139,8 +139,8 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
|
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
|
||||||
@ -154,7 +154,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -211,10 +211,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -675,12 +675,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
|
|||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export.py \
|
./pruned_transducer_stateless7_streaming/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ Usage of this script:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -37,7 +37,7 @@ Usage of this script:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -46,7 +46,7 @@ Usage of this script:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -55,7 +55,7 @@ Usage of this script:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -75,7 +75,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -88,7 +87,7 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -106,9 +105,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -225,13 +224,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -286,6 +286,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -297,8 +303,8 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -307,16 +313,16 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -337,12 +343,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -28,7 +28,7 @@ popd
|
|||||||
2. Export to ncnn
|
2. Export to ncnn
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
@ -64,7 +64,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train2 import add_model_arguments, get_params, get_transducer_model
|
from train2 import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -75,7 +75,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import setup_logger, str2bool
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -121,10 +121,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -244,12 +244,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -26,7 +26,7 @@ Usage:
|
|||||||
|
|
||||||
./pruned_transducer_stateless8/export.py \
|
./pruned_transducer_stateless8/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless8/exp \
|
--exp-dir ./pruned_transducer_stateless8/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 9 \
|
--avg 9 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
|
|||||||
|
|
||||||
./pruned_transducer_stateless8/export.py \
|
./pruned_transducer_stateless8/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless8/exp \
|
--exp-dir ./pruned_transducer_stateless8/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
@ -98,7 +98,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -155,10 +155,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -198,12 +198,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
|
|||||||
|
|
||||||
./pruned_transducer_stateless8/export.py \
|
./pruned_transducer_stateless8/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless8/exp \
|
--exp-dir ./pruned_transducer_stateless8/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ Usage of this script:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless8/pretrained.py \
|
./pruned_transducer_stateless8/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -37,7 +37,7 @@ Usage of this script:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless8/pretrained.py \
|
./pruned_transducer_stateless8/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -46,7 +46,7 @@ Usage of this script:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless8/pretrained.py \
|
./pruned_transducer_stateless8/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -55,7 +55,7 @@ Usage of this script:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless8/pretrained.py \
|
./pruned_transducer_stateless8/pretrained.py \
|
||||||
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
|
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -75,7 +75,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -88,7 +87,7 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -106,9 +105,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -225,13 +224,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -286,6 +286,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -297,8 +303,8 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -307,16 +313,16 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -337,12 +343,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./transducer/export.py \
|
./transducer/export.py \
|
||||||
--exp-dir ./transducer/exp \
|
--exp-dir ./transducer/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 34 \
|
--epoch 34 \
|
||||||
--avg 11
|
--avg 11
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
@ -55,7 +55,7 @@ from model import Transducer
|
|||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, str2bool
|
from icefall.utils import AttributeDict, num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -90,10 +90,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -191,12 +191,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,7 @@ Usage:
|
|||||||
|
|
||||||
./transducer/pretrained.py \
|
./transducer/pretrained.py \
|
||||||
--checkpoint ./transducer/exp/pretrained.pt \
|
--checkpoint ./transducer/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav \
|
/path/to/bar.wav \
|
||||||
@ -36,8 +36,8 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import beam_search, greedy_search
|
from beam_search import beam_search, greedy_search
|
||||||
@ -48,7 +48,7 @@ from model import Transducer
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict
|
from icefall.utils import AttributeDict, num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -66,11 +66,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.
|
help="Path to tokens.txt.",
|
||||||
Used only when method is ctc-decoding.
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -204,12 +202,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -257,6 +257,12 @@ def main():
|
|||||||
x=features, x_lens=feature_lengths
|
x=features, x_lens=feature_lengths
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
num_waves = encoder_out.size(0)
|
||||||
hyps = []
|
hyps = []
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
@ -272,12 +278,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./transducer_stateless/export.py \
|
./transducer_stateless/export.py \
|
||||||
--exp-dir ./transducer_stateless/exp \
|
--exp-dir ./transducer_stateless/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
@ -56,7 +56,7 @@ from model import Transducer
|
|||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, str2bool
|
from icefall.utils import AttributeDict, num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -91,10 +91,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -191,12 +191,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame 1 \
|
--max-sym-per-frame 1 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -29,7 +29,7 @@ Usage:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -38,7 +38,7 @@ Usage:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -47,7 +47,7 @@ Usage:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -67,7 +67,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -80,6 +79,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -96,9 +97,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -213,12 +214,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -273,6 +276,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_list = fast_beam_search_one_best(
|
hyp_list = fast_beam_search_one_best(
|
||||||
@ -318,12 +327,11 @@ def main():
|
|||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
hyp_list.append(hyp)
|
hyp_list.append(hyp)
|
||||||
|
|
||||||
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
|
hyps = [token_ids_to_words(hyp) for hyp in hyp_list]
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./transducer_stateless2/export.py \
|
./transducer_stateless2/export.py \
|
||||||
--exp-dir ./transducer_stateless2/exp \
|
--exp-dir ./transducer_stateless2/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -46,12 +46,12 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -86,10 +86,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -123,12 +123,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./transducer_stateless2/pretrained.py \
|
./transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame 1 \
|
--max-sym-per-frame 1 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -29,7 +29,7 @@ Usage:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./transducer_stateless2/pretrained.py \
|
./transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -38,7 +38,7 @@ Usage:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./transducer_stateless2/pretrained.py \
|
./transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -47,7 +47,7 @@ Usage:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./transducer_stateless2/pretrained.py \
|
./transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -67,7 +67,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -80,6 +79,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -96,9 +97,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -213,12 +214,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -273,6 +276,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_list = fast_beam_search_one_best(
|
hyp_list = fast_beam_search_one_best(
|
||||||
@ -318,12 +327,11 @@ def main():
|
|||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
hyp_list.append(hyp)
|
hyp_list.append(hyp)
|
||||||
|
|
||||||
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
|
hyps = [token_ids_to_words(hyp) for hyp in hyp_list]
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./transducer_stateless_multi_datasets/export.py \
|
./transducer_stateless_multi_datasets/export.py \
|
||||||
--exp-dir ./transducer_stateless_multi_datasets/exp \
|
--exp-dir ./transducer_stateless_multi_datasets/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
@ -57,7 +57,7 @@ from model import Transducer
|
|||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, str2bool
|
from icefall.utils import AttributeDict, num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -92,10 +92,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -192,12 +192,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./transducer_stateless_multi_datasets/pretrained.py \
|
./transducer_stateless_multi_datasets/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame 1 \
|
--max-sym-per-frame 1 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -29,7 +29,7 @@ Usage:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./transducer_stateless_multi_datasets/pretrained.py \
|
./transducer_stateless_multi_datasets/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -38,7 +38,7 @@ Usage:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./transducer_stateless_multi_datasets/pretrained.py \
|
./transducer_stateless_multi_datasets/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -47,7 +47,7 @@ Usage:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./transducer_stateless_multi_datasets/pretrained.py \
|
./transducer_stateless_multi_datasets/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -67,7 +67,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -80,6 +79,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -96,9 +97,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -213,12 +214,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -273,6 +276,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_list = fast_beam_search_one_best(
|
hyp_list = fast_beam_search_one_best(
|
||||||
@ -318,12 +327,11 @@ def main():
|
|||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
hyp_list.append(hyp)
|
hyp_list.append(hyp)
|
||||||
|
|
||||||
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
|
hyps = [token_ids_to_words(hyp) for hyp in hyp_list]
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/tokens.txt"
|
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -74,7 +73,6 @@ import onnx
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from export import num_tokens
|
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
@ -86,7 +84,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
|||||||
@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/tokens.txt"
|
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -71,7 +70,6 @@ import onnx
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from export import num_tokens
|
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
@ -83,7 +81,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import make_pad_mask, str2bool
|
from icefall.utils import make_pad_mask, num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
|||||||
@ -160,7 +160,6 @@ with the following commands:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
@ -176,27 +175,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import make_pad_mask, str2bool
|
from icefall.utils import make_pad_mask, num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def num_tokens(
|
|
||||||
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
|
|
||||||
) -> int:
|
|
||||||
"""Return the number of tokens excluding those from
|
|
||||||
disambiguation symbols.
|
|
||||||
|
|
||||||
Caution:
|
|
||||||
0 is not a token ID so it is excluded from the return value.
|
|
||||||
"""
|
|
||||||
symbols = token_table.symbols
|
|
||||||
ans = []
|
|
||||||
for s in symbols:
|
|
||||||
if not disambig_pattern.match(s):
|
|
||||||
ans.append(token_table[s])
|
|
||||||
num_tokens = len(ans)
|
|
||||||
if 0 in ans:
|
|
||||||
num_tokens -= 1
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -487,6 +466,8 @@ def main():
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
else:
|
||||||
assert params.avg > 0, params.avg
|
assert params.avg > 0, params.avg
|
||||||
start = params.epoch - params.avg
|
start = params.epoch - params.avg
|
||||||
|
|||||||
@ -410,10 +410,20 @@ def main():
|
|||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
if params.method == "ctc-decoding":
|
||||||
words = " ".join(hyp)
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = words.replace("▁", " ").strip()
|
words = "".join(hyp)
|
||||||
s += f"{filename}:\n{words}\n\n"
|
words = words.replace("▁", " ").strip()
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
elif params.method in [
|
||||||
|
"1best",
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
]:
|
||||||
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
|
words = " ".join(hyp)
|
||||||
|
words = words.replace("▁", " ").strip()
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -33,7 +33,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/tokens.txt"
|
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
|
|||||||
@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -29,7 +28,7 @@ popd
|
|||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./zipformer/export-onnx-streaming.py \
|
./zipformer/export-onnx-streaming.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
|
|||||||
@ -31,7 +31,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/tokens.txt"
|
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
|
|||||||
@ -274,7 +274,7 @@ def main():
|
|||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
params.vocab_size = num_tokens(token_table)
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for blank
|
||||||
params.blank_id = token_table["<blk>"]
|
params.blank_id = token_table["<blk>"]
|
||||||
assert params.blank_id == 0
|
assert params.blank_id == 0
|
||||||
|
|
||||||
@ -429,10 +429,20 @@ def main():
|
|||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
if params.method == "ctc-decoding":
|
||||||
words = " ".join(hyp)
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = words.replace("▁", " ").strip()
|
words = "".join(hyp)
|
||||||
s += f"{filename}:\n{words}\n\n"
|
words = words.replace("▁", " ").strip()
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
elif params.method in [
|
||||||
|
"1best",
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
]:
|
||||||
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
|
words = " ".join(hyp)
|
||||||
|
words = words.replace("▁", " ").strip()
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -26,7 +26,7 @@ Usage:
|
|||||||
|
|
||||||
./zipformer_mmi/export.py \
|
./zipformer_mmi/export.py \
|
||||||
--exp-dir ./zipformer_mmi/exp \
|
--exp-dir ./zipformer_mmi/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 9 \
|
--avg 9 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
|
|||||||
|
|
||||||
./zipformer_mmi/export.py \
|
./zipformer_mmi/export.py \
|
||||||
--exp-dir ./zipformer_mmi/exp \
|
--exp-dir ./zipformer_mmi/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_ctc_model, get_params
|
from train import add_model_arguments, get_ctc_model, get_params
|
||||||
@ -97,7 +97,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -154,10 +154,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -190,12 +190,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -21,7 +21,7 @@ You can generate the checkpoint with the following command:
|
|||||||
|
|
||||||
./zipformer_mmi/export.py \
|
./zipformer_mmi/export.py \
|
||||||
--exp-dir ./zipformer_mmi/exp \
|
--exp-dir ./zipformer_mmi/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -30,14 +30,14 @@ Usage of this script:
|
|||||||
(1) 1best
|
(1) 1best
|
||||||
./zipformer_mmi/pretrained.py \
|
./zipformer_mmi/pretrained.py \
|
||||||
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method 1best \
|
--method 1best \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
(2) nbest
|
(2) nbest
|
||||||
./zipformer_mmi/pretrained.py \
|
./zipformer_mmi/pretrained.py \
|
||||||
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--nbest-scale 1.2 \
|
--nbest-scale 1.2 \
|
||||||
--method nbest \
|
--method nbest \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -45,7 +45,7 @@ Usage of this script:
|
|||||||
(3) nbest-rescoring-LG
|
(3) nbest-rescoring-LG
|
||||||
./zipformer_mmi/pretrained.py \
|
./zipformer_mmi/pretrained.py \
|
||||||
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--nbest-scale 1.2 \
|
--nbest-scale 1.2 \
|
||||||
--method nbest-rescoring-LG \
|
--method nbest-rescoring-LG \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -53,7 +53,7 @@ Usage of this script:
|
|||||||
(4) nbest-rescoring-3-gram
|
(4) nbest-rescoring-3-gram
|
||||||
./zipformer_mmi/pretrained.py \
|
./zipformer_mmi/pretrained.py \
|
||||||
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--nbest-scale 1.2 \
|
--nbest-scale 1.2 \
|
||||||
--method nbest-rescoring-3-gram \
|
--method nbest-rescoring-3-gram \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -61,7 +61,7 @@ Usage of this script:
|
|||||||
(5) nbest-rescoring-4-gram
|
(5) nbest-rescoring-4-gram
|
||||||
./zipformer_mmi/pretrained.py \
|
./zipformer_mmi/pretrained.py \
|
||||||
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--nbest-scale 1.2 \
|
--nbest-scale 1.2 \
|
||||||
--method nbest-rescoring-4-gram \
|
--method nbest-rescoring-4-gram \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -83,7 +83,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from decode import get_decoding_params
|
from decode import get_decoding_params
|
||||||
@ -97,7 +96,7 @@ from icefall.decode import (
|
|||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
)
|
)
|
||||||
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||||
from icefall.utils import get_texts
|
from icefall.utils import get_texts, num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -115,9 +114,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -247,13 +246,14 @@ def main():
|
|||||||
params.update(get_decoding_params())
|
params.update(get_decoding_params())
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -298,8 +298,6 @@ def main():
|
|||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
|
||||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
|
||||||
mmi_graph_compiler = MmiTrainingGraphCompiler(
|
mmi_graph_compiler = MmiTrainingGraphCompiler(
|
||||||
params.lang_dir,
|
params.lang_dir,
|
||||||
uniq_filename="lexicon.txt",
|
uniq_filename="lexicon.txt",
|
||||||
@ -313,6 +311,12 @@ def main():
|
|||||||
if not hasattr(HP, "lm_scores"):
|
if not hasattr(HP, "lm_scores"):
|
||||||
HP.lm_scores = HP.scores.clone()
|
HP.lm_scores = HP.scores.clone()
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
method = params.method
|
method = params.method
|
||||||
assert method in (
|
assert method in (
|
||||||
"1best",
|
"1best",
|
||||||
@ -390,14 +394,11 @@ def main():
|
|||||||
#
|
#
|
||||||
# token_ids is a lit-of-list of IDs
|
# token_ids is a lit-of-list of IDs
|
||||||
token_ids = get_texts(best_path)
|
token_ids = get_texts(best_path)
|
||||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
hyps = [token_ids_to_words(ids) for ids in token_ids]
|
||||||
hyps = bpe_model.decode(token_ids)
|
|
||||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
|
||||||
hyps = [s.split() for s in hyps]
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
|||||||
@ -498,7 +498,7 @@ def main():
|
|||||||
quantize_dynamic(
|
quantize_dynamic(
|
||||||
model_input=decoder_filename,
|
model_input=decoder_filename,
|
||||||
model_output=decoder_filename_int8,
|
model_output=decoder_filename_int8,
|
||||||
op_types_to_quantize=["MatMul"],
|
op_types_to_quantize=["MatMul", "Gather"],
|
||||||
weight_type=QuantType.QInt8,
|
weight_type=QuantType.QInt8,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -320,7 +320,7 @@ def main():
|
|||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
words = "".join(hyp)
|
||||||
s += f"{filename}:\n{words}\n\n"
|
s += f"{filename}:\n{words}\n\n"
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user