mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
CI for streaming zipformer CTC + HLG decoding
This commit is contained in:
parent
557bf292a2
commit
d1410c52e7
41
.github/scripts/librispeech/ASR/run.sh
vendored
41
.github/scripts/librispeech/ASR/run.sh
vendored
@ -64,6 +64,46 @@ function run_diagnostics() {
|
|||||||
--print-diagnostics 1
|
--print-diagnostics 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function test_streaming_zipformer_ctc_hlg() {
|
||||||
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18
|
||||||
|
|
||||||
|
log "Downloading pre-trained model from $repo_url"
|
||||||
|
git lfs install
|
||||||
|
git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
rm $repo/exp-ctc-rnnt-small/*.onnx
|
||||||
|
ls -lh $repo/exp-ctc-rnnt-small
|
||||||
|
|
||||||
|
# export models to onnx
|
||||||
|
./zipformer/export-onnx-streaming-ctc.py \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 3 \
|
||||||
|
--exp-dir $repo/exp-ctc-rnnt-small \
|
||||||
|
--causal 1 \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 128 \
|
||||||
|
\
|
||||||
|
--num-encoder-layers 2,2,2,2,2,2 \
|
||||||
|
--feedforward-dim 512,768,768,768,768,768 \
|
||||||
|
--encoder-dim 192,256,256,256,256,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,192,192,192,192
|
||||||
|
|
||||||
|
ls -lh $repo/exp-ctc-rnnt-small
|
||||||
|
|
||||||
|
for wav in 0.wav 1.wav 8k.wav; do
|
||||||
|
python3 ./zipformer/onnx_pretrained_ctc_HLG_streaming.py \
|
||||||
|
--nn-model $repo/exp-ctc-rnnt-small/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
|
||||||
|
--words $repo/data/lang_bpe_500/words.txt \
|
||||||
|
--HLG $repo/data/lang_bpe_500/HLG.fst \
|
||||||
|
$repo/test_wavs/$wav
|
||||||
|
done
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
}
|
||||||
|
|
||||||
function test_pruned_transducer_stateless_2022_03_12() {
|
function test_pruned_transducer_stateless_2022_03_12() {
|
||||||
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
|
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
|
||||||
|
|
||||||
@ -1577,6 +1617,7 @@ function test_transducer_bpe_500_2021_12_23() {
|
|||||||
|
|
||||||
prepare_data
|
prepare_data
|
||||||
run_diagnostics
|
run_diagnostics
|
||||||
|
test_streaming_zipformer_ctc_hlg
|
||||||
test_pruned_transducer_stateless_2022_03_12
|
test_pruned_transducer_stateless_2022_03_12
|
||||||
test_pruned_transducer_stateless2_2022_04_29
|
test_pruned_transducer_stateless2_2022_04_29
|
||||||
test_pruned_transducer_stateless3_2022_04_29
|
test_pruned_transducer_stateless3_2022_04_29
|
||||||
|
@ -27,10 +27,10 @@ popd
|
|||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./zipformer/export-onnx-streaming-ctc.py \
|
./zipformer/export-onnx-streaming-ctc.py \
|
||||||
--tokens ./data/lang_bpe_500/tokens.txt \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 3 \
|
--avg 3 \
|
||||||
--exp-dir zipformer/exp-ctc-rnnt-small \
|
--exp-dir $repo/exp-ctc-rnnt-small \
|
||||||
--causal 1 \
|
--causal 1 \
|
||||||
--use-ctc 1 \
|
--use-ctc 1 \
|
||||||
--chunk-size 16 \
|
--chunk-size 16 \
|
||||||
@ -107,8 +107,7 @@ def get_parser():
|
|||||||
type=str,
|
type=str,
|
||||||
help="The input sound file to transcribe. "
|
help="The input sound file to transcribe. "
|
||||||
"Supported formats are those supported by torchaudio.load(). "
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
"For example, wav and flac are supported. "
|
"For example, wav and flac are supported. ",
|
||||||
"The sample rate has to be 16kHz.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -311,9 +310,13 @@ def read_sound_files(
|
|||||||
ans = []
|
ans = []
|
||||||
for f in filenames:
|
for f in filenames:
|
||||||
wave, sample_rate = torchaudio.load(f)
|
wave, sample_rate = torchaudio.load(f)
|
||||||
assert (
|
if sample_rate != expected_sample_rate:
|
||||||
sample_rate == expected_sample_rate
|
logging.info(f"Resample {sample_rate} to {expected_sample_rate}")
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
wave = torchaudio.functional.resample(
|
||||||
|
wave,
|
||||||
|
orig_freq=sample_rate,
|
||||||
|
new_freq=expected_sample_rate,
|
||||||
|
)
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0].contiguous())
|
ans.append(wave[0].contiguous())
|
||||||
return ans
|
return ans
|
||||||
|
Loading…
x
Reference in New Issue
Block a user