diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh index c68ccc954..f6fe8c9b2 100755 --- a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh +++ b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh @@ -38,7 +38,7 @@ log "Decode with models exported by torch.jit.trace()" for m in ctc-decoding 1best; do ./conformer_ctc3/jit_pretrained.py \ --model-filename $repo/exp/jit_trace.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ --bpe-model $repo/data/lang_bpe_500/bpe.model \ --G $repo/data/lm/G_4_gram.pt \ @@ -53,7 +53,7 @@ log "Export to torchscript model" ./conformer_ctc3/export.py \ --exp-dir $repo/exp \ - --lang-dir $repo/data/lang_bpe_500 \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --jit-trace 1 \ --epoch 99 \ --avg 1 \ @@ -80,9 +80,9 @@ done for m in ctc-decoding 1best; do ./conformer_ctc3/pretrained.py \ --checkpoint $repo/exp/pretrained.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --G $repo/data/lm/G_4_gram.pt \ --method $m \ --sample-rate 16000 \ @@ -93,7 +93,7 @@ done echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then mkdir -p conformer_ctc3/exp ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt ln -s $PWD/$repo/data/lang_bpe_500 data/ diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh index 4cd2c4bec..d547bdd45 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh @@ -31,7 +31,7 @@ log "Test exporting with torch.jit.trace()" ./lstm_transducer_stateless2/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -55,7 +55,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -68,7 +68,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh index 6792c7088..412e3ad56 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh index dbf678d72..243b669ed 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh @@ -36,7 +36,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -49,7 +49,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh index b6d477afe..2d0f80304 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh @@ -35,7 +35,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -48,7 +48,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index efa4b53f0..3d5814c48 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -30,14 +30,14 @@ popd log "Export to torchscript model" ./pruned_transducer_stateless3/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 ./pruned_transducer_stateless3/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit-trace 1 @@ -74,7 +74,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh index 511fe0c9e..3d2442d54 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh @@ -32,7 +32,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --num-encoder-layers 18 \ --dim-feedforward 2048 \ --nhead 8 \ @@ -51,7 +51,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh index 2bc179c86..961dde4f4 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh @@ -33,7 +33,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -56,7 +56,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -69,7 +69,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh index 192438353..ba7139efb 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh @@ -37,7 +37,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -74,7 +74,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh index 7d2853c17..1ecbc4798 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh @@ -36,7 +36,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -72,7 +72,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -85,7 +85,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh index e1e4e1f10..37b192a57 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -37,7 +37,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --decode-chunk-len 32 \ --epoch 99 \ --avg 1 \ @@ -81,7 +81,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --decode-chunk-len 32 \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ @@ -95,7 +95,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --decode-chunk-len 32 \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh index 5d9485692..4f2bfac24 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh @@ -41,7 +41,7 @@ log "Decode with models exported by torch.jit.script()" log "Export to torchscript model" ./pruned_transducer_stateless8/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model false \ --epoch 99 \ --avg 1 \ @@ -65,7 +65,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -78,7 +78,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh index 77cd59506..5cbdad16d 100755 --- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh +++ b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh @@ -32,7 +32,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --simulate-streaming 1 \ --causal-convolution 1 \ $repo/test_wavs/1089-134686-0001.wav \ @@ -47,7 +47,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --simulate-streaming 1 \ --causal-convolution 1 \ $repo/test_wavs/1089-134686-0001.wav \ diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh index b4aca1b6b..ff77855a2 100755 --- a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh +++ b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh index a58b8ec56..c59921055 100755 --- a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh +++ b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh @@ -37,7 +37,7 @@ log "Export to torchscript model" ./zipformer_mmi/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -61,7 +61,7 @@ for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescor --method $method \ --checkpoint $repo/exp/pretrained.pt \ --lang-dir $repo/data/lang_bpe_500 \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index 125d1f3b1..a4959aa01 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -27,7 +27,7 @@ log "CTC decoding" --method ctc-decoding \ --num-classes 500 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.flac \ $repo/test_wavs/1221-135766-0001.flac \ $repo/test_wavs/1221-135766-0002.flac @@ -38,7 +38,7 @@ log "HLG decoding" --method 1best \ --num-classes 500 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ $repo/test_wavs/1089-134686-0001.flac \ diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh index 89115e88d..7b686328d 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh index 85e2c89e6..a8eeeb514 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh index 41456f11b..2e2360435 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh index 1331c966c..b865f8d13 100755 --- a/.github/scripts/run-pre-trained-transducer.sh +++ b/.github/scripts/run-pre-trained-transducer.sh @@ -27,7 +27,7 @@ log "Beam search decoding" --method beam_search \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh index 90097c752..a3a2d3080 100755 --- a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh +++ b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh @@ -17,7 +17,6 @@ git lfs install git clone $repo_url repo=$(basename $repo_url) - log "Display test files" tree $repo/ ls -lh $repo/test_wavs/*.wav @@ -29,12 +28,11 @@ popd log "Test exporting to ONNX format" -./pruned_transducer_stateless2/export.py \ +./pruned_transducer_stateless2/export-onnx.py \ --exp-dir $repo/exp \ --lang-dir $repo/data/lang_char \ --epoch 99 \ - --avg 1 \ - --onnx 1 + --avg 1 log "Export to torchscript model" @@ -59,19 +57,17 @@ log "Decode with ONNX models" ./pruned_transducer_stateless2/onnx_check.py \ --jit-filename $repo/exp/cpu_jit.pt \ - --onnx-encoder-filename $repo/exp/encoder.onnx \ - --onnx-decoder-filename $repo/exp/decoder.onnx \ - --onnx-joiner-filename $repo/exp/joiner.onnx \ - --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ - --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx + --onnx-encoder-filename $repo/exp/encoder-epoch-10-avg-2.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-10-avg-2.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-10-avg-2.onnx \ + --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ + --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj-epoch-10-avg-2.onnx ./pruned_transducer_stateless2/onnx_pretrained.py \ --tokens $repo/data/lang_char/tokens.txt \ - --encoder-model-filename $repo/exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ $repo/test_wavs/DEV_T0000000000.wav \ $repo/test_wavs/DEV_T0000000001.wav \ $repo/test_wavs/DEV_T0000000002.wav @@ -104,9 +100,9 @@ for sym in 1 2 3; do --lang-dir $repo/data/lang_char \ --decoding-method greedy_search \ --max-sym-per-frame $sym \ - $repo/test_wavs/DEV_T0000000000.wav \ - $repo/test_wavs/DEV_T0000000001.wav \ - $repo/test_wavs/DEV_T0000000002.wav + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav done for method in modified_beam_search beam_search fast_beam_search; do @@ -117,7 +113,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --beam-size 4 \ --checkpoint $repo/exp/epoch-99.pt \ --lang-dir $repo/data/lang_char \ - $repo/test_wavs/DEV_T0000000000.wav \ - $repo/test_wavs/DEV_T0000000001.wav \ - $repo/test_wavs/DEV_T0000000002.wav + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav done diff --git a/.github/scripts/test-ncnn-export.sh b/.github/scripts/test-ncnn-export.sh index ac16131d0..4073c594a 100755 --- a/.github/scripts/test-ncnn-export.sh +++ b/.github/scripts/test-ncnn-export.sh @@ -45,7 +45,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" cd exp @@ -56,11 +55,10 @@ log "Export via torch.jit.trace()" ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ - \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --num-encoder-layers 12 \ --chunk-length 32 \ --cnn-module-kernel 31 \ @@ -91,7 +89,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" cd exp @@ -102,7 +99,7 @@ log "Export via torch.jit.trace()" ./lstm_transducer_stateless2/export-for-ncnn.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 @@ -140,7 +137,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained.pt" cd exp @@ -148,7 +144,7 @@ ln -s pretrained.pt epoch-99.pt popd ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 99 \ @@ -199,7 +195,7 @@ ln -s pretrained.pt epoch-9999.pt popd ./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ - --lang-dir $repo/data/lang_char_bpe \ + --tokens $repo/data/lang_char_bpe/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 9999 \ diff --git a/.github/scripts/test-onnx-export.sh b/.github/scripts/test-onnx-export.sh index 39467c44a..fcfc11fa6 100755 --- a/.github/scripts/test-onnx-export.sh +++ b/.github/scripts/test-onnx-export.sh @@ -10,7 +10,123 @@ log() { cd egs/librispeech/ASR +log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) +pushd $repo +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +log "Export via torch.jit.script()" +./zipformer/export.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +log "Test export to ONNX format" +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size "16,32,64,-1" \ + --left-context-frames "64,128,256,-1" + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./zipformer/onnx_check.py \ + --jit-filename $repo/exp/jit_script.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +log "Test export streaming model to ONNX format" +./zipformer/export-onnx-streaming.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 64 + +ls -lh $repo/exp + +log "Run onnx_pretrained-streaming.py" + +./zipformer/onnx_pretrained-streaming.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo + +log "--------------------------------------------------------------------------" log "==========================================================================" repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 @@ -39,7 +155,7 @@ log "Export via torch.jit.trace()" log "Test exporting to ONNX format" ./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -88,7 +204,7 @@ popd log "Export via torch.jit.script()" ./pruned_transducer_stateless3/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 9999 \ --avg 1 \ --exp-dir $repo/exp/ \ @@ -97,7 +213,7 @@ log "Export via torch.jit.script()" log "Test exporting to ONNX format" ./pruned_transducer_stateless3/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 9999 \ --avg 1 \ --exp-dir $repo/exp/ @@ -126,7 +242,6 @@ log "Run onnx_pretrained.py" rm -rf $repo log "--------------------------------------------------------------------------" - log "==========================================================================" repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url @@ -143,7 +258,7 @@ popd log "Export via torch.jit.script()" ./pruned_transducer_stateless5/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -159,7 +274,7 @@ log "Export via torch.jit.script()" log "Test exporting to ONNX format" ./pruned_transducer_stateless5/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -205,7 +320,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained.pt" cd exp @@ -215,7 +329,7 @@ popd log "Export via torch.jit.script()" ./pruned_transducer_stateless7/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -226,7 +340,7 @@ log "Export via torch.jit.script()" log "Test exporting to ONNX format" ./pruned_transducer_stateless7/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -270,7 +384,7 @@ popd log "Test exporting to ONNX format" ./conv_emformer_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -310,7 +424,7 @@ popd log "Export via torch.jit.trace()" ./lstm_transducer_stateless2/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -320,7 +434,7 @@ log "Export via torch.jit.trace()" log "Test exporting to ONNX format" ./lstm_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 8a2c94829..57f15fe87 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -44,11 +44,6 @@ jobs: with: fetch-depth: 0 - - name: Install graphviz - shell: bash - run: | - sudo apt-get -qq install graphviz - - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -70,6 +65,7 @@ jobs: pip install --no-binary protobuf protobuf==3.20.* pip install --no-deps --force-reinstall https://huggingface.co/csukuangfj/k2/resolve/main/cpu/k2-1.24.3.dev20230508+cpu.torch1.13.1-cp38-cp38-linux_x86_64.whl + pip install kaldifeat==1.25.0.dev20230726+cpu.torch1.13.1 -f https://csukuangfj.github.io/kaldifeat/cpu.html - name: Run yesno recipe shell: bash @@ -78,9 +74,75 @@ jobs: export PYTHONPATH=$PWD:$PYTHONPATH echo $PYTHONPATH - cd egs/yesno/ASR ./prepare.sh python3 ./tdnn/train.py python3 ./tdnn/decode.py - # TODO: Check that the WER is less than some value + + - name: Test exporting to pretrained.pt + shell: bash + working-directory: ${{github.workspace}} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + echo $PYTHONPATH + + cd egs/yesno/ASR + python3 ./tdnn/export.py --epoch 14 --avg 2 + + python3 ./tdnn/pretrained.py \ + --checkpoint ./tdnn/exp/pretrained.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + + - name: Test exporting to torchscript + shell: bash + working-directory: ${{github.workspace}} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + echo $PYTHONPATH + + cd egs/yesno/ASR + python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 + + python3 ./tdnn/jit_pretrained.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + + - name: Test exporting to onnx + shell: bash + working-directory: ${{github.workspace}} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + echo $PYTHONPATH + + cd egs/yesno/ASR + python3 ./tdnn/export_onnx.py --epoch 14 --avg 2 + + echo "Test float32 model" + python3 ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + + + echo "Test int8 model" + python3 ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + + - name: Show generated files + shell: bash + working-directory: ${{github.workspace}} + run: | + cd egs/yesno/ASR + ls -lh tdnn/exp diff --git a/docs/source/decoding-with-langugage-models/LODR.rst b/docs/source/decoding-with-langugage-models/LODR.rst index 7ffa0c128..b6625ee1d 100644 --- a/docs/source/decoding-with-langugage-models/LODR.rst +++ b/docs/source/decoding-with-langugage-models/LODR.rst @@ -4,59 +4,59 @@ LODR for RNN Transducer ======================= -As a type of E2E model, neural transducers are usually considered as having an internal -language model, which learns the language level information on the training corpus. -In real-life scenario, there is often a mismatch between the training corpus and the target corpus space. +As a type of E2E model, neural transducers are usually considered as having an internal +language model, which learns the language level information on the training corpus. +In real-life scenario, there is often a mismatch between the training corpus and the target corpus space. This mismatch can be a problem when decoding for neural transducer models with language models as its internal language can act "against" the external LM. In this tutorial, we show how to use `Low-order Density Ratio `_ to alleviate this effect to further improve the performance -of langugae model integration. +of langugae model integration. .. note:: - This tutorial is based on the recipe + This tutorial is based on the recipe `pruned_transducer_stateless7_streaming `_, - which is a streaming transducer model trained on `LibriSpeech`_. + which is a streaming transducer model trained on `LibriSpeech`_. However, you can easily apply LODR to other recipes. If you encounter any problems, please open an issue here `icefall `__. .. note:: - For simplicity, the training and testing corpus in this tutorial are the same (`LibriSpeech`_). However, - you can change the testing set to any other domains (e.g `GigaSpeech`_) and prepare the language models + For simplicity, the training and testing corpus in this tutorial are the same (`LibriSpeech`_). However, + you can change the testing set to any other domains (e.g `GigaSpeech`_) and prepare the language models using that corpus. -First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here `_ +First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here `_ to address the language information mismatch between the training corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain are acoustically similar, DR derives the following formular for decoding with Bayes' theorem: .. math:: - \text{score}\left(y_u|\mathit{x},y\right) = - \log p\left(y_u|\mathit{x},y_{1:u-1}\right) + - \lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) - + \text{score}\left(y_u|\mathit{x},y\right) = + \log p\left(y_u|\mathit{x},y_{1:u-1}\right) + + \lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) - \lambda_2 \log p_{\text{Source LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively. -Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to +where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively. +Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to shallow fusion is the subtraction of the source domain LM. -Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is +Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is considered to be weak and can only capture low-level language information. Therefore, `LODR `__ proposed to use a low-order n-gram LM as an approximation of the ILM of the neural transducer. This leads to the following formula during decoding for transducer model: .. math:: - \text{score}\left(y_u|\mathit{x},y\right) = - \log p_{rnnt}\left(y_u|\mathit{x},y_{1:u-1}\right) + - \lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) - + \text{score}\left(y_u|\mathit{x},y\right) = + \log p_{rnnt}\left(y_u|\mathit{x},y_{1:u-1}\right) + + \lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) - \lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right) -In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR, +In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR, the only difference lies in the choice of source domain LM. According to the original `paper `_, LODR achieves similar performance compared DR in both intra-domain and cross-domain settings. As a bi-gram is much faster to evaluate, LODR is usually much faster. @@ -85,7 +85,7 @@ To test the model, let's have a look at the decoding results **without** using L --avg 1 \ --use-averaged-model False \ --exp-dir $exp_dir \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \ --max-duration 600 \ --decode-chunk-len 32 \ --decoding-method modified_beam_search @@ -99,17 +99,17 @@ The following WERs are achieved on test-clean and test-other: $ For test-other, WER of different settings are: $ beam_size_4 7.93 best for test-other -Then, we download the external language model and bi-gram LM that are necessary for LODR. +Then, we download the external language model and bi-gram LM that are necessary for LODR. Note that the bi-gram is estimated on the LibriSpeech 960 hours' text. .. code-block:: bash $ # download the external LM - $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm $ # create a symbolic link so that the checkpoint can be loaded $ pushd icefall-librispeech-rnn-lm/exp $ git lfs pull --include "pretrained.pt" - $ ln -s pretrained.pt epoch-99.pt + $ ln -s pretrained.pt epoch-99.pt $ popd $ $ # download the bi-gram @@ -122,7 +122,7 @@ Note that the bi-gram is estimated on the LibriSpeech 960 hours' text. Then, we perform LODR decoding by setting ``--decoding-method`` to ``modified_beam_search_lm_LODR``: .. code-block:: bash - + $ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp $ lm_dir=./icefall-librispeech-rnn-lm/exp $ lm_scale=0.42 @@ -135,8 +135,8 @@ Then, we perform LODR decoding by setting ``--decoding-method`` to ``modified_be --exp-dir $exp_dir \ --max-duration 600 \ --decode-chunk-len 32 \ - --decoding-method modified_beam_search_lm_LODR \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --decoding-method modified_beam_search_LODR \ + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \ --use-shallow-fusion 1 \ --lm-type rnn \ --lm-exp-dir $lm_dir \ @@ -181,4 +181,4 @@ indeed **further improves** the WER. We can do even better if we increase ``--be - 6.38 * - 12 - 2.4 - - 6.23 \ No newline at end of file + - 6.23 diff --git a/docs/source/decoding-with-langugage-models/rescoring.rst b/docs/source/decoding-with-langugage-models/rescoring.rst index ee2e2113c..02eba9129 100644 --- a/docs/source/decoding-with-langugage-models/rescoring.rst +++ b/docs/source/decoding-with-langugage-models/rescoring.rst @@ -48,7 +48,7 @@ As usual, we first test the model's performance without external LM. This can be --avg 1 \ --use-averaged-model False \ --exp-dir $exp_dir \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \ --max-duration 600 \ --decode-chunk-len 32 \ --decoding-method modified_beam_search @@ -101,7 +101,7 @@ is set to `False`. --max-duration 600 \ --decode-chunk-len 32 \ --decoding-method modified_beam_search_lm_rescore \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \ --use-shallow-fusion 0 \ --lm-type rnn \ --lm-exp-dir $lm_dir \ @@ -173,7 +173,7 @@ Then we can performn LM rescoring + LODR by changing the decoding method to `mod --max-duration 600 \ --decode-chunk-len 32 \ --decoding-method modified_beam_search_lm_rescore_LODR \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \ --use-shallow-fusion 0 \ --lm-type rnn \ --lm-exp-dir $lm_dir \ diff --git a/docs/source/decoding-with-langugage-models/shallow-fusion.rst b/docs/source/decoding-with-langugage-models/shallow-fusion.rst index 0d2837372..f15e3f1d9 100644 --- a/docs/source/decoding-with-langugage-models/shallow-fusion.rst +++ b/docs/source/decoding-with-langugage-models/shallow-fusion.rst @@ -46,7 +46,7 @@ To test the model, let's have a look at the decoding results without using LM. T --avg 1 \ --use-averaged-model False \ --exp-dir $exp_dir \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \ --max-duration 600 \ --decode-chunk-len 32 \ --decoding-method modified_beam_search @@ -95,7 +95,7 @@ To use shallow fusion for decoding, we can execute the following command: --max-duration 600 \ --decode-chunk-len 32 \ --decoding-method modified_beam_search_lm_shallow_fusion \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model + --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model \ --use-shallow-fusion 1 \ --lm-type rnn \ --lm-exp-dir $lm_dir \ diff --git a/docs/source/docker/img/docker-hub.png b/docs/source/docker/img/docker-hub.png new file mode 100644 index 000000000..a9e7715b0 Binary files /dev/null and b/docs/source/docker/img/docker-hub.png differ diff --git a/docs/source/docker/index.rst b/docs/source/docker/index.rst new file mode 100644 index 000000000..2c92a4cbc --- /dev/null +++ b/docs/source/docker/index.rst @@ -0,0 +1,17 @@ +.. _icefall_docker: + +Docker +====== + +This section describes how to use pre-built docker images to run `icefall`_. + +.. hint:: + + If you only have CPUs available, you can still use the pre-built docker + images. + +.. toctree:: + :maxdepth: 2 + + ./intro.rst + diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst new file mode 100644 index 000000000..b09247d85 --- /dev/null +++ b/docs/source/docker/intro.rst @@ -0,0 +1,171 @@ +Introduction +============= + +We have pre-built docker images hosted at the following address: + + ``_ + +.. figure:: img/docker-hub.png + :width: 600 + :align: center + +You can find the ``Dockerfile`` at ``_. + +We describe the following items in this section: + + - How to view available tags + - How to download pre-built docker images + - How to run the `yesno`_ recipe within a docker container on ``CPU`` + +View available tags +=================== + +You can use the following command to view available tags: + +.. code-block:: bash + + curl -s 'https://registry.hub.docker.com/v2/repositories/k2fsa/icefall/tags/'|jq '."results"[]["name"]' + +which will give you something like below: + +.. code-block:: bash + + "torch2.0.0-cuda11.7" + "torch1.12.1-cuda11.3" + "torch1.9.0-cuda10.2" + "torch1.13.0-cuda11.6" + +.. hint:: + + Available tags will be updated when there are new releases of `torch`_. + +Please select an appropriate combination of `torch`_ and CUDA. + +Download a docker image +======================= + +Suppose that you select the tag ``torch1.13.0-cuda11.6``, you can use +the following command to download it: + +.. code-block:: bash + + sudo docker image pull k2fsa/icefall:torch1.13.0-cuda11.6 + +Run a docker image with GPU +=========================== + +.. code-block:: bash + + sudo docker run --gpus all --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash + +Run a docker image with CPU +=========================== + +.. code-block:: bash + + sudo docker run --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash + +Run yesno within a docker container +=================================== + +After starting the container, the following interface is presented: + +.. code-block:: bash + + root@60c947eac59c:/workspace/icefall# + +It shows the current user is ``root`` and the current working directory +is ``/workspace/icefall``. + +Update the code +--------------- + +Please first run: + +.. code-block:: bash + + root@60c947eac59c:/workspace/icefall# git pull + +so that your local copy contains the latest code. + +Data preparation +---------------- + +Now we can use + +.. code-block:: bash + + root@60c947eac59c:/workspace/icefall# cd egs/yesno/ASR/ + +to switch to the ``yesno`` recipe and run + +.. code-block:: bash + + root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ./prepare.sh + +.. hint:: + + If you are running without GPU, it may report the following error: + + .. code-block:: bash + + File "/opt/conda/lib/python3.9/site-packages/k2/__init__.py", line 23, in + from _k2 import DeterminizeWeightPushingType + ImportError: libcuda.so.1: cannot open shared object file: No such file or directory + + We can use the following command to fix it: + + .. code-block:: bash + + root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ln -s /opt/conda/lib/stubs/libcuda.so /opt/conda/lib/stubs/libcuda.so.1 + +The logs of running ``./prepare.sh`` are listed below: + +.. literalinclude:: ./log/log-preparation.txt + +Training +-------- + +After preparing the data, we can start training with the following command + +.. code-block:: bash + + root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ./tdnn/train.py + +All of the training logs are given below: + +.. hint:: + + It is running on CPU and it takes only 16 seconds for this run. + +.. literalinclude:: ./log/log-train-2023-08-01-01-55-27 + + +Decoding +-------- + +After training, we can decode the trained model with + +.. code-block:: bash + + root@60c947eac59c:/workspace/icefall/egs/yesno/ASR# ./tdnn/decode.py + +The decoding logs are given below: + +.. code-block:: bash + + 2023-08-01 02:06:22,400 INFO [decode.py:263] Decoding started + 2023-08-01 02:06:22,400 INFO [decode.py:264] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lm_dir': PosixPath('data/lm'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'export': False, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '4c05309499a08454997adf500b56dcc629e35ae5', 'k2-git-date': 'Tue Jul 25 16:23:36 2023', 'lhotse-version': '1.16.0.dev+git.7640d663.clean', 'torch-version': '1.13.0', 'torch-cuda-available': False, 'torch-cuda-version': '11.6', 'python-version': '3.9', 'icefall-git-branch': 'master', 'icefall-git-sha1': '375520d-clean', 'icefall-git-date': 'Fri Jul 28 07:43:08 2023', 'icefall-path': '/workspace/icefall', 'k2-path': '/opt/conda/lib/python3.9/site-packages/k2/__init__.py', 'lhotse-path': '/opt/conda/lib/python3.9/site-packages/lhotse/__init__.py', 'hostname': '60c947eac59c', 'IP address': '172.17.0.2'}} + 2023-08-01 02:06:22,401 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt + 2023-08-01 02:06:22,403 INFO [decode.py:273] device: cpu + 2023-08-01 02:06:22,406 INFO [decode.py:291] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt'] + 2023-08-01 02:06:22,424 INFO [asr_datamodule.py:218] About to get test cuts + 2023-08-01 02:06:22,425 INFO [asr_datamodule.py:252] About to get test cuts + 2023-08-01 02:06:22,504 INFO [decode.py:204] batch 0/?, cuts processed until now is 4 + [W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware. + 2023-08-01 02:06:22,687 INFO [decode.py:241] The transcripts are stored in tdnn/exp/recogs-test_set.txt + 2023-08-01 02:06:22,688 INFO [utils.py:564] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ] + 2023-08-01 02:06:22,690 INFO [decode.py:249] Wrote detailed error stats to tdnn/exp/errs-test_set.txt + 2023-08-01 02:06:22,690 INFO [decode.py:316] Done! + +Congratulations! You have finished successfully running `icefall`_ within a docker container. diff --git a/docs/source/index.rst b/docs/source/index.rst index a7d365a15..0fa8fdd1c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,9 +21,11 @@ speech recognition recipes using `k2 `_. :caption: Contents: installation/index + docker/index faqs model-export/index + .. toctree:: :maxdepth: 3 @@ -38,4 +40,4 @@ speech recognition recipes using `k2 `_. .. toctree:: :maxdepth: 2 - decoding-with-langugage-models/index \ No newline at end of file + decoding-with-langugage-models/index diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index 534b674f9..5a034ef5b 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -3,6 +3,11 @@ Installation ============ +.. hint:: + + We also provide :ref:`icefall_docker` support, which has already setup + the environment for you. + .. hint:: We have a colab notebook guiding you step by step to setup the environment. diff --git a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py index 387c14acf..9caacb78b 100755 --- a/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aidatatang_200zh/ASR/local/compute_fbank_aidatatang_200zh.py @@ -32,7 +32,7 @@ import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): +def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, perturb_speed: bool = False): src_dir = Path("data/manifests/aidatatang_200zh") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -85,7 +85,8 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): recordings=m["recordings"], supervisions=m["supervisions"], ) - if "train" in partition: + if "train" in partition and perturb_speed: + logging.info(f"Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -109,7 +110,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) - + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -119,4 +125,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins) + compute_fbank_aidatatang_200zh( + num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + ) diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh index 46ecd5769..2eb0b3718 100755 --- a/egs/aidatatang_200zh/ASR/prepare.sh +++ b/egs/aidatatang_200zh/ASR/prepare.sh @@ -77,7 +77,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Compute fbank for aidatatang_200zh" if [ ! -f data/fbank/.aidatatang_200zh.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aidatatang_200zh.py + ./local/compute_fbank_aidatatang_200zh.py --perturb-speed True touch data/fbank/.aidatatang_200zh.done fi fi diff --git a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py index 037971927..6a9bb4f42 100755 --- a/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py +++ b/egs/aishell/ASR/local/compute_fbank_aidatatang_200zh.py @@ -32,7 +32,7 @@ import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): +def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80, perturb_speed: bool = False): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -85,7 +85,8 @@ def compute_fbank_aidatatang_200zh(num_mel_bins: int = 80): recordings=m["recordings"], supervisions=m["supervisions"], ) - if "train" in partition: + if "train" in partition and perturb_speed: + logging.info(f"Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -109,7 +110,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) - + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -119,4 +125,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_aidatatang_200zh(num_mel_bins=args.num_mel_bins) + compute_fbank_aidatatang_200zh( + num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + ) diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index 115ca1031..c7000da1c 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -32,7 +32,7 @@ import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell(num_mel_bins: int = 80): +def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -81,7 +81,8 @@ def compute_fbank_aishell(num_mel_bins: int = 80): recordings=m["recordings"], supervisions=m["supervisions"], ) - if "train" in partition: + if "train" in partition and perturb_speed: + logging.info(f"Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -104,7 +105,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) - + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -114,4 +120,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_aishell(num_mel_bins=args.num_mel_bins) + compute_fbank_aishell( + num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + ) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index b763d72c1..ff8e1301d 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -114,7 +114,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute fbank for aishell" if [ ! -f data/fbank/.aishell.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aishell.py + ./local/compute_fbank_aishell.py --perturb-speed True touch data/fbank/.aishell.done fi fi diff --git a/egs/aishell/ASR/prepare_aidatatang_200zh.sh b/egs/aishell/ASR/prepare_aidatatang_200zh.sh index f1d4d18a7..ec89450df 100755 --- a/egs/aishell/ASR/prepare_aidatatang_200zh.sh +++ b/egs/aishell/ASR/prepare_aidatatang_200zh.sh @@ -53,7 +53,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Process aidatatang_200zh" if [ ! -f data/fbank/.aidatatang_200zh_fbank.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aidatatang_200zh.py + ./local/compute_fbank_aidatatang_200zh.py --perturb-speed True touch data/fbank/.aidatatang_200zh_fbank.done fi fi diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export.py b/egs/aishell/ASR/pruned_transducer_stateless7/export.py deleted file mode 100755 index 1b0e8d3b9..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/export.py +++ /dev/null @@ -1,321 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Check -https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --epoch 20 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `pruned_transducer_stateless7/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless7/decode.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --lang-dir data/lang_char - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21 - # You will find the pre-trained model in icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21exp -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -import torch.nn as nn -from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt - - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export.py b/egs/aishell/ASR/pruned_transducer_stateless7/export.py new file mode 120000 index 000000000..2713792e6 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/export.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py deleted file mode 100644 index cc54027d6..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py +++ /dev/null @@ -1,348 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --epoch 20 \ - --avg 10 - -Usage of this script: - -(1) greedy search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by -./pruned_transducer_stateless7/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.lexicon import Lexicon -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=str, - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - token_table = lexicon.token_table - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp_tokens = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp_tokens = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py new file mode 120000 index 000000000..068f0f57f --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py \ No newline at end of file diff --git a/egs/aishell2/ASR/local/compute_fbank_aishell2.py b/egs/aishell2/ASR/local/compute_fbank_aishell2.py index ec0c584ca..1fb1621ff 100755 --- a/egs/aishell2/ASR/local/compute_fbank_aishell2.py +++ b/egs/aishell2/ASR/local/compute_fbank_aishell2.py @@ -32,7 +32,7 @@ import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell2(num_mel_bins: int = 80): +def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -81,7 +81,8 @@ def compute_fbank_aishell2(num_mel_bins: int = 80): recordings=m["recordings"], supervisions=m["supervisions"], ) - if "train" in partition: + if "train" in partition and perturb_speed: + logging.info(f"Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -104,6 +105,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -114,4 +121,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_aishell2(num_mel_bins=args.num_mel_bins) + compute_fbank_aishell2( + num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + ) diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh index 3e8e840ab..42631c864 100755 --- a/egs/aishell2/ASR/prepare.sh +++ b/egs/aishell2/ASR/prepare.sh @@ -101,7 +101,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute fbank for aishell2" if [ ! -f data/fbank/.aishell2.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aishell2.py + ./local/compute_fbank_aishell2.py --perturb-speed True touch data/fbank/.aishell2.done fi fi diff --git a/egs/aishell4/ASR/local/compute_fbank_aishell4.py b/egs/aishell4/ASR/local/compute_fbank_aishell4.py index 400c406f0..f19163988 100755 --- a/egs/aishell4/ASR/local/compute_fbank_aishell4.py +++ b/egs/aishell4/ASR/local/compute_fbank_aishell4.py @@ -32,7 +32,7 @@ import torch from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell4(num_mel_bins: int = 80): +def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False): src_dir = Path("data/manifests/aishell4") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -83,10 +83,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80): recordings=m["recordings"], supervisions=m["supervisions"], ) - if "train" in partition: + if "train" in partition and perturb_speed: + logging.info(f"Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) + cut_set = cut_set.compute_and_store_features( extractor=extractor, storage_path=f"{output_dir}/{prefix}_feats_{partition}", @@ -113,6 +115,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -123,4 +131,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_aishell4(num_mel_bins=args.num_mel_bins) + compute_fbank_aishell4( + num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + ) diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh index cb2b73a3e..1b1ec0005 100755 --- a/egs/aishell4/ASR/prepare.sh +++ b/egs/aishell4/ASR/prepare.sh @@ -107,7 +107,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Compute fbank for aishell4" if [ ! -f data/fbank/.aishell4.done ]; then mkdir -p data/fbank - ./local/compute_fbank_aishell4.py + ./local/compute_fbank_aishell4.py --perturb-speed True touch data/fbank/.aishell4.done fi fi diff --git a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py index 96115a230..f8c10648a 100755 --- a/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py +++ b/egs/alimeeting/ASR/local/compute_fbank_alimeeting.py @@ -32,7 +32,7 @@ import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_alimeeting(num_mel_bins: int = 80): +def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False): src_dir = Path("data/manifests/alimeeting") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -82,7 +82,8 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80): recordings=m["recordings"], supervisions=m["supervisions"], ) - if "train" in partition: + if "train" in partition and perturb_speed: + logging.info(f"Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -114,6 +115,12 @@ def get_args(): default=80, help="""The number of mel bins for Fbank""", ) + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) return parser.parse_args() @@ -124,4 +131,6 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() - compute_fbank_alimeeting(num_mel_bins=args.num_mel_bins) + compute_fbank_alimeeting( + num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + ) diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh index 604cc92c6..1709733c7 100755 --- a/egs/alimeeting/ASR/prepare.sh +++ b/egs/alimeeting/ASR/prepare.sh @@ -97,7 +97,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Compute fbank for alimeeting" if [ ! -f data/fbank/.alimeeting.done ]; then mkdir -p data/fbank - ./local/compute_fbank_alimeeting.py + ./local/compute_fbank_alimeeting.py --perturb-speed True touch data/fbank/.alimeeting.done fi fi diff --git a/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py b/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py index c6aa2ab36..833d11c72 100755 --- a/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py +++ b/egs/alimeeting/ASR_v2/local/compute_fbank_alimeeting.py @@ -25,6 +25,7 @@ It looks for manifests in the directory data/manifests. The generated fbank features are saved in data/fbank. """ +import argparse import logging from pathlib import Path @@ -39,6 +40,8 @@ from lhotse.features.kaldifeat import ( ) from lhotse.recipes.utils import read_manifests_if_cached +from icefall.utils import str2bool + # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect @@ -48,7 +51,7 @@ torch.set_num_interop_threads(1) torch.multiprocessing.set_sharing_strategy("file_system") -def compute_fbank_ami(): +def compute_fbank_ami(perturb_speed: bool = False): src_dir = Path("data/manifests") output_dir = Path("data/fbank") @@ -84,8 +87,12 @@ def compute_fbank_ami(): suffix="jsonl.gz", ) - def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None: - cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1) + def _extract_feats( + cuts: CutSet, storage_path: Path, manifest_path: Path, speed_perturb: bool + ) -> None: + if speed_perturb: + logging.info(f"Doing speed perturb") + cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1) _ = cuts.compute_and_store_features_batch( extractor=extractor, storage_path=storage_path, @@ -109,6 +116,7 @@ def compute_fbank_ami(): cuts_ihm, output_dir / "feats_train_ihm", src_dir / "cuts_train_ihm.jsonl.gz", + perturb_speed, ) logging.info("Processing train split IHM + reverberated IHM") @@ -117,6 +125,7 @@ def compute_fbank_ami(): cuts_ihm_rvb, output_dir / "feats_train_ihm_rvb", src_dir / "cuts_train_ihm_rvb.jsonl.gz", + perturb_speed, ) logging.info("Processing train split SDM") @@ -129,6 +138,7 @@ def compute_fbank_ami(): cuts_sdm, output_dir / "feats_train_sdm", src_dir / "cuts_train_sdm.jsonl.gz", + perturb_speed, ) logging.info("Processing train split GSS") @@ -141,6 +151,7 @@ def compute_fbank_ami(): cuts_gss, output_dir / "feats_train_gss", src_dir / "cuts_train_gss.jsonl.gz", + perturb_speed, ) logging.info("Preparing test cuts: IHM, SDM, GSS (optional)") @@ -186,8 +197,21 @@ def compute_fbank_ami(): ) +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + return parser.parse_args() + + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_ami() + args = get_args() + + compute_fbank_ami(perturb_speed=args.perturb_speed) diff --git a/egs/alimeeting/ASR_v2/prepare.sh b/egs/alimeeting/ASR_v2/prepare.sh index 76a108771..1098840f8 100755 --- a/egs/alimeeting/ASR_v2/prepare.sh +++ b/egs/alimeeting/ASR_v2/prepare.sh @@ -85,7 +85,7 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Compute fbank for alimeeting" mkdir -p data/fbank - python local/compute_fbank_alimeeting.py + python local/compute_fbank_alimeeting.py --perturb-speed True log "Combine features from train splits" lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\ gzip -c > data/manifests/cuts_train_all.jsonl.gz diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 1b8e690bd..b945f43fd 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -90,6 +90,11 @@ You can use to deploy it. | greedy_search | 2.23 | 4.96 | --epoch 40 --avg 16 | | modified_beam_search | 2.21 | 4.91 | --epoch 40 --avg 16 | | fast_beam_search | 2.24 | 4.93 | --epoch 40 --avg 16 | +| modified_beam_search_shallow_fusion | 2.01 | 4.37 | --epoch 40 --avg 16 --beam-size 12 --lm-scale 0.3 | +| modified_beam_search_LODR | 1.94 | 4.17 | --epoch 40 --avg 16 --beam-size 12 --lm-scale 0.52 --LODR-scale -0.26 | +| modified_beam_search_rescore | 2.04 | 4.39 | --epoch 40 --avg 16 --beam-size 12 | +| modified_beam_search_rescore_LODR | 2.01 | 4.33 | --epoch 40 --avg 16 --beam-size 12 | + The training command is: ```bash @@ -119,6 +124,8 @@ for m in greedy_search modified_beam_search fast_beam_search; do done ``` +To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html). + ##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M The tensorboard log can be found at diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index fbcbd7b29..f0bb97560 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -23,12 +23,13 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -63,11 +64,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="""It contains language related input files such as "lexicon.txt" - """, + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -98,16 +98,16 @@ def get_params() -> AttributeDict: def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) logging.info(params) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + num_classes = num_tokens(token_table) + 1 # +1 for the blank device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 30def9c40..df3e4d819 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -24,7 +24,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from conformer import Conformer @@ -70,11 +69,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -83,10 +80,9 @@ def get_parser(): default="1best", help="""Decoding method. Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a sentence - piece model, i.e., lang_dir/bpe.model, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. + (0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file + to convert tokens to actual words or characters. It needs + neither a lexicon nor an n-gram LM. (1) 1best - Use the best path as decoding output. Only the transformer encoder output is used for decoding. We call it HLG decoding. @@ -297,6 +293,7 @@ def main(): waves = [w.to(device) for w in waves] logging.info("Decoding started") + hyps = [] features = fbank(waves) features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) @@ -313,10 +310,17 @@ def main(): if params.method == "ctc-decoding": logging.info("Use CTC decoding") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.bpe_model) max_token_id = params.num_classes - 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + H = k2.ctc_topo( max_token=max_token_id, modified=params.num_classes > 500, @@ -337,9 +341,9 @@ def main(): best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) - hyps = [s.split() for s in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method in [ "1best", "whole-lattice-rescoring", @@ -408,16 +412,16 @@ def main(): ) best_path = next(iter(best_path_dict.values())) - hyps = get_texts(best_path) word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(" ".join([word_sym_table[i] for i in hyp])) else: raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/conformer_ctc2/export.py b/egs/librispeech/ASR/conformer_ctc2/export.py index 7892b03c6..26a95dbfa 100755 --- a/egs/librispeech/ASR/conformer_ctc2/export.py +++ b/egs/librispeech/ASR/conformer_ctc2/export.py @@ -23,6 +23,7 @@ Usage: ./conformer_ctc2/export.py \ --exp-dir ./conformer_ctc2/exp \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -46,6 +47,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from decode import get_params @@ -56,8 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -123,10 +124,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="The lang dir", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -143,14 +144,14 @@ def get_parser(): def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + num_classes = num_tokens(token_table) + 1 # +1 for the blank device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/librispeech/ASR/conformer_ctc3/export.py b/egs/librispeech/ASR/conformer_ctc3/export.py index c5b95d981..5cb9b4b6d 100755 --- a/egs/librispeech/ASR/conformer_ctc3/export.py +++ b/egs/librispeech/ASR/conformer_ctc3/export.py @@ -25,7 +25,7 @@ Usage: ./conformer_ctc3/export.py \ --exp-dir ./conformer_ctc3/exp \ - --lang-dir data/lang_bpe_500 \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 \ --jit-trace 1 @@ -36,7 +36,7 @@ It will generates the file: `jit_trace.pt`. ./conformer_ctc3/export.py \ --exp-dir ./conformer_ctc3/exp \ - --lang-dir data/lang_bpe_500 \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -62,6 +62,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_ctc_model, get_params @@ -72,8 +73,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -130,10 +130,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", + "--tokens", + type=str, + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -171,9 +171,10 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + num_classes = num_tokens(token_table) + 1 # +1 for the blank params.vocab_size = num_classes if params.streaming_model: diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py index 880945ea0..c37b99cce 100755 --- a/egs/librispeech/ASR/conformer_ctc3/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py @@ -24,7 +24,7 @@ Usage (for non-streaming mode): (1) ctc-decoding ./conformer_ctc3/pretrained.py \ --checkpoint conformer_ctc3/exp/pretrained.pt \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method ctc-decoding \ --sample-rate 16000 \ test_wavs/1089-134686-0001.wav @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from decode import get_decoding_params @@ -114,11 +113,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -127,10 +124,9 @@ def get_parser(): default="1best", help="""Decoding method. Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a sentence - piece model, i.e., lang_dir/bpe.model, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. + (0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file + to convert tokens to actual words or characters. It needs + neither a lexicon nor an n-gram LM. (1) 1best - Use the best path as decoding output. Only the transformer encoder output is used for decoding. We call it HLG decoding. @@ -316,6 +312,7 @@ def main(): waves = [w.to(device) for w in waves] logging.info("Decoding started") + hyps = [] features = fbank(waves) feature_lengths = [f.size(0) for f in features] @@ -348,10 +345,17 @@ def main(): if params.method == "ctc-decoding": logging.info("Use CTC decoding") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.bpe_model) max_token_id = params.num_classes - 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + H = k2.ctc_topo( max_token=max_token_id, modified=False, @@ -372,9 +376,9 @@ def main(): best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) - hyps = [s.split() for s in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method in [ "1best", "nbest-rescoring", @@ -439,16 +443,16 @@ def main(): ) best_path = next(iter(best_path_dict.values())) - hyps = get_texts(best_path) word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(" ".join([word_sym_table[i] for i in hyp])) else: raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 09a3e96b0..67fcc35a4 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -22,7 +22,7 @@ Usage: ./conv_emformer_transducer_stateless/export.py \ --exp-dir ./conv_emformer_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 10 \ --use-averaged-model=True \ @@ -62,7 +62,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model @@ -72,7 +72,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -118,10 +118,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -166,12 +166,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py index 8fbb02f14..85dbd4661 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py @@ -8,7 +8,7 @@ for more details about how to use this file. Usage: ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ --exp-dir ./conv_emformer_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 10 \ --use-averaged-model=True \ @@ -37,7 +37,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -48,7 +48,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -94,10 +94,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -217,12 +217,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py index ad0b45bd9..cfd365207 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py @@ -18,7 +18,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" cd exp @@ -28,7 +27,7 @@ popd 2. Export the model to ONNX ./conv_emformer_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -55,14 +54,14 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder +from emformer import Emformer from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model -from emformer import Emformer from icefall.checkpoint import ( average_checkpoints, @@ -70,7 +69,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -127,10 +126,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -484,12 +483,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index b53426c75..8e5b14903 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./conv_emformer_transducer_stateless2/export.py \ --exp-dir ./conv_emformer_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 10 \ --use-averaged-model=True \ @@ -62,7 +62,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -73,7 +73,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -119,10 +119,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -167,12 +167,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py index db92ac696..5d7e2dfcd 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./conv_emformer_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py index e338342cc..c007220d5 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -26,7 +26,7 @@ Usage: ./lstm_transducer_stateless/export.py \ --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 \ --jit-trace 1 @@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`, ./lstm_transducer_stateless/export.py \ --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 @@ -79,7 +79,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -91,7 +91,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -148,10 +148,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -266,12 +266,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index b3a34a9e3..119fcf1fd 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,6 +78,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -95,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -214,13 +215,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -275,6 +277,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -286,8 +294,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -296,16 +304,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -326,12 +334,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py index 08bfcb204..2b8c92208 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py @@ -29,7 +29,7 @@ popd ./lstm_transducer_stateless2/export-for-ncnn.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -49,7 +49,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -60,7 +60,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -106,10 +106,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -221,12 +221,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py index f068f6a0f..89ced388c 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py @@ -613,7 +613,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py index acaff8540..6b6cb893f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./lstm_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -52,8 +52,8 @@ import logging from pathlib import Path from typing import Dict, Optional, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder @@ -68,7 +68,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -125,10 +125,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -437,12 +437,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -607,7 +608,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 0adc68112..5712da25e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -27,7 +27,7 @@ Usage: ./lstm_transducer_stateless2/export.py \ --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 \ --jit-trace 1 @@ -39,7 +39,7 @@ It will generate 3 files: `encoder_jit_trace.pt`, ./lstm_transducer_stateless2/export.py \ --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 @@ -80,7 +80,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -92,7 +92,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -149,10 +149,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -267,12 +267,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index f3f272b9f..5d6d97320 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -69,7 +69,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -82,6 +81,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -98,9 +99,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -217,13 +218,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -278,6 +280,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -289,8 +297,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -299,16 +307,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -329,12 +337,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py index a82cad043..21eaa049b 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -26,7 +26,7 @@ Usage: ./lstm_transducer_stateless3/export.py \ --exp-dir ./lstm_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 40 \ --avg 20 \ --jit-trace 1 @@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`, ./lstm_transducer_stateless3/export.py \ --exp-dir ./lstm_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 40 \ --avg 20 @@ -79,7 +79,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -91,7 +91,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -148,10 +148,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to tokens.txt.", ) parser.add_argument( @@ -266,12 +266,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index f49e9c518..29a0d4d1a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./lstm_transducer_stateless3/pretrained.py \ --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./lstm_transducer_stateless3/pretrained.py \ --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./lstm_transducer_stateless3/pretrained.py \ --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -79,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -95,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -214,13 +216,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -275,6 +278,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -286,8 +295,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -296,16 +305,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -326,12 +335,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index 3612a2bfd..ec2c9d580 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -22,7 +22,7 @@ Usage: ./prunted_stateless_emformer_rnnt/export.py \ --exp-dir ./prunted_stateless_emformer_rnnt/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -48,7 +48,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model @@ -58,7 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -115,10 +115,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -154,13 +154,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py index a3ebe9d8c..282238c13 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py @@ -508,7 +508,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index a19f9ab9a..4b20e3a2b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless/export.py \ --exp-dir ./pruned_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -87,10 +87,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -135,13 +135,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index 2ed1725b4..02f9f1b03 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,7 +78,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -97,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -237,13 +236,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.simulate_streaming: assert ( @@ -314,6 +314,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -325,8 +331,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -335,16 +341,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -365,12 +371,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index 984caf5f2..e02afa892 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -98,10 +98,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -145,12 +145,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index 013964720..029f55ba0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,7 +78,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -97,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -238,13 +237,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.simulate_streaming: assert ( @@ -315,6 +315,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -326,8 +332,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -336,16 +342,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -366,12 +372,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py index 9645b7801..26dea7e11 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless3/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 9999 \ --avg 1 \ --exp-dir $repo/exp/ @@ -48,8 +48,8 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer @@ -59,7 +59,7 @@ from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import setup_logger +from icefall.utils import num_tokens, setup_logger def get_parser(): @@ -105,10 +105,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -393,12 +393,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -518,7 +520,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index f30c9df6a..925b15646 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 \ --jit 1 @@ -44,7 +44,7 @@ It will also generate 3 other files: `encoder_jit_script.pt`, ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 \ --jit-trace 1 @@ -56,7 +56,7 @@ It will generates 3 files: `encoder_jit_trace.pt`, ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -97,14 +97,14 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -150,10 +150,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -342,12 +342,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 7c3dfc660..abda4e2d4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +87,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -247,13 +246,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.simulate_streaming: assert ( @@ -324,6 +324,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -335,8 +341,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -345,16 +351,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -375,12 +381,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index 8f33f5b05..08d736f52 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless4/export.py \ --exp-dir ./pruned_transducer_stateless4/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -48,7 +48,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -59,7 +59,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -116,10 +116,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -164,12 +164,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py index 938ff2f16..549fb13c9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless5/export-onnx-streaming.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -58,13 +58,13 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer -from onnxruntime.quantization import QuantType, quantize_dynamic from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -74,7 +74,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -131,10 +131,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -489,12 +489,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -662,7 +664,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py index 20fd8dff8..fff0fcdd5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless5/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -55,13 +55,13 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer -from onnxruntime.quantization import QuantType, quantize_dynamic from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -71,7 +71,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -128,10 +128,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -416,12 +416,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -586,7 +588,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index 54f656859..e5223be26 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -48,7 +48,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -59,7 +59,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -116,10 +116,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -164,12 +164,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index 74a2210c3..304fa8693 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,6 +78,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -95,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -214,13 +215,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -275,6 +277,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -286,8 +294,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -296,16 +304,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -326,12 +334,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index 4d0d8326c..38f48b2ed 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless6/export.py \ --exp-dir ./pruned_transducer_stateless6/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -98,10 +98,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -135,12 +135,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py index d2db92820..11c885f4d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang +# Zengrui Jin) """ This script exports a transducer model from PyTorch to ONNX. @@ -18,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt" cd exp @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless7/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -50,8 +50,8 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder @@ -66,7 +66,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -123,10 +123,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -411,12 +410,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -581,7 +580,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 3e3160e7e..eb4c4d282 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang +# Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -26,7 +27,7 @@ Usage: ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +46,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -65,7 +66,7 @@ you can do: --avg 1 \ --max-duration 600 \ --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model + --tokens data/lang_bpe_500/tokens.txt \ Check ./pretrained.py for its usage. @@ -86,7 +87,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -98,7 +99,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -155,10 +156,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -198,12 +198,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -292,7 +292,7 @@ def main(): model.to("cpu") model.eval() - if params.jit is True: + if params.jit: convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index d05bafcfb..86c922cda 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -20,7 +21,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +30,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +38,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +47,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +56,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +76,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +88,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens def get_parser(): @@ -106,9 +106,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -225,13 +225,13 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -286,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -297,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -307,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -337,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py index c1607699f..51e62d6a8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -97,7 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -154,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -197,12 +197,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py index 2f1b1a49f..78e0fa778 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,6 +87,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -104,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -223,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -284,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -295,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -305,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -335,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py index 5d460edb5..904c1deae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py @@ -22,14 +22,14 @@ You can use the following command to get the exported models: ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 Usage of this script: (1) ctc-decoding -./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --bpe-model data/lang_bpe_500/bpe.model \ --method ctc-decoding \ @@ -38,7 +38,7 @@ Usage of this script: /path/to/bar.wav (2) 1best -./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -48,7 +48,7 @@ Usage of this script: /path/to/bar.wav (3) nbest-rescoring -./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./bruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -60,7 +60,7 @@ Usage of this script: (4) whole-lattice-rescoring -./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py index 05df8cfff..9f35cf63e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -97,7 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -154,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -197,12 +197,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py index 630a7f735..d3033b888 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py @@ -28,7 +28,7 @@ Usage: ./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 \ --onnx 1 @@ -48,7 +48,7 @@ Check `onnx_check.py` for how to use them. (2) Export to ONNX format which can be used in Triton Server ./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 \ --onnx-triton 1 @@ -86,9 +86,10 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn +from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -98,8 +99,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool -from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -156,10 +156,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -728,12 +728,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py index ea0fe9164..5d240cf30 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,6 +87,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -104,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -223,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -284,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -295,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -305,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -335,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py index 412631ba1..914107526 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py @@ -22,14 +22,14 @@ You can use the following command to get the exported models: ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 Usage of this script: (1) ctc-decoding -./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --bpe-model data/lang_bpe_500/bpe.model \ --method ctc-decoding \ @@ -38,7 +38,7 @@ Usage of this script: /path/to/bar.wav (2) 1best -./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -48,7 +48,7 @@ Usage of this script: /path/to/bar.wav (3) nbest-rescoring -./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./bruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -60,7 +60,7 @@ Usage of this script: (4) whole-lattice-rescoring -./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py index 3444f8193..02029c108 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -396,6 +396,12 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -907,6 +913,7 @@ def main(): ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") logging.info(f"lm filename: {ngram_file_name}") ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search elif params.decoding_method == "modified_beam_search_LODR": lm_filename = f"{params.tokens_ngram}gram.fst.txt" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py index e196f8b7d..07de57a86 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py @@ -66,6 +66,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -76,8 +77,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -123,10 +123,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="The tokens.txt file", ) parser.add_argument( @@ -246,9 +246,14 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py index 4a16a97fb..9a6b31268 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -28,7 +28,7 @@ popd 2. Export to ncnn ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 99 \ @@ -64,7 +64,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -75,7 +75,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -121,10 +121,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -244,12 +244,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py index 04d97808d..8653126de 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py @@ -29,7 +29,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless7_streaming/export-onnx-zh.py \ - --lang-dir $repo/data/lang_char_bpe \ + --tokens $repo/data/lang_char_bpe/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -60,6 +60,7 @@ import logging from pathlib import Path from typing import Dict, List, Tuple +import k2 import onnx import torch import torch.nn as nn @@ -76,8 +77,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -134,10 +134,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="The tokens.txt file", ) parser.add_argument( @@ -493,9 +493,14 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -661,7 +666,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py index e71bcaf29..6f84d79b4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -27,7 +27,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -48,8 +48,8 @@ import logging from pathlib import Path from typing import Dict, List, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder @@ -65,7 +65,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -122,10 +122,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -481,12 +481,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -652,7 +654,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py index c191b5bcc..59a7eb589 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -139,8 +139,8 @@ import argparse import logging from pathlib import Path +import k2 import onnxruntime -import sentencepiece as spm import torch import torch.nn as nn from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner @@ -154,7 +154,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -211,10 +211,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -675,12 +675,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py index fb77fdd42..bc42e8d05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +87,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -225,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -286,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -297,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -307,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -337,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py index 4a16a97fb..9a6b31268 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py @@ -28,7 +28,7 @@ popd 2. Export to ncnn ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 99 \ @@ -64,7 +64,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -75,7 +75,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -121,10 +121,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -244,12 +244,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py index d4a228b47..d9697680b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless8/export.py \ --exp-dir ./pruned_transducer_stateless8/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless8/export.py \ --exp-dir ./pruned_transducer_stateless8/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -98,7 +98,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -155,10 +155,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -198,12 +198,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 486d9d74e..64b38c9d5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless8/export.py \ --exp-dir ./pruned_transducer_stateless8/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +87,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -225,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -286,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -297,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -307,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -337,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py index 6db0272f0..3b9e4a5dc 100755 --- a/egs/librispeech/ASR/transducer/export.py +++ b/egs/librispeech/ASR/transducer/export.py @@ -22,7 +22,7 @@ Usage: ./transducer/export.py \ --exp-dir ./transducer/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 34 \ --avg 11 @@ -46,7 +46,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from conformer import Conformer from decoder import Decoder @@ -55,7 +55,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -90,10 +90,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -191,12 +191,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index 511610245..c2413f5de 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -19,7 +19,7 @@ Usage: ./transducer/pretrained.py \ --checkpoint ./transducer/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav \ @@ -36,8 +36,8 @@ import logging import math from typing import List +import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import beam_search, greedy_search @@ -48,7 +48,7 @@ from model import Transducer from torch.nn.utils.rnn import pad_sequence from icefall.env import get_env_info -from icefall.utils import AttributeDict +from icefall.utils import AttributeDict, num_tokens def get_parser(): @@ -66,11 +66,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="Path to tokens.txt.", ) parser.add_argument( @@ -204,12 +202,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -257,6 +257,12 @@ def main(): x=features, x_lens=feature_lengths ) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + num_waves = encoder_out.size(0) hyps = [] for i in range(num_waves): @@ -272,12 +278,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py index 89359f1a4..c397eb171 100755 --- a/egs/librispeech/ASR/transducer_stateless/export.py +++ b/egs/librispeech/ASR/transducer_stateless/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless/export.py \ --exp-dir ./transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -46,7 +46,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -56,7 +56,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -91,10 +91,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -191,12 +191,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index 915a6069d..5898dd0f5 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -47,7 +47,7 @@ Usage: (4) fast beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -80,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -96,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -213,12 +214,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -273,6 +276,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_list = fast_beam_search_one_best( @@ -318,12 +327,11 @@ def main(): raise ValueError(f"Unsupported method: {params.method}") hyp_list.append(hyp) - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps = [token_ids_to_words(hyp) for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py index d33d02642..f4b6f5554 100755 --- a/egs/librispeech/ASR/transducer_stateless2/export.py +++ b/egs/librispeech/ASR/transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless2/export.py \ --exp-dir ./transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -46,12 +46,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -86,10 +86,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -123,12 +123,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py index 0738f30c0..b69b347ef 100755 --- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -47,7 +47,7 @@ Usage: (4) fast beam search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -80,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -96,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -213,12 +214,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -273,6 +276,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_list = fast_beam_search_one_best( @@ -318,12 +327,11 @@ def main(): raise ValueError(f"Unsupported method: {params.method}") hyp_list.append(hyp) - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps = [token_ids_to_words(hyp) for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py index 3735ef452..6d31dfe34 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless_multi_datasets/export.py \ --exp-dir ./transducer_stateless_multi_datasets/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,7 +47,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -57,7 +57,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -92,10 +92,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -192,12 +192,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py index 8c7726367..4f29d6f1f 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -47,7 +47,7 @@ Usage: (4) fast beam search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -80,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -96,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -213,12 +214,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -273,6 +276,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_list = fast_beam_search_one_best( @@ -318,12 +327,11 @@ def main(): raise ValueError(f"Unsupported method: {params.method}") hyp_list.append(hyp) - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps = [token_ids_to_words(hyp) for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 93680602e..2cc157e7a 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -115,9 +115,14 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, ) -from train import add_model_arguments, get_params, get_model +from train import add_model_arguments, get_model, get_params +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -273,8 +278,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -302,6 +306,47 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + add_model_arguments(parser) return parser @@ -314,6 +359,9 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -342,6 +390,12 @@ def decode_one_batch( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -425,10 +479,7 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -445,6 +496,50 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) else: batch_size = encoder_out.size(0) @@ -483,6 +578,16 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} + elif params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"beam_size_{params.beam_size}_{key}"] = hyps + return ans else: return {f"beam_size_{params.beam_size}": hyps} @@ -494,6 +599,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -543,6 +651,9 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, ) for name, hyps in hyps_dict.items(): @@ -559,9 +670,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -594,8 +703,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -614,6 +722,7 @@ def save_results( def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -628,6 +737,10 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", ) params.res_dir = params.exp_dir / params.decoding_method @@ -656,13 +769,19 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -690,9 +809,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -719,9 +838,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -768,6 +887,54 @@ def main(): model.to(device) model.eval() + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -780,9 +947,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None word_table = None @@ -811,6 +976,9 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, ) save_results( diff --git a/egs/librispeech/ASR/zipformer/decode_stream.py b/egs/librispeech/ASR/zipformer/decode_stream.py index 946db275c..d6918bf32 100644 --- a/egs/librispeech/ASR/zipformer/decode_stream.py +++ b/egs/librispeech/ASR/zipformer/decode_stream.py @@ -79,12 +79,12 @@ class DecodeStream(object): self.pad_length = 7 + 2 * 3 if params.decoding_method == "greedy_search": - self.hyp = [params.blank_id] * params.context_size + self.hyp = [-1] * (params.context_size - 1) + [params.blank_id] elif params.decoding_method == "modified_beam_search": self.hyps = HypothesisList() self.hyps.add( Hypothesis( - ys=[params.blank_id] * params.context_size, + ys=[-1] * (params.context_size - 1) + [params.blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), ) ) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 3eb06f68c..a951aeef3 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp @@ -74,7 +73,6 @@ import onnx import torch import torch.nn as nn from decoder import Decoder -from export import num_tokens from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -86,7 +84,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index 724fdd2a6..e0d664009 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp @@ -71,7 +70,6 @@ import onnx import torch import torch.nn as nn from decoder import Decoder -from export import num_tokens from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -83,7 +81,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import make_pad_mask, str2bool +from icefall.utils import make_pad_mask, num_tokens, str2bool def get_parser(): diff --git a/egs/librispeech/ASR/zipformer/export.py b/egs/librispeech/ASR/zipformer/export.py index 4a48d5bad..2b8d1aaf3 100755 --- a/egs/librispeech/ASR/zipformer/export.py +++ b/egs/librispeech/ASR/zipformer/export.py @@ -160,7 +160,6 @@ with the following commands: import argparse import logging -import re from pathlib import Path from typing import List, Tuple @@ -176,27 +175,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import make_pad_mask, str2bool - - -def num_tokens( - token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") -) -> int: - """Return the number of tokens excluding those from - disambiguation symbols. - - Caution: - 0 is not a token ID so it is excluded from the return value. - """ - symbols = token_table.symbols - ans = [] - for s in symbols: - if not disambig_pattern.match(s): - ans.append(token_table[s]) - num_tokens = len(ans) - if 0 in ans: - num_tokens -= 1 - return num_tokens +from icefall.utils import make_pad_mask, num_tokens, str2bool def get_parser(): @@ -487,6 +466,8 @@ def main(): device=device, ) ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: assert params.avg > 0, params.avg start = params.epoch - params.avg diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py index 904d8cd76..660a4bfc6 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py @@ -410,10 +410,20 @@ def main(): raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" + if params.method == "ctc-decoding": + for filename, hyp in zip(params.sound_files, hyps): + words = "".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/zipformer/onnx_check.py b/egs/librispeech/ASR/zipformer/onnx_check.py index b38b875d0..93bd3a211 100755 --- a/egs/librispeech/ASR/zipformer/onnx_check.py +++ b/egs/librispeech/ASR/zipformer/onnx_check.py @@ -33,7 +33,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py index 2ce4506a8..500b2cd09 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained.pt" cd exp @@ -29,7 +28,7 @@ popd 2. Export the model to ONNX ./zipformer/export-onnx-streaming.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained.py b/egs/librispeech/ASR/zipformer/onnx_pretrained.py index e8a521460..032b07721 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained.py @@ -31,7 +31,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py index be239e9c3..9dff2e6fc 100755 --- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py @@ -274,7 +274,7 @@ def main(): params.update(vars(args)) token_table = k2.SymbolTable.from_file(params.tokens) - params.vocab_size = num_tokens(token_table) + params.vocab_size = num_tokens(token_table) + 1 # +1 for blank params.blank_id = token_table[""] assert params.blank_id == 0 @@ -429,10 +429,20 @@ def main(): raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" + if params.method == "ctc-decoding": + for filename, hyp in zip(params.sound_files, hyps): + words = "".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/zipformer_mmi/export.py b/egs/librispeech/ASR/zipformer_mmi/export.py index 0af7bd367..1aec56420 100755 --- a/egs/librispeech/ASR/zipformer_mmi/export.py +++ b/egs/librispeech/ASR/zipformer_mmi/export.py @@ -26,7 +26,7 @@ Usage: ./zipformer_mmi/export.py \ --exp-dir ./zipformer_mmi/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./zipformer_mmi/export.py \ --exp-dir ./zipformer_mmi/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_ctc_model, get_params @@ -97,7 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -154,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -190,12 +190,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py index 0e7fd0daf..3ba4da5dd 100755 --- a/egs/librispeech/ASR/zipformer_mmi/pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py @@ -21,7 +21,7 @@ You can generate the checkpoint with the following command: ./zipformer_mmi/export.py \ --exp-dir ./zipformer_mmi/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -30,14 +30,14 @@ Usage of this script: (1) 1best ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method 1best \ /path/to/foo.wav \ /path/to/bar.wav (2) nbest ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest \ /path/to/foo.wav \ @@ -45,7 +45,7 @@ Usage of this script: (3) nbest-rescoring-LG ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest-rescoring-LG \ /path/to/foo.wav \ @@ -53,7 +53,7 @@ Usage of this script: (4) nbest-rescoring-3-gram ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest-rescoring-3-gram \ /path/to/foo.wav \ @@ -61,7 +61,7 @@ Usage of this script: (5) nbest-rescoring-4-gram ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest-rescoring-4-gram \ /path/to/foo.wav \ @@ -83,7 +83,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from decode import get_decoding_params @@ -97,7 +96,7 @@ from icefall.decode import ( one_best_decoding, ) from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import get_texts +from icefall.utils import get_texts, num_tokens def get_parser(): @@ -115,9 +114,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -247,13 +246,14 @@ def main(): params.update(get_decoding_params()) params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -298,8 +298,6 @@ def main(): features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) mmi_graph_compiler = MmiTrainingGraphCompiler( params.lang_dir, uniq_filename="lexicon.txt", @@ -313,6 +311,12 @@ def main(): if not hasattr(HP, "lm_scores"): HP.lm_scores = HP.scores.clone() + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + method = params.method assert method in ( "1best", @@ -390,14 +394,11 @@ def main(): # # token_ids is a lit-of-list of IDs token_ids = get_texts(best_path) - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] + hyps = [token_ids_to_words(ids) for ids in token_ids] + s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py index 93ce750f8..5de3c23a9 100755 --- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py +++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import logging import re from pathlib import Path @@ -24,6 +25,7 @@ from lhotse import CutSet, SupervisionSegment from lhotse.recipes.utils import read_manifests_if_cached from icefall import setup_logger +from icefall.utils import str2bool # Similar text filtering and normalization procedure as in: # https://github.com/SpeechColab/WenetSpeech/blob/main/toolkits/kaldi/wenetspeech_data_prep.sh @@ -45,7 +47,7 @@ def has_no_oov( return oov_pattern.search(sup.text) is None -def preprocess_wenet_speech(): +def preprocess_wenet_speech(perturb_speed: bool = False): src_dir = Path("data/manifests") output_dir = Path("data/fbank") output_dir.mkdir(exist_ok=True) @@ -110,7 +112,7 @@ def preprocess_wenet_speech(): ) # Run data augmentation that needs to be done in the # time domain. - if partition not in ["DEV", "TEST_NET", "TEST_MEETING"]: + if partition not in ["DEV", "TEST_NET", "TEST_MEETING"] and perturb_speed: logging.info( f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" @@ -120,10 +122,22 @@ def preprocess_wenet_speech(): cut_set.to_file(raw_cuts_path) +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + return parser.parse_args() + + def main(): setup_logger(log_filename="./log-preprocess-wenetspeech") - preprocess_wenet_speech() + args = get_args() + preprocess_wenet_speech(perturb_speed=args.perturb_speed) logging.info("Done") diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index f7b521794..097a59a5f 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -91,7 +91,7 @@ fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Preprocess WenetSpeech manifest" if [ ! -f data/fbank/.preprocess_complete ]; then - python3 ./local/preprocess_wenetspeech.py + python3 ./local/preprocess_wenetspeech.py --perturb-speed True touch data/fbank/.preprocess_complete fi fi diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py index fad66986b..760fad974 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py @@ -498,7 +498,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py index bc499f3dd..c3d67ad92 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -320,7 +320,7 @@ def main(): s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) + words = "".join(hyp) s += f"{filename}:\n{words}\n\n" logging.info(s) diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index d5efb41df..f520607af 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -65,7 +65,6 @@ def get_params() -> AttributeDict: { "exp_dir": Path("tdnn/exp/"), "lang_dir": Path("data/lang_phone"), - "lm_dir": Path("data/lm"), "feature_dim": 23, "search_beam": 20, "output_beam": 8, diff --git a/egs/yesno/ASR/tdnn/export.py b/egs/yesno/ASR/tdnn/export.py new file mode 100755 index 000000000..c40cf8cd1 --- /dev/null +++ b/egs/yesno/ASR/tdnn/export.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +""" +This file is for exporting trained models to a checkpoint +or to a torchscript model. + +(1) Generate the checkpoint tdnn/exp/pretrained.pt + +./tdnn/export.py \ + --epoch 14 \ + --avg 2 + +See ./tdnn/pretrained.py for how to use the generated file. + +(2) Generate torchscript model tdnn/exp/cpu_jit.pt + +./tdnn/export.py \ + --epoch 14 \ + --avg 2 \ + --jit 1 + +See ./tdnn/jit_pretrained.py for how to use the generated file. +""" + +import argparse +import logging + +import torch +from model import Tdnn +from train import get_params + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=14, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=2, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + + model = Tdnn( + num_features=params.feature_dim, + num_classes=max_token_id + 1, # +1 for the blank symbol + ) + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/export_onnx.py b/egs/yesno/ASR/tdnn/export_onnx.py new file mode 100755 index 000000000..2436ca81b --- /dev/null +++ b/egs/yesno/ASR/tdnn/export_onnx.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 + +""" +This file is for exporting trained models to onnx. + +Usage: + + ./tdnn/export_onnx.py \ + --epoch 14 \ + --avg 2 + +The above command generates the following two files: + - ./exp/model-epoch-14-avg-2.onnx + - ./exp/model-epoch-14-avg-2.int8.onnx + +See ./tdnn/onnx_pretrained.py for how to use them. +""" + +import argparse +import logging +from typing import Dict + +import onnx +import torch +from model import Tdnn +from onnxruntime.quantization import QuantType, quantize_dynamic +from train import get_params + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=14, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=2, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + + model = Tdnn( + num_features=params.feature_dim, + num_classes=max_token_id + 1, # +1 for the blank symbol + ) + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to("cpu") + model.eval() + + N = 1 + T = 100 + C = params.feature_dim + x = torch.rand(N, T, C) + + opset_version = 13 + onnx_filename = f"{params.exp_dir}/model-epoch-{params.epoch}-avg-{params.avg}.onnx" + torch.onnx.export( + model, + x, + onnx_filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["log_prob"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "log_prob": {0: "N", 1: "T"}, + }, + ) + + logging.info(f"Saved to {onnx_filename}") + meta_data = { + "model_type": "tdnn", + "version": "1", + "model_author": "k2-fsa", + "comment": "non-streaming tdnn for the yesno recipe", + "vocab_size": max_token_id + 1, + } + + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=onnx_filename, meta_data=meta_data) + + logging.info("Generate int8 quantization models") + onnx_filename_int8 = ( + f"{params.exp_dir}/model-epoch-{params.epoch}-avg-{params.avg}.int8.onnx" + ) + + quantize_dynamic( + model_input=onnx_filename, + model_output=onnx_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + logging.info(f"Saved to {onnx_filename_int8}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py new file mode 100755 index 000000000..84390fca5 --- /dev/null +++ b/egs/yesno/ASR/tdnn/jit_pretrained.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a torchscript model for decoding. + +Usage: + + ./tdnn/jit_pretrained.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +Note that to generate ./tdnn/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +from typing import List +import math + + +import k2 +import kaldifeat +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +from icefall.decode import get_lattice, one_best_decoding +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./tdnn/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 23, + "num_classes": 4, # [, N, SIL, Y] + "sample_rate": 8000, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + # Note: We don't use key padding mask for attention during decoding + nnet_output = model(features) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/onnx_pretrained.py b/egs/yesno/ASR/tdnn/onnx_pretrained.py new file mode 100755 index 000000000..626473b6e --- /dev/null +++ b/egs/yesno/ASR/tdnn/onnx_pretrained.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use an ONNX model for decoding with onnxruntime. + +Usage: + +(1) Use a not quantized ONNX model, i.e., a float32 model + ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +(2) Use a quantized ONNX model, i.e., an int8 model + + ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +Note that to generate ./tdnn/exp/model-epoch-14-avg-2.onnx, +and ./tdnn/exp/model-epoch-14-avg-2.onnx, +you can use ./export_onnx.py --epoch 14 --avg 2 +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +from icefall.decode import get_lattice, one_best_decoding +from icefall.utils import AttributeDict, get_texts + + +class OnnxModel: + def __init__(self, nn_model: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + ) + + meta = self.model.get_modelmeta().custom_metadata_map + self.vocab_size = int(meta["vocab_size"]) + + def run( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor log_prob of shape (N, T, C) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + }, + ) + return torch.from_numpy(out[0]) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./tdnn/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 23, + "sample_rate": 8000, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def main(): + parser = get_parser() + args = parser.parse_args() + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"device: {device}") + + logging.info(f"Loading onnx model {params.nn_model}") + model = OnnxModel(params.nn_model) + + logging.info(f"Loading HLG from {args.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + # Note: We don't use key padding mask for attention during decoding + nnet_output = model.run(features) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index 65be77db1..987c49de6 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -15,6 +15,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +This file shows how to use a checkpoint for decoding. + +Usage: + + ./tdnn/pretrained.py \ + --checkpoint ./tdnn/exp/pretrained.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +Note that to generate ./tdnn/exp/pretrained.pt, +you can use ./export.py +""" import argparse import logging @@ -43,7 +58,8 @@ def get_parser(): required=True, help="Path to the checkpoint. " "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + "icefall.checkpoint.save_checkpoint(). " + "You can use ./tdnn/export.py to obtain it.", ) parser.add_argument( @@ -61,8 +77,7 @@ def get_parser(): nargs="+", help="The input sound file(s) to transcribe. " "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + "For example, wav and flac are supported. ", ) return parser @@ -99,14 +114,19 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + # We use only the first channel - ans.append(wave[0]) + ans.append(wave[0].contiguous()) return ans +@torch.no_grad() def main(): parser = get_parser() args = parser.parse_args() @@ -159,8 +179,7 @@ def main(): features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) # Note: We don't use key padding mask for attention during decoding - with torch.no_grad(): - nnet_output = model(features) + nnet_output = model(features) batch_size = nnet_output.shape[0] supervision_segments = torch.tensor( diff --git a/icefall/utils.py b/icefall/utils.py index 0feff9dc8..b01cd2770 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -2060,3 +2060,23 @@ def symlink_or_copy(exp_dir: Path, src: str, dst: str): except OSError: copyfile(src=exp_dir / src, dst=exp_dir / dst) os.close(dir_fd) + + +def num_tokens( + token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") +) -> int: + """Return the number of tokens excluding those from + disambiguation symbols. + + Caution: + 0 is not a token ID so it is excluded from the return value. + """ + symbols = token_table.symbols + ans = [] + for s in symbols: + if not disambig_pattern.match(s): + ans.append(token_table[s]) + num_tokens = len(ans) + if 0 in ans: + num_tokens -= 1 + return num_tokens