Use tokens.txt to replace bpe.model (#1162)

This commit is contained in:
zr_jin 2023-08-12 16:53:59 +08:00 committed by GitHub
parent d6b28a11a7
commit a81396b482
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
99 changed files with 1243 additions and 1623 deletions

View File

@ -38,7 +38,7 @@ log "Decode with models exported by torch.jit.trace()"
for m in ctc-decoding 1best; do
./conformer_ctc3/jit_pretrained.py \
--model-filename $repo/exp/jit_trace.pt \
--words-file $repo/data/lang_bpe_500/words.txt \
--words-file $repo/data/lang_bpe_500/words.txt \
--HLG $repo/data/lang_bpe_500/HLG.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--G $repo/data/lm/G_4_gram.pt \
@ -53,7 +53,7 @@ log "Export to torchscript model"
./conformer_ctc3/export.py \
--exp-dir $repo/exp \
--lang-dir $repo/data/lang_bpe_500 \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--jit-trace 1 \
--epoch 99 \
--avg 1 \
@ -80,9 +80,9 @@ done
for m in ctc-decoding 1best; do
./conformer_ctc3/pretrained.py \
--checkpoint $repo/exp/pretrained.pt \
--words-file $repo/data/lang_bpe_500/words.txt \
--words-file $repo/data/lang_bpe_500/words.txt \
--HLG $repo/data/lang_bpe_500/HLG.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--G $repo/data/lm/G_4_gram.pt \
--method $m \
--sample-rate 16000 \
@ -93,7 +93,7 @@ done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p conformer_ctc3/exp
ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/

View File

@ -31,7 +31,7 @@ log "Test exporting with torch.jit.trace()"
./lstm_transducer_stateless2/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
@ -55,7 +55,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -68,7 +68,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -28,7 +28,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -36,7 +36,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -49,7 +49,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -35,7 +35,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -48,7 +48,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -30,14 +30,14 @@ popd
log "Export to torchscript model"
./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit-trace 1
@ -74,7 +74,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -32,7 +32,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--num-encoder-layers 18 \
--dim-feedforward 2048 \
--nhead 8 \
@ -51,7 +51,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav \

View File

@ -33,7 +33,7 @@ log "Export to torchscript model"
./pruned_transducer_stateless7/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
@ -56,7 +56,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -69,7 +69,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -37,7 +37,7 @@ log "Export to torchscript model"
./pruned_transducer_stateless7_ctc/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
@ -74,7 +74,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -36,7 +36,7 @@ log "Export to torchscript model"
./pruned_transducer_stateless7_ctc_bs/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
@ -72,7 +72,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -85,7 +85,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -37,7 +37,7 @@ log "Export to torchscript model"
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--decode-chunk-len 32 \
--epoch 99 \
--avg 1 \
@ -81,7 +81,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--decode-chunk-len 32 \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
@ -95,7 +95,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--decode-chunk-len 32 \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \

View File

@ -41,7 +41,7 @@ log "Decode with models exported by torch.jit.script()"
log "Export to torchscript model"
./pruned_transducer_stateless8/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model false \
--epoch 99 \
--avg 1 \
@ -65,7 +65,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -78,7 +78,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -32,7 +32,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--simulate-streaming 1 \
--causal-convolution 1 \
$repo/test_wavs/1089-134686-0001.wav \
@ -47,7 +47,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--simulate-streaming 1 \
--causal-convolution 1 \
$repo/test_wavs/1089-134686-0001.wav \

View File

@ -28,7 +28,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -37,7 +37,7 @@ log "Export to torchscript model"
./zipformer_mmi/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
@ -61,7 +61,7 @@ for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescor
--method $method \
--checkpoint $repo/exp/pretrained.pt \
--lang-dir $repo/data/lang_bpe_500 \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -27,7 +27,7 @@ log "CTC decoding"
--method ctc-decoding \
--num-classes 500 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.flac \
$repo/test_wavs/1221-135766-0001.flac \
$repo/test_wavs/1221-135766-0002.flac
@ -38,7 +38,7 @@ log "HLG decoding"
--method 1best \
--num-classes 500 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--words-file $repo/data/lang_bpe_500/words.txt \
--HLG $repo/data/lang_bpe_500/HLG.pt \
$repo/test_wavs/1089-134686-0001.flac \

View File

@ -28,7 +28,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -28,7 +28,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -28,7 +28,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -27,7 +27,7 @@ log "Beam search decoding"
--method beam_search \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -17,7 +17,6 @@ git lfs install
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
ls -lh $repo/test_wavs/*.wav
@ -29,12 +28,11 @@ popd
log "Test exporting to ONNX format"
./pruned_transducer_stateless2/export.py \
./pruned_transducer_stateless2/export-onnx.py \
--exp-dir $repo/exp \
--lang-dir $repo/data/lang_char \
--epoch 99 \
--avg 1 \
--onnx 1
--avg 1
log "Export to torchscript model"
@ -59,19 +57,17 @@ log "Decode with ONNX models"
./pruned_transducer_stateless2/onnx_check.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-encoder-filename $repo/exp/encoder.onnx \
--onnx-decoder-filename $repo/exp/decoder.onnx \
--onnx-joiner-filename $repo/exp/joiner.onnx \
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx
--onnx-encoder-filename $repo/exp/encoder-epoch-10-avg-2.onnx \
--onnx-decoder-filename $repo/exp/decoder-epoch-10-avg-2.onnx \
--onnx-joiner-filename $repo/exp/joiner-epoch-10-avg-2.onnx \
--onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \
--onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj-epoch-10-avg-2.onnx
./pruned_transducer_stateless2/onnx_pretrained.py \
--tokens $repo/data/lang_char/tokens.txt \
--encoder-model-filename $repo/exp/encoder.onnx \
--decoder-model-filename $repo/exp/decoder.onnx \
--joiner-model-filename $repo/exp/joiner.onnx \
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
@ -104,9 +100,9 @@ for sym in 1 2 3; do
--lang-dir $repo/data/lang_char \
--decoding-method greedy_search \
--max-sym-per-frame $sym \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
done
for method in modified_beam_search beam_search fast_beam_search; do
@ -117,7 +113,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--beam-size 4 \
--checkpoint $repo/exp/epoch-99.pt \
--lang-dir $repo/data/lang_char \
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
done

View File

@ -45,7 +45,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
cd exp
@ -56,11 +55,10 @@ log "Export via torch.jit.trace()"
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
\
--tokens $repo/data/lang_bpe_500/tokens.txt \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
@ -91,7 +89,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
cd exp
@ -102,7 +99,7 @@ log "Export via torch.jit.trace()"
./lstm_transducer_stateless2/export-for-ncnn.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--use-averaged-model 0
@ -140,7 +137,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
@ -148,7 +144,7 @@ ln -s pretrained.pt epoch-99.pt
popd
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
@ -199,7 +195,7 @@ ln -s pretrained.pt epoch-9999.pt
popd
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
--lang-dir $repo/data/lang_char_bpe \
--tokens $repo/data/lang_char_bpe/tokens.txt \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 9999 \

View File

@ -10,7 +10,123 @@ log() {
cd egs/librispeech/ASR
log "=========================================================================="
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
log "Downloading pre-trained model from $repo_url"
git lfs install
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
log "Export via torch.jit.script()"
./zipformer/export.py \
--exp-dir $repo/exp \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
log "Test export to ONNX format"
./zipformer/export-onnx.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal False \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1"
ls -lh $repo/exp
log "Run onnx_check.py"
./zipformer/onnx_check.py \
--jit-filename $repo/exp/jit_script.pt \
--onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx
log "Run onnx_pretrained.py"
./zipformer/onnx_pretrained.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav
rm -rf $repo
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
log "Downloading pre-trained model from $repo_url"
git lfs install
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
log "Test export streaming model to ONNX format"
./zipformer/export-onnx-streaming.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
--left-context-frames 64
ls -lh $repo/exp
log "Run onnx_pretrained-streaming.py"
./zipformer/onnx_pretrained-streaming.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav
rm -rf $repo
log "--------------------------------------------------------------------------"
log "=========================================================================="
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
@ -39,7 +155,7 @@ log "Export via torch.jit.trace()"
log "Test exporting to ONNX format"
./pruned_transducer_stateless7_streaming/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -88,7 +204,7 @@ popd
log "Export via torch.jit.script()"
./pruned_transducer_stateless3/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 9999 \
--avg 1 \
--exp-dir $repo/exp/ \
@ -97,7 +213,7 @@ log "Export via torch.jit.script()"
log "Test exporting to ONNX format"
./pruned_transducer_stateless3/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 9999 \
--avg 1 \
--exp-dir $repo/exp/
@ -126,7 +242,6 @@ log "Run onnx_pretrained.py"
rm -rf $repo
log "--------------------------------------------------------------------------"
log "=========================================================================="
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
@ -143,7 +258,7 @@ popd
log "Export via torch.jit.script()"
./pruned_transducer_stateless5/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
@ -159,7 +274,7 @@ log "Export via torch.jit.script()"
log "Test exporting to ONNX format"
./pruned_transducer_stateless5/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
@ -205,7 +320,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
@ -215,7 +329,7 @@ popd
log "Export via torch.jit.script()"
./pruned_transducer_stateless7/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -226,7 +340,7 @@ log "Export via torch.jit.script()"
log "Test exporting to ONNX format"
./pruned_transducer_stateless7/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -270,7 +384,7 @@ popd
log "Test exporting to ONNX format"
./conv_emformer_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -310,7 +424,7 @@ popd
log "Export via torch.jit.trace()"
./lstm_transducer_stateless2/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -320,7 +434,7 @@ log "Export via torch.jit.trace()"
log "Test exporting to ONNX format"
./lstm_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \

View File

@ -1,321 +0,0 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.script()
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--lang-dir data/lang_char \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("cpu_jit.pt")`.
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
Check
https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--lang-dir data/lang_char \
--epoch 20 \
--avg 10
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `pruned_transducer_stateless7/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./pruned_transducer_stateless7/decode.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--lang-dir data/lang_char
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21
# You will find the pre-trained model in icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21exp
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named cpu_jit.pt
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/export.py

View File

@ -1,348 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--lang-dir data/lang_char \
--epoch 20 \
--avg 10
Usage of this script:
(1) greedy search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by
./pruned_transducer_stateless7/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = lexicon.token_table
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp_tokens = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp_tokens = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py

View File

@ -23,12 +23,13 @@ import argparse
import logging
from pathlib import Path
import k2
import torch
from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser():
@ -63,11 +64,10 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_bpe_500",
help="""It contains language related input files such as "lexicon.txt"
""",
required=True,
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -98,16 +98,16 @@ def get_params() -> AttributeDict:
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
num_classes = num_tokens(token_table) + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():

View File

@ -24,7 +24,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from conformer import Conformer
@ -70,11 +69,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -83,10 +80,9 @@ def get_parser():
default="1best",
help="""Decoding method.
Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a sentence
piece model, i.e., lang_dir/bpe.model, to convert
word pieces to words. It needs neither a lexicon
nor an n-gram LM.
(0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file
to convert tokens to actual words or characters. It needs
neither a lexicon nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding.
We call it HLG decoding.
@ -297,6 +293,7 @@ def main():
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
hyps = []
features = fbank(waves)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
@ -313,10 +310,17 @@ def main():
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.bpe_model)
max_token_id = params.num_classes - 1
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
H = k2.ctc_topo(
max_token=max_token_id,
modified=params.num_classes > 500,
@ -337,9 +341,9 @@ def main():
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
token_ids = get_texts(best_path)
hyps = bpe_model.decode(token_ids)
hyps = [s.split() for s in hyps]
hyp_tokens = get_texts(best_path)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method in [
"1best",
"whole-lattice-rescoring",
@ -408,16 +412,16 @@ def main():
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
hyp_tokens = get_texts(best_path)
for hyp in hyp_tokens:
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -23,6 +23,7 @@
Usage:
./conformer_ctc2/export.py \
--exp-dir ./conformer_ctc2/exp \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -46,6 +47,7 @@ import argparse
import logging
from pathlib import Path
import k2
import torch
from conformer import Conformer
from decode import get_params
@ -56,8 +58,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -123,10 +124,10 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_bpe_500",
help="The lang dir",
required=True,
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -143,14 +144,14 @@ def get_parser():
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
num_classes = num_tokens(token_table) + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():

View File

@ -25,7 +25,7 @@ Usage:
./conformer_ctc3/export.py \
--exp-dir ./conformer_ctc3/exp \
--lang-dir data/lang_bpe_500 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10 \
--jit-trace 1
@ -36,7 +36,7 @@ It will generates the file: `jit_trace.pt`.
./conformer_ctc3/export.py \
--exp-dir ./conformer_ctc3/exp \
--lang-dir data/lang_bpe_500 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -62,6 +62,7 @@ import argparse
import logging
from pathlib import Path
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_ctc_model, get_params
@ -72,8 +73,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -130,10 +130,10 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
"--tokens",
type=str,
required=True,
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -171,9 +171,10 @@ def main():
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
num_classes = num_tokens(token_table) + 1 # +1 for the blank
params.vocab_size = num_classes
if params.streaming_model:

View File

@ -24,7 +24,7 @@ Usage (for non-streaming mode):
(1) ctc-decoding
./conformer_ctc3/pretrained.py \
--checkpoint conformer_ctc3/exp/pretrained.pt \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method ctc-decoding \
--sample-rate 16000 \
test_wavs/1089-134686-0001.wav
@ -67,7 +67,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from decode import get_decoding_params
@ -114,11 +113,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -127,10 +124,9 @@ def get_parser():
default="1best",
help="""Decoding method.
Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a sentence
piece model, i.e., lang_dir/bpe.model, to convert
word pieces to words. It needs neither a lexicon
nor an n-gram LM.
(0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file
to convert tokens to actual words or characters. It needs
neither a lexicon nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding.
We call it HLG decoding.
@ -316,6 +312,7 @@ def main():
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
hyps = []
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
@ -348,10 +345,17 @@ def main():
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.bpe_model)
max_token_id = params.num_classes - 1
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
@ -372,9 +376,9 @@ def main():
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
token_ids = get_texts(best_path)
hyps = bpe_model.decode(token_ids)
hyps = [s.split() for s in hyps]
hyp_tokens = get_texts(best_path)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method in [
"1best",
"nbest-rescoring",
@ -439,16 +443,16 @@ def main():
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
hyp_tokens = get_texts(best_path)
for hyp in hyp_tokens:
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage:
./conv_emformer_transducer_stateless/export.py \
--exp-dir ./conv_emformer_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 10 \
--use-averaged-model=True \
@ -62,7 +62,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model
@ -72,7 +72,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -118,10 +118,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
required=True,
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -166,12 +166,12 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -8,7 +8,7 @@ for more details about how to use this file.
Usage:
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
--exp-dir ./conv_emformer_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 10 \
--use-averaged-model=True \
@ -37,7 +37,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model
@ -48,7 +48,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -94,10 +94,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
required=True,
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -217,12 +217,12 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -18,7 +18,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
cd exp
@ -28,7 +27,7 @@ popd
2. Export the model to ONNX
./conv_emformer_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -55,14 +54,14 @@ import logging
from pathlib import Path
from typing import Dict, Tuple
import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
from emformer import Emformer
from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model
from emformer import Emformer
from icefall.checkpoint import (
average_checkpoints,
@ -70,7 +69,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -127,10 +126,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
required=True,
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -484,12 +483,12 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -22,7 +22,7 @@
Usage:
./conv_emformer_transducer_stateless2/export.py \
--exp-dir ./conv_emformer_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 10 \
--use-averaged-model=True \
@ -62,7 +62,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -73,7 +73,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -119,10 +119,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
required=True,
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -167,12 +167,12 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX
./conv_emformer_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \

View File

@ -26,7 +26,7 @@ Usage:
./lstm_transducer_stateless/export.py \
--exp-dir ./lstm_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 35 \
--avg 10 \
--jit-trace 1
@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`,
./lstm_transducer_stateless/export.py \
--exp-dir ./lstm_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 35 \
--avg 10
@ -79,7 +79,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
@ -91,7 +91,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -148,10 +148,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -266,12 +266,13 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size, <blk> is
# defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage:
(4) fast beam search
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -66,7 +66,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -79,6 +78,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -95,9 +96,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -214,13 +215,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -275,6 +277,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -286,8 +294,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -296,16 +304,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -326,12 +334,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -29,7 +29,7 @@ popd
./lstm_transducer_stateless2/export-for-ncnn.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
@ -49,7 +49,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -60,7 +60,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -106,10 +106,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -221,12 +221,13 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size, <blk> is
# defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -613,7 +613,7 @@ def main():
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX
./lstm_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -52,8 +52,8 @@ import logging
from pathlib import Path
from typing import Dict, Optional, Tuple
import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
@ -68,7 +68,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -125,10 +125,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -437,12 +437,13 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size, <blk> is
# defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)
@ -607,7 +608,7 @@ def main():
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)

View File

@ -27,7 +27,7 @@ Usage:
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 35 \
--avg 10 \
--jit-trace 1
@ -39,7 +39,7 @@ It will generate 3 files: `encoder_jit_trace.pt`,
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 35 \
--avg 10
@ -80,7 +80,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
@ -92,7 +92,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -149,10 +149,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -267,12 +267,13 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size, <blk> is
# defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search
./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search
./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage:
(4) fast beam search
./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -69,7 +69,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -82,6 +81,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -98,9 +99,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -217,13 +218,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -278,6 +280,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -289,8 +297,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -299,16 +307,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -329,12 +337,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -26,7 +26,7 @@ Usage:
./lstm_transducer_stateless3/export.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 40 \
--avg 20 \
--jit-trace 1
@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`,
./lstm_transducer_stateless3/export.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 40 \
--avg 20
@ -79,7 +79,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
@ -91,7 +91,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -148,10 +148,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to tokens.txt.",
)
parser.add_argument(
@ -266,12 +266,13 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size, <blk> is
# defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -79,6 +79,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -95,9 +97,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -214,13 +216,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -275,6 +278,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -286,8 +295,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -296,16 +305,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -326,12 +335,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage:
./prunted_stateless_emformer_rnnt/export.py \
--exp-dir ./prunted_stateless_emformer_rnnt/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -48,7 +48,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model
@ -58,7 +58,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -115,10 +115,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -154,13 +154,12 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -508,7 +508,7 @@ def main():
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)

View File

@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless/export.py \
--exp-dir ./pruned_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -47,12 +47,12 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -87,10 +87,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -135,13 +135,13 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size, <blk> is
# defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.streaming_model:
assert params.causal_convolution

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage:
(4) fast beam search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -66,7 +66,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -79,7 +78,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -97,9 +96,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -237,13 +236,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.simulate_streaming:
assert (
@ -314,6 +314,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -325,8 +331,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -335,16 +341,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -365,12 +371,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -47,12 +47,12 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -98,10 +98,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -145,12 +145,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.streaming_model:
assert params.causal_convolution

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage:
(4) fast beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -66,7 +66,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -79,7 +78,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -97,9 +96,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -238,13 +237,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.simulate_streaming:
assert (
@ -315,6 +315,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -326,8 +332,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -336,16 +342,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -366,12 +372,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX
./pruned_transducer_stateless3/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 9999 \
--avg 1 \
--exp-dir $repo/exp/
@ -48,8 +48,8 @@ import logging
from pathlib import Path
from typing import Dict, Tuple
import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
@ -59,7 +59,7 @@ from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import setup_logger
from icefall.utils import num_tokens, setup_logger
def get_parser():
@ -105,10 +105,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -393,12 +393,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)
@ -518,7 +520,7 @@ def main():
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)

View File

@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10 \
--jit 1
@ -44,7 +44,7 @@ It will also generate 3 other files: `encoder_jit_script.pt`,
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10 \
--jit-trace 1
@ -56,7 +56,7 @@ It will generates 3 files: `encoder_jit_trace.pt`,
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -97,14 +97,14 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -150,10 +150,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
@ -342,12 +342,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.streaming_model:
assert params.causal_convolution

View File

@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -29,7 +29,7 @@ Usage of this script:
(1) greedy search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -37,7 +37,7 @@ Usage of this script:
(2) beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage of this script:
(3) modified beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -55,7 +55,7 @@ Usage of this script:
(4) fast beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -75,7 +75,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -88,7 +87,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -106,9 +105,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -247,13 +246,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.simulate_streaming:
assert (
@ -324,6 +324,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -335,8 +341,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -345,16 +351,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -375,12 +381,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless4/export.py \
--exp-dir ./pruned_transducer_stateless4/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -48,7 +48,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -59,7 +59,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -116,10 +116,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -164,12 +164,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.streaming_model:
assert params.causal_convolution

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX
./pruned_transducer_stateless5/export-onnx-streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
@ -58,13 +58,13 @@ import logging
from pathlib import Path
from typing import Dict, Tuple
import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
from onnxruntime.quantization import QuantType, quantize_dynamic
from decoder import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -74,7 +74,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -131,10 +131,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -489,12 +489,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)
@ -662,7 +664,7 @@ def main():
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX
./pruned_transducer_stateless5/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
@ -55,13 +55,13 @@ import logging
from pathlib import Path
from typing import Dict, Tuple
import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
from onnxruntime.quantization import QuantType, quantize_dynamic
from decoder import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -71,7 +71,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -128,10 +128,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -416,12 +416,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)
@ -586,7 +588,7 @@ def main():
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)

View File

@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -48,7 +48,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -59,7 +59,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -116,10 +116,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -164,12 +164,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.streaming_model:
assert params.causal_convolution

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage:
(4) fast beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -66,7 +66,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -79,6 +78,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -95,9 +96,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -214,13 +215,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -275,6 +277,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -286,8 +294,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -296,16 +304,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -326,12 +334,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless6/export.py \
--exp-dir ./pruned_transducer_stateless6/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -47,12 +47,12 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -98,10 +98,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -135,12 +135,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang
# Zengrui Jin)
"""
This script exports a transducer model from PyTorch to ONNX.
@ -18,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt"
cd exp
@ -28,7 +28,7 @@ popd
2. Export the model to ONNX
./pruned_transducer_stateless7/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -50,8 +50,8 @@ import logging
from pathlib import Path
from typing import Dict, Tuple
import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
@ -66,7 +66,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -123,10 +123,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -411,12 +410,12 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)
@ -581,7 +580,7 @@ def main():
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
# Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -26,7 +27,7 @@ Usage:
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@ -45,7 +46,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -65,7 +66,7 @@ you can do:
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
--tokens data/lang_bpe_500/tokens.txt \
Check ./pretrained.py for its usage.
@ -86,7 +87,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
@ -98,7 +99,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -155,10 +156,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -198,12 +198,12 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)
@ -292,7 +292,7 @@ def main():
model.to("cpu")
model.eval()
if params.jit is True:
if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -20,7 +21,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -29,7 +30,7 @@ Usage of this script:
(1) greedy search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -37,7 +38,7 @@ Usage of this script:
(2) beam search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +47,7 @@ Usage of this script:
(3) modified beam search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -55,7 +56,7 @@ Usage of this script:
(4) fast beam search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -75,7 +76,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -88,7 +88,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
from icefall.utils import num_tokens
def get_parser():
@ -106,9 +106,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -225,13 +225,13 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -286,6 +286,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -297,8 +303,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -307,16 +313,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -337,12 +343,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless7_ctc/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless7_ctc/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -86,7 +86,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -97,7 +97,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -154,10 +154,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -197,12 +197,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless7_ctc/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -29,7 +29,7 @@ Usage of this script:
(1) greedy search
./pruned_transducer_stateless7_ctc/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -37,7 +37,7 @@ Usage of this script:
(2) beam search
./pruned_transducer_stateless7_ctc/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage of this script:
(3) modified beam search
./pruned_transducer_stateless7_ctc/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -55,7 +55,7 @@ Usage of this script:
(4) fast beam search
./pruned_transducer_stateless7_ctc/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -75,7 +75,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -88,6 +87,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -104,9 +105,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -223,13 +224,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -284,6 +286,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -295,8 +303,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -305,16 +313,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -335,12 +343,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,14 +22,14 @@ You can use the following command to get the exported models:
./pruned_transducer_stateless7_ctc/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
Usage of this script:
(1) ctc-decoding
./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--bpe-model data/lang_bpe_500/bpe.model \
--method ctc-decoding \
@ -38,7 +38,7 @@ Usage of this script:
/path/to/bar.wav
(2) 1best
./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
@ -48,7 +48,7 @@ Usage of this script:
/path/to/bar.wav
(3) nbest-rescoring
./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
./bruned_transducer_stateless7_ctc/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
@ -60,7 +60,7 @@ Usage of this script:
(4) whole-lattice-rescoring
./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \

View File

@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless7_ctc_bs/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 13 \
--jit 1
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless7_ctc_bs/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 13
@ -86,7 +86,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -97,7 +97,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -154,10 +154,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -197,12 +197,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -28,7 +28,7 @@ Usage:
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 13 \
--onnx 1
@ -48,7 +48,7 @@ Check `onnx_check.py` for how to use them.
(2) Export to ONNX format which can be used in Triton Server
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 13 \
--onnx-triton 1
@ -86,9 +86,10 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -98,8 +99,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -156,10 +156,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -728,12 +728,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless7_ctc_bs/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 13
@ -29,7 +29,7 @@ Usage of this script:
(1) greedy search
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -37,7 +37,7 @@ Usage of this script:
(2) beam search
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage of this script:
(3) modified beam search
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -55,7 +55,7 @@ Usage of this script:
(4) fast beam search
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -75,7 +75,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -88,6 +87,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -104,9 +105,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -223,13 +224,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -284,6 +286,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -295,8 +303,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -305,16 +313,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -335,12 +343,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,14 +22,14 @@ You can use the following command to get the exported models:
./pruned_transducer_stateless7_ctc_bs/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
Usage of this script:
(1) ctc-decoding
./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \
./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--bpe-model data/lang_bpe_500/bpe.model \
--method ctc-decoding \
@ -38,7 +38,7 @@ Usage of this script:
/path/to/bar.wav
(2) 1best
./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \
./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
@ -48,7 +48,7 @@ Usage of this script:
/path/to/bar.wav
(3) nbest-rescoring
./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
./bruned_transducer_stateless7_ctc/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
@ -60,7 +60,7 @@ Usage of this script:
(4) whole-lattice-rescoring
./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \
./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \

View File

@ -66,6 +66,7 @@ import argparse
import logging
from pathlib import Path
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model
@ -76,8 +77,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -123,10 +123,10 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="The lang dir",
default="data/lang_char/tokens.txt",
help="The tokens.txt file",
)
parser.add_argument(
@ -246,9 +246,14 @@ def main():
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -28,7 +28,7 @@ popd
2. Export to ncnn
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
@ -64,7 +64,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model
@ -75,7 +75,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -121,10 +121,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -244,12 +244,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -29,7 +29,7 @@ popd
2. Export the model to ONNX
./pruned_transducer_stateless7_streaming/export-onnx-zh.py \
--lang-dir $repo/data/lang_char_bpe \
--tokens $repo/data/lang_char_bpe/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -60,6 +60,7 @@ import logging
from pathlib import Path
from typing import Dict, List, Tuple
import k2
import onnx
import torch
import torch.nn as nn
@ -76,8 +77,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -134,10 +134,10 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="The lang dir",
default="data/lang_char/tokens.txt",
help="The tokens.txt file",
)
parser.add_argument(
@ -493,9 +493,14 @@ def main():
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)
@ -661,7 +666,7 @@ def main():
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)

View File

@ -27,7 +27,7 @@ popd
2. Export the model to ONNX
./pruned_transducer_stateless7_streaming/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -48,8 +48,8 @@ import logging
from pathlib import Path
from typing import Dict, List, Tuple
import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
@ -65,7 +65,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -122,10 +122,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -481,12 +481,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)
@ -652,7 +654,7 @@ def main():
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)

View File

@ -139,8 +139,8 @@ import argparse
import logging
from pathlib import Path
import k2
import onnxruntime
import sentencepiece as spm
import torch
import torch.nn as nn
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
@ -154,7 +154,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -211,10 +211,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
@ -675,12 +675,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -29,7 +29,7 @@ Usage of this script:
(1) greedy search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -37,7 +37,7 @@ Usage of this script:
(2) beam search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage of this script:
(3) modified beam search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -55,7 +55,7 @@ Usage of this script:
(4) fast beam search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -75,7 +75,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -88,7 +87,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -106,9 +105,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -225,13 +224,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -286,6 +286,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -297,8 +303,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -307,16 +313,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -337,12 +343,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -28,7 +28,7 @@ popd
2. Export to ncnn
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
@ -64,7 +64,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model
@ -75,7 +75,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -121,10 +121,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -244,12 +244,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless8/export.py \
--exp-dir ./pruned_transducer_stateless8/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless8/export.py \
--exp-dir ./pruned_transducer_stateless8/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -86,7 +86,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
@ -98,7 +98,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -155,10 +155,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -198,12 +198,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless8/export.py \
--exp-dir ./pruned_transducer_stateless8/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -29,7 +29,7 @@ Usage of this script:
(1) greedy search
./pruned_transducer_stateless8/pretrained.py \
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -37,7 +37,7 @@ Usage of this script:
(2) beam search
./pruned_transducer_stateless8/pretrained.py \
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage of this script:
(3) modified beam search
./pruned_transducer_stateless8/pretrained.py \
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -55,7 +55,7 @@ Usage of this script:
(4) fast beam search
./pruned_transducer_stateless8/pretrained.py \
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -75,7 +75,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -88,7 +87,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -106,9 +105,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -225,13 +224,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -286,6 +286,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -297,8 +303,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -307,16 +313,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -337,12 +343,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage:
./transducer/export.py \
--exp-dir ./transducer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 34 \
--avg 11
@ -46,7 +46,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from conformer import Conformer
from decoder import Decoder
@ -55,7 +55,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import AttributeDict, str2bool
from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser():
@ -90,10 +90,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -191,12 +191,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -19,7 +19,7 @@ Usage:
./transducer/pretrained.py \
--checkpoint ./transducer/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav \
@ -36,8 +36,8 @@ import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import beam_search, greedy_search
@ -48,7 +48,7 @@ from model import Transducer
from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.utils import AttributeDict
from icefall.utils import AttributeDict, num_tokens
def get_parser():
@ -66,11 +66,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
help="Path to tokens.txt.",
)
parser.add_argument(
@ -204,12 +202,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -257,6 +257,12 @@ def main():
x=features, x_lens=feature_lengths
)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
num_waves = encoder_out.size(0)
hyps = []
for i in range(num_waves):
@ -272,12 +278,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage:
./transducer_stateless/export.py \
--exp-dir ./transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -46,7 +46,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from conformer import Conformer
@ -56,7 +56,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import AttributeDict, str2bool
from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser():
@ -91,10 +91,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -191,12 +191,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
@ -29,7 +29,7 @@ Usage:
(2) beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -38,7 +38,7 @@ Usage:
(3) modified beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -47,7 +47,7 @@ Usage:
(4) fast beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -67,7 +67,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -80,6 +79,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -96,9 +97,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -213,12 +214,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -273,6 +276,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
@ -318,12 +327,11 @@ def main():
raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
hyps = [token_ids_to_words(hyp) for hyp in hyp_list]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage:
./transducer_stateless2/export.py \
--exp-dir ./transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -46,12 +46,12 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -86,10 +86,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
@ -123,12 +123,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
@ -29,7 +29,7 @@ Usage:
(2) beam search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -38,7 +38,7 @@ Usage:
(3) modified beam search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -47,7 +47,7 @@ Usage:
(4) fast beam search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -67,7 +67,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -80,6 +79,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -96,9 +97,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -213,12 +214,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -273,6 +276,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
@ -318,12 +327,11 @@ def main():
raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
hyps = [token_ids_to_words(hyp) for hyp in hyp_list]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage:
./transducer_stateless_multi_datasets/export.py \
--exp-dir ./transducer_stateless_multi_datasets/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -47,7 +47,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from conformer import Conformer
@ -57,7 +57,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import AttributeDict, str2bool
from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser():
@ -92,10 +92,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -192,12 +192,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./transducer_stateless_multi_datasets/pretrained.py \
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
@ -29,7 +29,7 @@ Usage:
(2) beam search
./transducer_stateless_multi_datasets/pretrained.py \
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -38,7 +38,7 @@ Usage:
(3) modified beam search
./transducer_stateless_multi_datasets/pretrained.py \
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -47,7 +47,7 @@ Usage:
(4) fast beam search
./transducer_stateless_multi_datasets/pretrained.py \
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -67,7 +67,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -80,6 +79,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -96,9 +97,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -213,12 +214,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -273,6 +276,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
@ -318,12 +327,11 @@ def main():
raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
hyps = [token_ids_to_words(hyp) for hyp in hyp_list]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"
cd exp
@ -74,7 +73,6 @@ import onnx
import torch
import torch.nn as nn
from decoder import Decoder
from export import num_tokens
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
@ -86,7 +84,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():

View File

@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"
cd exp
@ -71,7 +70,6 @@ import onnx
import torch
import torch.nn as nn
from decoder import Decoder
from export import num_tokens
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
@ -83,7 +81,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, str2bool
from icefall.utils import make_pad_mask, num_tokens, str2bool
def get_parser():

View File

@ -160,7 +160,6 @@ with the following commands:
import argparse
import logging
import re
from pathlib import Path
from typing import List, Tuple
@ -176,27 +175,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, str2bool
def num_tokens(
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
) -> int:
"""Return the number of tokens excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
symbols = token_table.symbols
ans = []
for s in symbols:
if not disambig_pattern.match(s):
ans.append(token_table[s])
num_tokens = len(ans)
if 0 in ans:
num_tokens -= 1
return num_tokens
from icefall.utils import make_pad_mask, num_tokens, str2bool
def get_parser():
@ -487,6 +466,8 @@ def main():
device=device,
)
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg

View File

@ -410,10 +410,20 @@ def main():
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
if params.method == "ctc-decoding":
for filename, hyp in zip(params.sound_files, hyps):
words = "".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
elif params.method in [
"1best",
"nbest-rescoring",
"whole-lattice-rescoring",
]:
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -33,7 +33,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"
cd exp

View File

@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
@ -29,7 +28,7 @@ popd
2. Export the model to ONNX
./zipformer/export-onnx-streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \

View File

@ -31,7 +31,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"
cd exp

View File

@ -274,7 +274,7 @@ def main():
params.update(vars(args))
token_table = k2.SymbolTable.from_file(params.tokens)
params.vocab_size = num_tokens(token_table)
params.vocab_size = num_tokens(token_table) + 1 # +1 for blank
params.blank_id = token_table["<blk>"]
assert params.blank_id == 0
@ -429,10 +429,20 @@ def main():
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
if params.method == "ctc-decoding":
for filename, hyp in zip(params.sound_files, hyps):
words = "".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
elif params.method in [
"1best",
"nbest-rescoring",
"whole-lattice-rescoring",
]:
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -26,7 +26,7 @@ Usage:
./zipformer_mmi/export.py \
--exp-dir ./zipformer_mmi/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./zipformer_mmi/export.py \
--exp-dir ./zipformer_mmi/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -86,7 +86,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_ctc_model, get_params
@ -97,7 +97,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -154,10 +154,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -190,12 +190,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -21,7 +21,7 @@ You can generate the checkpoint with the following command:
./zipformer_mmi/export.py \
--exp-dir ./zipformer_mmi/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -30,14 +30,14 @@ Usage of this script:
(1) 1best
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method 1best \
/path/to/foo.wav \
/path/to/bar.wav
(2) nbest
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--nbest-scale 1.2 \
--method nbest \
/path/to/foo.wav \
@ -45,7 +45,7 @@ Usage of this script:
(3) nbest-rescoring-LG
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--nbest-scale 1.2 \
--method nbest-rescoring-LG \
/path/to/foo.wav \
@ -53,7 +53,7 @@ Usage of this script:
(4) nbest-rescoring-3-gram
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--nbest-scale 1.2 \
--method nbest-rescoring-3-gram \
/path/to/foo.wav \
@ -61,7 +61,7 @@ Usage of this script:
(5) nbest-rescoring-4-gram
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--nbest-scale 1.2 \
--method nbest-rescoring-4-gram \
/path/to/foo.wav \
@ -83,7 +83,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from decode import get_decoding_params
@ -97,7 +96,7 @@ from icefall.decode import (
one_best_decoding,
)
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import get_texts
from icefall.utils import get_texts, num_tokens
def get_parser():
@ -115,9 +114,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -247,13 +246,14 @@ def main():
params.update(get_decoding_params())
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -298,8 +298,6 @@ def main():
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
mmi_graph_compiler = MmiTrainingGraphCompiler(
params.lang_dir,
uniq_filename="lexicon.txt",
@ -313,6 +311,12 @@ def main():
if not hasattr(HP, "lm_scores"):
HP.lm_scores = HP.scores.clone()
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
method = params.method
assert method in (
"1best",
@ -390,14 +394,11 @@ def main():
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
hyps = [token_ids_to_words(ids) for ids in token_ids]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -498,7 +498,7 @@ def main():
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)

View File

@ -320,7 +320,7 @@ def main():
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
words = "".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)

Some files were not shown because too many files have changed in this diff Show More