From a81396b482c799b2ace2cefb79859be827b16f00 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 12 Aug 2023 16:53:59 +0800 Subject: [PATCH] Use tokens.txt to replace bpe.model (#1162) --- ...n-librispeech-conformer-ctc3-2022-11-28.sh | 10 +- ...h-lstm-transducer-stateless2-2022-09-03.sh | 6 +- ...-pruned-transducer-stateless-2022-03-12.sh | 4 +- ...pruned-transducer-stateless2-2022-04-29.sh | 4 +- ...pruned-transducer-stateless3-2022-04-29.sh | 4 +- ...pruned-transducer-stateless3-2022-05-13.sh | 8 +- ...pruned-transducer-stateless5-2022-05-13.sh | 4 +- ...pruned-transducer-stateless7-2022-11-11.sh | 6 +- ...ed-transducer-stateless7-ctc-2022-12-01.sh | 6 +- ...transducer-stateless7-ctc-bs-2023-01-29.sh | 6 +- ...nsducer-stateless7-streaming-2022-12-29.sh | 6 +- ...pruned-transducer-stateless8-2022-11-14.sh | 6 +- ...pruned-transducer-stateless2-2022-06-26.sh | 4 +- ...speech-transducer-stateless2-2022-04-19.sh | 4 +- ...un-librispeech-zipformer-mmi-2022-12-08.sh | 4 +- .../scripts/run-pre-trained-conformer-ctc.sh | 4 +- ...d-transducer-stateless-librispeech-100h.sh | 4 +- ...d-transducer-stateless-librispeech-960h.sh | 4 +- .../run-pre-trained-transducer-stateless.sh | 4 +- .github/scripts/run-pre-trained-transducer.sh | 2 +- ...enetspeech-pruned-transducer-stateless2.sh | 36 +- .github/scripts/test-ncnn-export.sh | 12 +- .github/scripts/test-onnx-export.sh | 138 ++++++- .../pruned_transducer_stateless7/export.py | 322 +--------------- .../pretrained.py | 349 +----------------- egs/librispeech/ASR/conformer_ctc/export.py | 18 +- .../ASR/conformer_ctc/pretrained.py | 40 +- egs/librispeech/ASR/conformer_ctc2/export.py | 19 +- egs/librispeech/ASR/conformer_ctc3/export.py | 23 +- .../ASR/conformer_ctc3/pretrained.py | 42 ++- .../export.py | 22 +- .../export-for-ncnn.py | 22 +- .../export-onnx.py | 25 +- .../export.py | 22 +- .../onnx_pretrained.py | 2 +- .../ASR/lstm_transducer_stateless/export.py | 25 +- .../lstm_transducer_stateless/pretrained.py | 49 +-- .../export-for-ncnn.py | 23 +- .../export-onnx-zh.py | 2 +- .../lstm_transducer_stateless2/export-onnx.py | 25 +- .../ASR/lstm_transducer_stateless2/export.py | 25 +- .../lstm_transducer_stateless2/pretrained.py | 49 +-- .../ASR/lstm_transducer_stateless3/export.py | 25 +- .../lstm_transducer_stateless3/pretrained.py | 46 ++- .../pruned_stateless_emformer_rnnt2/export.py | 23 +- .../export-onnx.py | 2 +- .../ASR/pruned_transducer_stateless/export.py | 24 +- .../pruned_transducer_stateless/pretrained.py | 49 +-- .../pruned_transducer_stateless2/export.py | 22 +- .../pretrained.py | 49 +-- .../export-onnx.py | 24 +- .../pruned_transducer_stateless3/export.py | 26 +- .../pretrained.py | 51 +-- .../pruned_transducer_stateless4/export.py | 22 +- .../export-onnx-streaming.py | 26 +- .../export-onnx.py | 26 +- .../pruned_transducer_stateless5/export.py | 22 +- .../pretrained.py | 49 +-- .../pruned_transducer_stateless6/export.py | 22 +- .../export-onnx.py | 27 +- .../pruned_transducer_stateless7/export.py | 30 +- .../pretrained.py | 55 +-- .../export.py | 24 +- .../pretrained.py | 51 +-- .../pretrained_ctc.py | 10 +- .../export.py | 24 +- .../export_onnx.py | 26 +- .../pretrained.py | 51 +-- .../pretrained_ctc.py | 10 +- .../export-for-ncnn-zh.py | 21 +- .../export-for-ncnn.py | 22 +- .../export-onnx-zh.py | 25 +- .../export-onnx.py | 24 +- .../export.py | 20 +- .../pretrained.py | 51 +-- .../export-for-ncnn.py | 22 +- .../pruned_transducer_stateless8/export.py | 24 +- .../pretrained.py | 51 +-- egs/librispeech/ASR/transducer/export.py | 22 +- egs/librispeech/ASR/transducer/pretrained.py | 33 +- .../ASR/transducer_stateless/export.py | 22 +- .../ASR/transducer_stateless/pretrained.py | 36 +- .../ASR/transducer_stateless2/export.py | 22 +- .../ASR/transducer_stateless2/pretrained.py | 36 +- .../export.py | 22 +- .../pretrained.py | 36 +- .../ASR/zipformer/export-onnx-streaming.py | 4 +- egs/librispeech/ASR/zipformer/export-onnx.py | 4 +- egs/librispeech/ASR/zipformer/export.py | 25 +- .../ASR/zipformer/jit_pretrained_ctc.py | 18 +- egs/librispeech/ASR/zipformer/onnx_check.py | 1 - .../zipformer/onnx_pretrained-streaming.py | 3 +- .../ASR/zipformer/onnx_pretrained.py | 1 - .../ASR/zipformer/pretrained_ctc.py | 20 +- egs/librispeech/ASR/zipformer_mmi/export.py | 24 +- .../ASR/zipformer_mmi/pretrained.py | 47 +-- .../export-onnx.py | 2 +- .../pretrained.py | 2 +- icefall/utils.py | 20 + 99 files changed, 1243 insertions(+), 1623 deletions(-) mode change 100755 => 120000 egs/aishell/ASR/pruned_transducer_stateless7/export.py mode change 100644 => 120000 egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh index c68ccc954..f6fe8c9b2 100755 --- a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh +++ b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh @@ -38,7 +38,7 @@ log "Decode with models exported by torch.jit.trace()" for m in ctc-decoding 1best; do ./conformer_ctc3/jit_pretrained.py \ --model-filename $repo/exp/jit_trace.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ --bpe-model $repo/data/lang_bpe_500/bpe.model \ --G $repo/data/lm/G_4_gram.pt \ @@ -53,7 +53,7 @@ log "Export to torchscript model" ./conformer_ctc3/export.py \ --exp-dir $repo/exp \ - --lang-dir $repo/data/lang_bpe_500 \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --jit-trace 1 \ --epoch 99 \ --avg 1 \ @@ -80,9 +80,9 @@ done for m in ctc-decoding 1best; do ./conformer_ctc3/pretrained.py \ --checkpoint $repo/exp/pretrained.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --G $repo/data/lm/G_4_gram.pt \ --method $m \ --sample-rate 16000 \ @@ -93,7 +93,7 @@ done echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then mkdir -p conformer_ctc3/exp ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt ln -s $PWD/$repo/data/lang_bpe_500 data/ diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh index 4cd2c4bec..d547bdd45 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh @@ -31,7 +31,7 @@ log "Test exporting with torch.jit.trace()" ./lstm_transducer_stateless2/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -55,7 +55,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -68,7 +68,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh index 6792c7088..412e3ad56 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh index dbf678d72..243b669ed 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh @@ -36,7 +36,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -49,7 +49,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh index b6d477afe..2d0f80304 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh @@ -35,7 +35,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -48,7 +48,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index efa4b53f0..3d5814c48 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -30,14 +30,14 @@ popd log "Export to torchscript model" ./pruned_transducer_stateless3/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 ./pruned_transducer_stateless3/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit-trace 1 @@ -74,7 +74,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh index 511fe0c9e..3d2442d54 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh @@ -32,7 +32,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --num-encoder-layers 18 \ --dim-feedforward 2048 \ --nhead 8 \ @@ -51,7 +51,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh index 2bc179c86..961dde4f4 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh @@ -33,7 +33,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -56,7 +56,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -69,7 +69,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh index 192438353..ba7139efb 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh @@ -37,7 +37,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -74,7 +74,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh index 7d2853c17..1ecbc4798 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh @@ -36,7 +36,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -72,7 +72,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -85,7 +85,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh index e1e4e1f10..37b192a57 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -37,7 +37,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --decode-chunk-len 32 \ --epoch 99 \ --avg 1 \ @@ -81,7 +81,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --decode-chunk-len 32 \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ @@ -95,7 +95,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --decode-chunk-len 32 \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh index 5d9485692..4f2bfac24 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh @@ -41,7 +41,7 @@ log "Decode with models exported by torch.jit.script()" log "Export to torchscript model" ./pruned_transducer_stateless8/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model false \ --epoch 99 \ --avg 1 \ @@ -65,7 +65,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -78,7 +78,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh index 77cd59506..5cbdad16d 100755 --- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh +++ b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh @@ -32,7 +32,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --simulate-streaming 1 \ --causal-convolution 1 \ $repo/test_wavs/1089-134686-0001.wav \ @@ -47,7 +47,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --simulate-streaming 1 \ --causal-convolution 1 \ $repo/test_wavs/1089-134686-0001.wav \ diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh index b4aca1b6b..ff77855a2 100755 --- a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh +++ b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh index a58b8ec56..c59921055 100755 --- a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh +++ b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh @@ -37,7 +37,7 @@ log "Export to torchscript model" ./zipformer_mmi/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -61,7 +61,7 @@ for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescor --method $method \ --checkpoint $repo/exp/pretrained.pt \ --lang-dir $repo/data/lang_bpe_500 \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index 125d1f3b1..a4959aa01 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -27,7 +27,7 @@ log "CTC decoding" --method ctc-decoding \ --num-classes 500 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.flac \ $repo/test_wavs/1221-135766-0001.flac \ $repo/test_wavs/1221-135766-0002.flac @@ -38,7 +38,7 @@ log "HLG decoding" --method 1best \ --num-classes 500 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ $repo/test_wavs/1089-134686-0001.flac \ diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh index 89115e88d..7b686328d 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh index 85e2c89e6..a8eeeb514 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh index 41456f11b..2e2360435 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh index 1331c966c..b865f8d13 100755 --- a/.github/scripts/run-pre-trained-transducer.sh +++ b/.github/scripts/run-pre-trained-transducer.sh @@ -27,7 +27,7 @@ log "Beam search decoding" --method beam_search \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh index 90097c752..a3a2d3080 100755 --- a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh +++ b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh @@ -17,7 +17,6 @@ git lfs install git clone $repo_url repo=$(basename $repo_url) - log "Display test files" tree $repo/ ls -lh $repo/test_wavs/*.wav @@ -29,12 +28,11 @@ popd log "Test exporting to ONNX format" -./pruned_transducer_stateless2/export.py \ +./pruned_transducer_stateless2/export-onnx.py \ --exp-dir $repo/exp \ --lang-dir $repo/data/lang_char \ --epoch 99 \ - --avg 1 \ - --onnx 1 + --avg 1 log "Export to torchscript model" @@ -59,19 +57,17 @@ log "Decode with ONNX models" ./pruned_transducer_stateless2/onnx_check.py \ --jit-filename $repo/exp/cpu_jit.pt \ - --onnx-encoder-filename $repo/exp/encoder.onnx \ - --onnx-decoder-filename $repo/exp/decoder.onnx \ - --onnx-joiner-filename $repo/exp/joiner.onnx \ - --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ - --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx + --onnx-encoder-filename $repo/exp/encoder-epoch-10-avg-2.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-10-avg-2.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-10-avg-2.onnx \ + --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ + --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj-epoch-10-avg-2.onnx ./pruned_transducer_stateless2/onnx_pretrained.py \ --tokens $repo/data/lang_char/tokens.txt \ - --encoder-model-filename $repo/exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ $repo/test_wavs/DEV_T0000000000.wav \ $repo/test_wavs/DEV_T0000000001.wav \ $repo/test_wavs/DEV_T0000000002.wav @@ -104,9 +100,9 @@ for sym in 1 2 3; do --lang-dir $repo/data/lang_char \ --decoding-method greedy_search \ --max-sym-per-frame $sym \ - $repo/test_wavs/DEV_T0000000000.wav \ - $repo/test_wavs/DEV_T0000000001.wav \ - $repo/test_wavs/DEV_T0000000002.wav + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav done for method in modified_beam_search beam_search fast_beam_search; do @@ -117,7 +113,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --beam-size 4 \ --checkpoint $repo/exp/epoch-99.pt \ --lang-dir $repo/data/lang_char \ - $repo/test_wavs/DEV_T0000000000.wav \ - $repo/test_wavs/DEV_T0000000001.wav \ - $repo/test_wavs/DEV_T0000000002.wav + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav done diff --git a/.github/scripts/test-ncnn-export.sh b/.github/scripts/test-ncnn-export.sh index ac16131d0..4073c594a 100755 --- a/.github/scripts/test-ncnn-export.sh +++ b/.github/scripts/test-ncnn-export.sh @@ -45,7 +45,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" cd exp @@ -56,11 +55,10 @@ log "Export via torch.jit.trace()" ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ - \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --num-encoder-layers 12 \ --chunk-length 32 \ --cnn-module-kernel 31 \ @@ -91,7 +89,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" cd exp @@ -102,7 +99,7 @@ log "Export via torch.jit.trace()" ./lstm_transducer_stateless2/export-for-ncnn.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 @@ -140,7 +137,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained.pt" cd exp @@ -148,7 +144,7 @@ ln -s pretrained.pt epoch-99.pt popd ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 99 \ @@ -199,7 +195,7 @@ ln -s pretrained.pt epoch-9999.pt popd ./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ - --lang-dir $repo/data/lang_char_bpe \ + --tokens $repo/data/lang_char_bpe/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 9999 \ diff --git a/.github/scripts/test-onnx-export.sh b/.github/scripts/test-onnx-export.sh index 39467c44a..fcfc11fa6 100755 --- a/.github/scripts/test-onnx-export.sh +++ b/.github/scripts/test-onnx-export.sh @@ -10,7 +10,123 @@ log() { cd egs/librispeech/ASR +log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) +pushd $repo +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +log "Export via torch.jit.script()" +./zipformer/export.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +log "Test export to ONNX format" +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size "16,32,64,-1" \ + --left-context-frames "64,128,256,-1" + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./zipformer/onnx_check.py \ + --jit-filename $repo/exp/jit_script.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +log "Test export streaming model to ONNX format" +./zipformer/export-onnx-streaming.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 64 + +ls -lh $repo/exp + +log "Run onnx_pretrained-streaming.py" + +./zipformer/onnx_pretrained-streaming.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo + +log "--------------------------------------------------------------------------" log "==========================================================================" repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 @@ -39,7 +155,7 @@ log "Export via torch.jit.trace()" log "Test exporting to ONNX format" ./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -88,7 +204,7 @@ popd log "Export via torch.jit.script()" ./pruned_transducer_stateless3/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 9999 \ --avg 1 \ --exp-dir $repo/exp/ \ @@ -97,7 +213,7 @@ log "Export via torch.jit.script()" log "Test exporting to ONNX format" ./pruned_transducer_stateless3/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 9999 \ --avg 1 \ --exp-dir $repo/exp/ @@ -126,7 +242,6 @@ log "Run onnx_pretrained.py" rm -rf $repo log "--------------------------------------------------------------------------" - log "==========================================================================" repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url @@ -143,7 +258,7 @@ popd log "Export via torch.jit.script()" ./pruned_transducer_stateless5/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -159,7 +274,7 @@ log "Export via torch.jit.script()" log "Test exporting to ONNX format" ./pruned_transducer_stateless5/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -205,7 +320,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained.pt" cd exp @@ -215,7 +329,7 @@ popd log "Export via torch.jit.script()" ./pruned_transducer_stateless7/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -226,7 +340,7 @@ log "Export via torch.jit.script()" log "Test exporting to ONNX format" ./pruned_transducer_stateless7/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -270,7 +384,7 @@ popd log "Test exporting to ONNX format" ./conv_emformer_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -310,7 +424,7 @@ popd log "Export via torch.jit.trace()" ./lstm_transducer_stateless2/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -320,7 +434,7 @@ log "Export via torch.jit.trace()" log "Test exporting to ONNX format" ./lstm_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export.py b/egs/aishell/ASR/pruned_transducer_stateless7/export.py deleted file mode 100755 index 1b0e8d3b9..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/export.py +++ /dev/null @@ -1,321 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Check -https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --epoch 20 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `pruned_transducer_stateless7/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless7/decode.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --lang-dir data/lang_char - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21 - # You will find the pre-trained model in icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21exp -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -import torch.nn as nn -from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt - - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export.py b/egs/aishell/ASR/pruned_transducer_stateless7/export.py new file mode 120000 index 000000000..2713792e6 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/export.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py deleted file mode 100644 index cc54027d6..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py +++ /dev/null @@ -1,348 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --epoch 20 \ - --avg 10 - -Usage of this script: - -(1) greedy search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by -./pruned_transducer_stateless7/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.lexicon import Lexicon -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=str, - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - token_table = lexicon.token_table - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp_tokens = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp_tokens = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py new file mode 120000 index 000000000..068f0f57f --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index fbcbd7b29..f0bb97560 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -23,12 +23,13 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -63,11 +64,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="""It contains language related input files such as "lexicon.txt" - """, + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -98,16 +98,16 @@ def get_params() -> AttributeDict: def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) logging.info(params) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + num_classes = num_tokens(token_table) + 1 # +1 for the blank device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 30def9c40..df3e4d819 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -24,7 +24,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from conformer import Conformer @@ -70,11 +69,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -83,10 +80,9 @@ def get_parser(): default="1best", help="""Decoding method. Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a sentence - piece model, i.e., lang_dir/bpe.model, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. + (0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file + to convert tokens to actual words or characters. It needs + neither a lexicon nor an n-gram LM. (1) 1best - Use the best path as decoding output. Only the transformer encoder output is used for decoding. We call it HLG decoding. @@ -297,6 +293,7 @@ def main(): waves = [w.to(device) for w in waves] logging.info("Decoding started") + hyps = [] features = fbank(waves) features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) @@ -313,10 +310,17 @@ def main(): if params.method == "ctc-decoding": logging.info("Use CTC decoding") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.bpe_model) max_token_id = params.num_classes - 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + H = k2.ctc_topo( max_token=max_token_id, modified=params.num_classes > 500, @@ -337,9 +341,9 @@ def main(): best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) - hyps = [s.split() for s in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method in [ "1best", "whole-lattice-rescoring", @@ -408,16 +412,16 @@ def main(): ) best_path = next(iter(best_path_dict.values())) - hyps = get_texts(best_path) word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(" ".join([word_sym_table[i] for i in hyp])) else: raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/conformer_ctc2/export.py b/egs/librispeech/ASR/conformer_ctc2/export.py index 7892b03c6..26a95dbfa 100755 --- a/egs/librispeech/ASR/conformer_ctc2/export.py +++ b/egs/librispeech/ASR/conformer_ctc2/export.py @@ -23,6 +23,7 @@ Usage: ./conformer_ctc2/export.py \ --exp-dir ./conformer_ctc2/exp \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -46,6 +47,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from decode import get_params @@ -56,8 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -123,10 +124,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="The lang dir", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -143,14 +144,14 @@ def get_parser(): def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + num_classes = num_tokens(token_table) + 1 # +1 for the blank device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/librispeech/ASR/conformer_ctc3/export.py b/egs/librispeech/ASR/conformer_ctc3/export.py index c5b95d981..5cb9b4b6d 100755 --- a/egs/librispeech/ASR/conformer_ctc3/export.py +++ b/egs/librispeech/ASR/conformer_ctc3/export.py @@ -25,7 +25,7 @@ Usage: ./conformer_ctc3/export.py \ --exp-dir ./conformer_ctc3/exp \ - --lang-dir data/lang_bpe_500 \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 \ --jit-trace 1 @@ -36,7 +36,7 @@ It will generates the file: `jit_trace.pt`. ./conformer_ctc3/export.py \ --exp-dir ./conformer_ctc3/exp \ - --lang-dir data/lang_bpe_500 \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -62,6 +62,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_ctc_model, get_params @@ -72,8 +73,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -130,10 +130,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", + "--tokens", + type=str, + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -171,9 +171,10 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + num_classes = num_tokens(token_table) + 1 # +1 for the blank params.vocab_size = num_classes if params.streaming_model: diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py index 880945ea0..c37b99cce 100755 --- a/egs/librispeech/ASR/conformer_ctc3/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py @@ -24,7 +24,7 @@ Usage (for non-streaming mode): (1) ctc-decoding ./conformer_ctc3/pretrained.py \ --checkpoint conformer_ctc3/exp/pretrained.pt \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method ctc-decoding \ --sample-rate 16000 \ test_wavs/1089-134686-0001.wav @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from decode import get_decoding_params @@ -114,11 +113,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -127,10 +124,9 @@ def get_parser(): default="1best", help="""Decoding method. Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a sentence - piece model, i.e., lang_dir/bpe.model, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. + (0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file + to convert tokens to actual words or characters. It needs + neither a lexicon nor an n-gram LM. (1) 1best - Use the best path as decoding output. Only the transformer encoder output is used for decoding. We call it HLG decoding. @@ -316,6 +312,7 @@ def main(): waves = [w.to(device) for w in waves] logging.info("Decoding started") + hyps = [] features = fbank(waves) feature_lengths = [f.size(0) for f in features] @@ -348,10 +345,17 @@ def main(): if params.method == "ctc-decoding": logging.info("Use CTC decoding") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.bpe_model) max_token_id = params.num_classes - 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + H = k2.ctc_topo( max_token=max_token_id, modified=False, @@ -372,9 +376,9 @@ def main(): best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) - hyps = [s.split() for s in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method in [ "1best", "nbest-rescoring", @@ -439,16 +443,16 @@ def main(): ) best_path = next(iter(best_path_dict.values())) - hyps = get_texts(best_path) word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(" ".join([word_sym_table[i] for i in hyp])) else: raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 09a3e96b0..67fcc35a4 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -22,7 +22,7 @@ Usage: ./conv_emformer_transducer_stateless/export.py \ --exp-dir ./conv_emformer_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 10 \ --use-averaged-model=True \ @@ -62,7 +62,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model @@ -72,7 +72,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -118,10 +118,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -166,12 +166,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py index 8fbb02f14..85dbd4661 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py @@ -8,7 +8,7 @@ for more details about how to use this file. Usage: ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ --exp-dir ./conv_emformer_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 10 \ --use-averaged-model=True \ @@ -37,7 +37,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -48,7 +48,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -94,10 +94,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -217,12 +217,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py index ad0b45bd9..cfd365207 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py @@ -18,7 +18,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" cd exp @@ -28,7 +27,7 @@ popd 2. Export the model to ONNX ./conv_emformer_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -55,14 +54,14 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder +from emformer import Emformer from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model -from emformer import Emformer from icefall.checkpoint import ( average_checkpoints, @@ -70,7 +69,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -127,10 +126,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -484,12 +483,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index b53426c75..8e5b14903 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./conv_emformer_transducer_stateless2/export.py \ --exp-dir ./conv_emformer_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 10 \ --use-averaged-model=True \ @@ -62,7 +62,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -73,7 +73,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -119,10 +119,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -167,12 +167,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py index db92ac696..5d7e2dfcd 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./conv_emformer_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py index e338342cc..c007220d5 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -26,7 +26,7 @@ Usage: ./lstm_transducer_stateless/export.py \ --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 \ --jit-trace 1 @@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`, ./lstm_transducer_stateless/export.py \ --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 @@ -79,7 +79,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -91,7 +91,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -148,10 +148,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -266,12 +266,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index b3a34a9e3..119fcf1fd 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,6 +78,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -95,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -214,13 +215,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -275,6 +277,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -286,8 +294,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -296,16 +304,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -326,12 +334,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py index 08bfcb204..2b8c92208 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py @@ -29,7 +29,7 @@ popd ./lstm_transducer_stateless2/export-for-ncnn.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -49,7 +49,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -60,7 +60,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -106,10 +106,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -221,12 +221,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py index f068f6a0f..89ced388c 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py @@ -613,7 +613,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py index acaff8540..6b6cb893f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./lstm_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -52,8 +52,8 @@ import logging from pathlib import Path from typing import Dict, Optional, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder @@ -68,7 +68,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -125,10 +125,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -437,12 +437,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -607,7 +608,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 0adc68112..5712da25e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -27,7 +27,7 @@ Usage: ./lstm_transducer_stateless2/export.py \ --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 \ --jit-trace 1 @@ -39,7 +39,7 @@ It will generate 3 files: `encoder_jit_trace.pt`, ./lstm_transducer_stateless2/export.py \ --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 @@ -80,7 +80,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -92,7 +92,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -149,10 +149,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -267,12 +267,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index f3f272b9f..5d6d97320 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -69,7 +69,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -82,6 +81,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -98,9 +99,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -217,13 +218,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -278,6 +280,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -289,8 +297,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -299,16 +307,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -329,12 +337,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py index a82cad043..21eaa049b 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -26,7 +26,7 @@ Usage: ./lstm_transducer_stateless3/export.py \ --exp-dir ./lstm_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 40 \ --avg 20 \ --jit-trace 1 @@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`, ./lstm_transducer_stateless3/export.py \ --exp-dir ./lstm_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 40 \ --avg 20 @@ -79,7 +79,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -91,7 +91,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -148,10 +148,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to tokens.txt.", ) parser.add_argument( @@ -266,12 +266,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index f49e9c518..29a0d4d1a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./lstm_transducer_stateless3/pretrained.py \ --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./lstm_transducer_stateless3/pretrained.py \ --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./lstm_transducer_stateless3/pretrained.py \ --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -79,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -95,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -214,13 +216,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -275,6 +278,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -286,8 +295,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -296,16 +305,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -326,12 +335,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index 3612a2bfd..ec2c9d580 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -22,7 +22,7 @@ Usage: ./prunted_stateless_emformer_rnnt/export.py \ --exp-dir ./prunted_stateless_emformer_rnnt/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -48,7 +48,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model @@ -58,7 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -115,10 +115,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -154,13 +154,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py index a3ebe9d8c..282238c13 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py @@ -508,7 +508,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index a19f9ab9a..4b20e3a2b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless/export.py \ --exp-dir ./pruned_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -87,10 +87,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -135,13 +135,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index 2ed1725b4..02f9f1b03 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,7 +78,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -97,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -237,13 +236,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.simulate_streaming: assert ( @@ -314,6 +314,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -325,8 +331,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -335,16 +341,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -365,12 +371,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index 984caf5f2..e02afa892 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -98,10 +98,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -145,12 +145,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index 013964720..029f55ba0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,7 +78,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -97,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -238,13 +237,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.simulate_streaming: assert ( @@ -315,6 +315,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -326,8 +332,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -336,16 +342,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -366,12 +372,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py index 9645b7801..26dea7e11 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless3/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 9999 \ --avg 1 \ --exp-dir $repo/exp/ @@ -48,8 +48,8 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer @@ -59,7 +59,7 @@ from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import setup_logger +from icefall.utils import num_tokens, setup_logger def get_parser(): @@ -105,10 +105,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -393,12 +393,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -518,7 +520,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index f30c9df6a..925b15646 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 \ --jit 1 @@ -44,7 +44,7 @@ It will also generate 3 other files: `encoder_jit_script.pt`, ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 \ --jit-trace 1 @@ -56,7 +56,7 @@ It will generates 3 files: `encoder_jit_trace.pt`, ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -97,14 +97,14 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -150,10 +150,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -342,12 +342,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 7c3dfc660..abda4e2d4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +87,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -247,13 +246,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.simulate_streaming: assert ( @@ -324,6 +324,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -335,8 +341,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -345,16 +351,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -375,12 +381,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index 8f33f5b05..08d736f52 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless4/export.py \ --exp-dir ./pruned_transducer_stateless4/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -48,7 +48,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -59,7 +59,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -116,10 +116,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -164,12 +164,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py index 938ff2f16..549fb13c9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless5/export-onnx-streaming.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -58,13 +58,13 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer -from onnxruntime.quantization import QuantType, quantize_dynamic from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -74,7 +74,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -131,10 +131,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -489,12 +489,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -662,7 +664,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py index 20fd8dff8..fff0fcdd5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless5/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -55,13 +55,13 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer -from onnxruntime.quantization import QuantType, quantize_dynamic from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -71,7 +71,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -128,10 +128,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -416,12 +416,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -586,7 +588,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index 54f656859..e5223be26 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -48,7 +48,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -59,7 +59,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -116,10 +116,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -164,12 +164,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index 74a2210c3..304fa8693 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,6 +78,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -95,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -214,13 +215,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -275,6 +277,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -286,8 +294,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -296,16 +304,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -326,12 +334,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index 4d0d8326c..38f48b2ed 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless6/export.py \ --exp-dir ./pruned_transducer_stateless6/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -98,10 +98,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -135,12 +135,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py index d2db92820..11c885f4d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang +# Zengrui Jin) """ This script exports a transducer model from PyTorch to ONNX. @@ -18,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt" cd exp @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless7/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -50,8 +50,8 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder @@ -66,7 +66,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -123,10 +123,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -411,12 +410,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -581,7 +580,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 3e3160e7e..eb4c4d282 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang +# Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -26,7 +27,7 @@ Usage: ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +46,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -65,7 +66,7 @@ you can do: --avg 1 \ --max-duration 600 \ --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model + --tokens data/lang_bpe_500/tokens.txt \ Check ./pretrained.py for its usage. @@ -86,7 +87,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -98,7 +99,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -155,10 +156,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -198,12 +198,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -292,7 +292,7 @@ def main(): model.to("cpu") model.eval() - if params.jit is True: + if params.jit: convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index d05bafcfb..86c922cda 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -20,7 +21,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +30,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +38,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +47,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +56,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +76,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +88,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens def get_parser(): @@ -106,9 +106,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -225,13 +225,13 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -286,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -297,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -307,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -337,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py index c1607699f..51e62d6a8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -97,7 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -154,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -197,12 +197,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py index 2f1b1a49f..78e0fa778 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,6 +87,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -104,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -223,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -284,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -295,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -305,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -335,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py index 5d460edb5..904c1deae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py @@ -22,14 +22,14 @@ You can use the following command to get the exported models: ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 Usage of this script: (1) ctc-decoding -./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --bpe-model data/lang_bpe_500/bpe.model \ --method ctc-decoding \ @@ -38,7 +38,7 @@ Usage of this script: /path/to/bar.wav (2) 1best -./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -48,7 +48,7 @@ Usage of this script: /path/to/bar.wav (3) nbest-rescoring -./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./bruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -60,7 +60,7 @@ Usage of this script: (4) whole-lattice-rescoring -./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py index 05df8cfff..9f35cf63e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -97,7 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -154,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -197,12 +197,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py index 630a7f735..d3033b888 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py @@ -28,7 +28,7 @@ Usage: ./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 \ --onnx 1 @@ -48,7 +48,7 @@ Check `onnx_check.py` for how to use them. (2) Export to ONNX format which can be used in Triton Server ./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 \ --onnx-triton 1 @@ -86,9 +86,10 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn +from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -98,8 +99,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool -from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -156,10 +156,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -728,12 +728,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py index ea0fe9164..5d240cf30 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,6 +87,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -104,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -223,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -284,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -295,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -305,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -335,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py index 412631ba1..914107526 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py @@ -22,14 +22,14 @@ You can use the following command to get the exported models: ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 Usage of this script: (1) ctc-decoding -./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --bpe-model data/lang_bpe_500/bpe.model \ --method ctc-decoding \ @@ -38,7 +38,7 @@ Usage of this script: /path/to/bar.wav (2) 1best -./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -48,7 +48,7 @@ Usage of this script: /path/to/bar.wav (3) nbest-rescoring -./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./bruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -60,7 +60,7 @@ Usage of this script: (4) whole-lattice-rescoring -./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py index e196f8b7d..07de57a86 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py @@ -66,6 +66,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -76,8 +77,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -123,10 +123,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="The tokens.txt file", ) parser.add_argument( @@ -246,9 +246,14 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py index 4a16a97fb..9a6b31268 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -28,7 +28,7 @@ popd 2. Export to ncnn ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 99 \ @@ -64,7 +64,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -75,7 +75,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -121,10 +121,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -244,12 +244,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py index 04d97808d..8653126de 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py @@ -29,7 +29,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless7_streaming/export-onnx-zh.py \ - --lang-dir $repo/data/lang_char_bpe \ + --tokens $repo/data/lang_char_bpe/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -60,6 +60,7 @@ import logging from pathlib import Path from typing import Dict, List, Tuple +import k2 import onnx import torch import torch.nn as nn @@ -76,8 +77,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -134,10 +134,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="The tokens.txt file", ) parser.add_argument( @@ -493,9 +493,14 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -661,7 +666,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py index e71bcaf29..6f84d79b4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -27,7 +27,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -48,8 +48,8 @@ import logging from pathlib import Path from typing import Dict, List, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder @@ -65,7 +65,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -122,10 +122,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -481,12 +481,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -652,7 +654,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py index c191b5bcc..59a7eb589 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -139,8 +139,8 @@ import argparse import logging from pathlib import Path +import k2 import onnxruntime -import sentencepiece as spm import torch import torch.nn as nn from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner @@ -154,7 +154,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -211,10 +211,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -675,12 +675,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py index fb77fdd42..bc42e8d05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +87,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -225,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -286,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -297,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -307,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -337,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py index 4a16a97fb..9a6b31268 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py @@ -28,7 +28,7 @@ popd 2. Export to ncnn ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 99 \ @@ -64,7 +64,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -75,7 +75,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -121,10 +121,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -244,12 +244,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py index d4a228b47..d9697680b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless8/export.py \ --exp-dir ./pruned_transducer_stateless8/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless8/export.py \ --exp-dir ./pruned_transducer_stateless8/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -98,7 +98,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -155,10 +155,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -198,12 +198,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 486d9d74e..64b38c9d5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless8/export.py \ --exp-dir ./pruned_transducer_stateless8/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +87,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -225,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -286,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -297,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -307,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -337,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py index 6db0272f0..3b9e4a5dc 100755 --- a/egs/librispeech/ASR/transducer/export.py +++ b/egs/librispeech/ASR/transducer/export.py @@ -22,7 +22,7 @@ Usage: ./transducer/export.py \ --exp-dir ./transducer/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 34 \ --avg 11 @@ -46,7 +46,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from conformer import Conformer from decoder import Decoder @@ -55,7 +55,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -90,10 +90,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -191,12 +191,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index 511610245..c2413f5de 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -19,7 +19,7 @@ Usage: ./transducer/pretrained.py \ --checkpoint ./transducer/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav \ @@ -36,8 +36,8 @@ import logging import math from typing import List +import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import beam_search, greedy_search @@ -48,7 +48,7 @@ from model import Transducer from torch.nn.utils.rnn import pad_sequence from icefall.env import get_env_info -from icefall.utils import AttributeDict +from icefall.utils import AttributeDict, num_tokens def get_parser(): @@ -66,11 +66,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="Path to tokens.txt.", ) parser.add_argument( @@ -204,12 +202,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -257,6 +257,12 @@ def main(): x=features, x_lens=feature_lengths ) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + num_waves = encoder_out.size(0) hyps = [] for i in range(num_waves): @@ -272,12 +278,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py index 89359f1a4..c397eb171 100755 --- a/egs/librispeech/ASR/transducer_stateless/export.py +++ b/egs/librispeech/ASR/transducer_stateless/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless/export.py \ --exp-dir ./transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -46,7 +46,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -56,7 +56,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -91,10 +91,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -191,12 +191,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index 915a6069d..5898dd0f5 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -47,7 +47,7 @@ Usage: (4) fast beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -80,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -96,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -213,12 +214,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -273,6 +276,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_list = fast_beam_search_one_best( @@ -318,12 +327,11 @@ def main(): raise ValueError(f"Unsupported method: {params.method}") hyp_list.append(hyp) - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps = [token_ids_to_words(hyp) for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py index d33d02642..f4b6f5554 100755 --- a/egs/librispeech/ASR/transducer_stateless2/export.py +++ b/egs/librispeech/ASR/transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless2/export.py \ --exp-dir ./transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -46,12 +46,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -86,10 +86,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -123,12 +123,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py index 0738f30c0..b69b347ef 100755 --- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -47,7 +47,7 @@ Usage: (4) fast beam search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -80,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -96,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -213,12 +214,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -273,6 +276,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_list = fast_beam_search_one_best( @@ -318,12 +327,11 @@ def main(): raise ValueError(f"Unsupported method: {params.method}") hyp_list.append(hyp) - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps = [token_ids_to_words(hyp) for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py index 3735ef452..6d31dfe34 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless_multi_datasets/export.py \ --exp-dir ./transducer_stateless_multi_datasets/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,7 +47,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -57,7 +57,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -92,10 +92,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -192,12 +192,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py index 8c7726367..4f29d6f1f 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -47,7 +47,7 @@ Usage: (4) fast beam search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -80,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -96,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -213,12 +214,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -273,6 +276,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_list = fast_beam_search_one_best( @@ -318,12 +327,11 @@ def main(): raise ValueError(f"Unsupported method: {params.method}") hyp_list.append(hyp) - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps = [token_ids_to_words(hyp) for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 3eb06f68c..a951aeef3 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp @@ -74,7 +73,6 @@ import onnx import torch import torch.nn as nn from decoder import Decoder -from export import num_tokens from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -86,7 +84,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index 724fdd2a6..e0d664009 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp @@ -71,7 +70,6 @@ import onnx import torch import torch.nn as nn from decoder import Decoder -from export import num_tokens from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -83,7 +81,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import make_pad_mask, str2bool +from icefall.utils import make_pad_mask, num_tokens, str2bool def get_parser(): diff --git a/egs/librispeech/ASR/zipformer/export.py b/egs/librispeech/ASR/zipformer/export.py index 4a48d5bad..2b8d1aaf3 100755 --- a/egs/librispeech/ASR/zipformer/export.py +++ b/egs/librispeech/ASR/zipformer/export.py @@ -160,7 +160,6 @@ with the following commands: import argparse import logging -import re from pathlib import Path from typing import List, Tuple @@ -176,27 +175,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import make_pad_mask, str2bool - - -def num_tokens( - token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") -) -> int: - """Return the number of tokens excluding those from - disambiguation symbols. - - Caution: - 0 is not a token ID so it is excluded from the return value. - """ - symbols = token_table.symbols - ans = [] - for s in symbols: - if not disambig_pattern.match(s): - ans.append(token_table[s]) - num_tokens = len(ans) - if 0 in ans: - num_tokens -= 1 - return num_tokens +from icefall.utils import make_pad_mask, num_tokens, str2bool def get_parser(): @@ -487,6 +466,8 @@ def main(): device=device, ) ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: assert params.avg > 0, params.avg start = params.epoch - params.avg diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py index 904d8cd76..660a4bfc6 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py @@ -410,10 +410,20 @@ def main(): raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" + if params.method == "ctc-decoding": + for filename, hyp in zip(params.sound_files, hyps): + words = "".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/zipformer/onnx_check.py b/egs/librispeech/ASR/zipformer/onnx_check.py index b38b875d0..93bd3a211 100755 --- a/egs/librispeech/ASR/zipformer/onnx_check.py +++ b/egs/librispeech/ASR/zipformer/onnx_check.py @@ -33,7 +33,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py index 2ce4506a8..500b2cd09 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained.pt" cd exp @@ -29,7 +28,7 @@ popd 2. Export the model to ONNX ./zipformer/export-onnx-streaming.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained.py b/egs/librispeech/ASR/zipformer/onnx_pretrained.py index e8a521460..032b07721 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained.py @@ -31,7 +31,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py index be239e9c3..9dff2e6fc 100755 --- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py @@ -274,7 +274,7 @@ def main(): params.update(vars(args)) token_table = k2.SymbolTable.from_file(params.tokens) - params.vocab_size = num_tokens(token_table) + params.vocab_size = num_tokens(token_table) + 1 # +1 for blank params.blank_id = token_table[""] assert params.blank_id == 0 @@ -429,10 +429,20 @@ def main(): raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" + if params.method == "ctc-decoding": + for filename, hyp in zip(params.sound_files, hyps): + words = "".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/zipformer_mmi/export.py b/egs/librispeech/ASR/zipformer_mmi/export.py index 0af7bd367..1aec56420 100755 --- a/egs/librispeech/ASR/zipformer_mmi/export.py +++ b/egs/librispeech/ASR/zipformer_mmi/export.py @@ -26,7 +26,7 @@ Usage: ./zipformer_mmi/export.py \ --exp-dir ./zipformer_mmi/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./zipformer_mmi/export.py \ --exp-dir ./zipformer_mmi/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_ctc_model, get_params @@ -97,7 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -154,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -190,12 +190,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py index 0e7fd0daf..3ba4da5dd 100755 --- a/egs/librispeech/ASR/zipformer_mmi/pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py @@ -21,7 +21,7 @@ You can generate the checkpoint with the following command: ./zipformer_mmi/export.py \ --exp-dir ./zipformer_mmi/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -30,14 +30,14 @@ Usage of this script: (1) 1best ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method 1best \ /path/to/foo.wav \ /path/to/bar.wav (2) nbest ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest \ /path/to/foo.wav \ @@ -45,7 +45,7 @@ Usage of this script: (3) nbest-rescoring-LG ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest-rescoring-LG \ /path/to/foo.wav \ @@ -53,7 +53,7 @@ Usage of this script: (4) nbest-rescoring-3-gram ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest-rescoring-3-gram \ /path/to/foo.wav \ @@ -61,7 +61,7 @@ Usage of this script: (5) nbest-rescoring-4-gram ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest-rescoring-4-gram \ /path/to/foo.wav \ @@ -83,7 +83,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from decode import get_decoding_params @@ -97,7 +96,7 @@ from icefall.decode import ( one_best_decoding, ) from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import get_texts +from icefall.utils import get_texts, num_tokens def get_parser(): @@ -115,9 +114,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -247,13 +246,14 @@ def main(): params.update(get_decoding_params()) params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -298,8 +298,6 @@ def main(): features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) mmi_graph_compiler = MmiTrainingGraphCompiler( params.lang_dir, uniq_filename="lexicon.txt", @@ -313,6 +311,12 @@ def main(): if not hasattr(HP, "lm_scores"): HP.lm_scores = HP.scores.clone() + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + method = params.method assert method in ( "1best", @@ -390,14 +394,11 @@ def main(): # # token_ids is a lit-of-list of IDs token_ids = get_texts(best_path) - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] + hyps = [token_ids_to_words(ids) for ids in token_ids] + s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py index fad66986b..760fad974 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py @@ -498,7 +498,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py index bc499f3dd..c3d67ad92 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -320,7 +320,7 @@ def main(): s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) + words = "".join(hyp) s += f"{filename}:\n{words}\n\n" logging.info(s) diff --git a/icefall/utils.py b/icefall/utils.py index 0feff9dc8..b01cd2770 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -2060,3 +2060,23 @@ def symlink_or_copy(exp_dir: Path, src: str, dst: str): except OSError: copyfile(src=exp_dir / src, dst=exp_dir / dst) os.close(dir_fd) + + +def num_tokens( + token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") +) -> int: + """Return the number of tokens excluding those from + disambiguation symbols. + + Caution: + 0 is not a token ID so it is excluded from the return value. + """ + symbols = token_table.symbols + ans = [] + for s in symbols: + if not disambig_pattern.match(s): + ans.append(token_table[s]) + num_tokens = len(ans) + if 0 in ans: + num_tokens -= 1 + return num_tokens