CI for streaming zipformer CTC + HLG decoding

This commit is contained in:
Fangjun Kuang 2024-03-18 19:44:17 +08:00
parent 557bf292a2
commit d1410c52e7
2 changed files with 51 additions and 7 deletions

View File

@ -64,6 +64,46 @@ function run_diagnostics() {
--print-diagnostics 1 --print-diagnostics 1
} }
function test_streaming_zipformer_ctc_hlg() {
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18
log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
repo=$(basename $repo_url)
rm $repo/exp-ctc-rnnt-small/*.onnx
ls -lh $repo/exp-ctc-rnnt-small
# export models to onnx
./zipformer/export-onnx-streaming-ctc.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 3 \
--exp-dir $repo/exp-ctc-rnnt-small \
--causal 1 \
--use-ctc 1 \
--chunk-size 16 \
--left-context-frames 128 \
\
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192
ls -lh $repo/exp-ctc-rnnt-small
for wav in 0.wav 1.wav 8k.wav; do
python3 ./zipformer/onnx_pretrained_ctc_HLG_streaming.py \
--nn-model $repo/exp-ctc-rnnt-small/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
--words $repo/data/lang_bpe_500/words.txt \
--HLG $repo/data/lang_bpe_500/HLG.fst \
$repo/test_wavs/$wav
done
rm -rf $repo
}
function test_pruned_transducer_stateless_2022_03_12() { function test_pruned_transducer_stateless_2022_03_12() {
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
@ -1577,6 +1617,7 @@ function test_transducer_bpe_500_2021_12_23() {
prepare_data prepare_data
run_diagnostics run_diagnostics
test_streaming_zipformer_ctc_hlg
test_pruned_transducer_stateless_2022_03_12 test_pruned_transducer_stateless_2022_03_12
test_pruned_transducer_stateless2_2022_04_29 test_pruned_transducer_stateless2_2022_04_29
test_pruned_transducer_stateless3_2022_04_29 test_pruned_transducer_stateless3_2022_04_29

View File

@ -27,10 +27,10 @@ popd
2. Export the model to ONNX 2. Export the model to ONNX
./zipformer/export-onnx-streaming-ctc.py \ ./zipformer/export-onnx-streaming-ctc.py \
--tokens ./data/lang_bpe_500/tokens.txt \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 3 \ --avg 3 \
--exp-dir zipformer/exp-ctc-rnnt-small \ --exp-dir $repo/exp-ctc-rnnt-small \
--causal 1 \ --causal 1 \
--use-ctc 1 \ --use-ctc 1 \
--chunk-size 16 \ --chunk-size 16 \
@ -107,8 +107,7 @@ def get_parser():
type=str, type=str,
help="The input sound file to transcribe. " help="The input sound file to transcribe. "
"Supported formats are those supported by torchaudio.load(). " "Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. " "For example, wav and flac are supported. ",
"The sample rate has to be 16kHz.",
) )
return parser return parser
@ -311,9 +310,13 @@ def read_sound_files(
ans = [] ans = []
for f in filenames: for f in filenames:
wave, sample_rate = torchaudio.load(f) wave, sample_rate = torchaudio.load(f)
assert ( if sample_rate != expected_sample_rate:
sample_rate == expected_sample_rate logging.info(f"Resample {sample_rate} to {expected_sample_rate}")
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" wave = torchaudio.functional.resample(
wave,
orig_freq=sample_rate,
new_freq=expected_sample_rate,
)
# We use only the first channel # We use only the first channel
ans.append(wave[0].contiguous()) ans.append(wave[0].contiguous())
return ans return ans