From 557b35cefc9398827cef3a779edb5aa8a6ebfea8 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Mon, 15 Jan 2024 20:40:44 +0800 Subject: [PATCH] clean codes --- .../ASR/local/compute_fbank_aishell.py | 18 +- .../local/compute_whisper_fbank_aishell.py | 125 ----- .../ASR/local/compute_whisper_fbank_musan.py | 109 ----- egs/aishell/ASR/prepare.sh | 427 +++++++++--------- egs/aishell/ASR/whisper/decode.py | 40 +- egs/aishell/ASR/whisper/model.py | 13 +- egs/aishell/ASR/whisper/train.py | 80 +--- .../ASR/local/compute_fbank_musan.py | 29 +- 8 files changed, 266 insertions(+), 575 deletions(-) delete mode 100644 egs/aishell/ASR/local/compute_whisper_fbank_aishell.py delete mode 100644 egs/aishell/ASR/local/compute_whisper_fbank_musan.py mode change 100644 => 100755 egs/aishell/ASR/whisper/decode.py mode change 100644 => 100755 egs/aishell/ASR/whisper/model.py mode change 100644 => 100755 egs/aishell/ASR/whisper/train.py diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index c7000da1c..0ca619d98 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -29,7 +29,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor, str2bool @@ -42,7 +42,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False): +def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -68,8 +68,10 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False): list(manifests.keys()), dataset_parts, ) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + if whisper_fbank: + extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda')) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -111,6 +113,12 @@ def get_args(): default=False, help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) return parser.parse_args() @@ -121,5 +129,5 @@ if __name__ == "__main__": args = get_args() compute_fbank_aishell( - num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank ) diff --git a/egs/aishell/ASR/local/compute_whisper_fbank_aishell.py b/egs/aishell/ASR/local/compute_whisper_fbank_aishell.py deleted file mode 100644 index 72f4b7acb..000000000 --- a/egs/aishell/ASR/local/compute_whisper_fbank_aishell.py +++ /dev/null @@ -1,125 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the aishell dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) - - dataset_parts = ( - "train", - "dev", - "test", - ) - prefix = "aishell" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda')) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition and perturb_speed: - logging.info(f"Doing speed perturb") - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--perturb-speed", - type=str2bool, - default=False, - help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - compute_fbank_aishell( - num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed - ) diff --git a/egs/aishell/ASR/local/compute_whisper_fbank_musan.py b/egs/aishell/ASR/local/compute_whisper_fbank_musan.py deleted file mode 100644 index 0378b359b..000000000 --- a/egs/aishell/ASR/local/compute_whisper_fbank_musan.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the musan dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import logging -import os -from pathlib import Path - -import torch -from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter, MonoCut, combine -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def is_cut_long(c: MonoCut) -> bool: - return c.duration > 5 - - -def compute_fbank_musan(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) - num_mel_bins = 80 - - dataset_parts = ( - "music", - "speech", - "noise", - ) - prefix = "musan" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - musan_cuts_path = output_dir / "musan_cuts.jsonl.gz" - - if musan_cuts_path.is_file(): - logging.info(f"{musan_cuts_path} already exists - skipping") - return - - logging.info("Extracting features for Musan") - - extractor = WhisperFbank(WhisperFbankConfig(device='cuda')) - - with get_executor() as ex: # Initialize the executor only once. - # create chunks of Musan with duration 5 - 10 seconds - musan_cuts = ( - CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) - ) - .cut_into_windows(10.0) - .filter(is_cut_long) - .compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/musan_feats", - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - ) - musan_cuts.to_file(musan_cuts_path) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_musan() diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 97dc721c2..aaeba39f8 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall + # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python @@ -83,9 +83,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # # ln -sfv /path/to/musan $dl_dir/musan # - # if [ ! -d $dl_dir/musan ]; then - # lhotse download musan $dl_dir - # fi + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then @@ -99,17 +99,17 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then fi fi -# if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then -# log "Stage 2: Prepare musan manifest" -# # We assume that you have downloaded the musan corpus -# # to data/musan -# if [ ! -f data/manifests/.musan_manifests.done ]; then -# log "It may take 6 minutes" -# mkdir -p data/manifests -# lhotse prepare musan $dl_dir/musan data/manifests -# touch data/manifests/.musan_manifests.done -# fi -# fi +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + if [ ! -f data/manifests/.musan_manifests.done ]; then + log "It may take 6 minutes" + mkdir -p data/manifests + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan_manifests.done + fi +fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Compute fbank for aishell" @@ -120,56 +120,29 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi fi -# if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then -# log "Stage 30: Compute whisper fbank for aishell" -# if [ ! -f data/fbank/.aishell.done ]; then -# mkdir -p data/fbank -# ./local/compute_whisper_fbank_aishell.py --perturb-speed True -# touch data/fbank/.aishell.done -# fi -# fi - -if [ $stage -le 300 ] && [ $stop_stage -ge 300 ]; then - log "Stage 30: Compute whisper fbank for aishell" - if [ ! -f data/fbank/.aishell.done ]; then +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + if [ ! -f data/fbank/.msuan.done ]; then mkdir -p data/fbank - ./local/compute_whisper_fbank_aishell.py --perturb-speed True --num-mel-bins 128 - touch data/fbank/.aishell.done + ./local/compute_fbank_musan.py + touch data/fbank/.msuan.done fi fi -# if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then -# log "Stage 4: Compute fbank for musan" -# if [ ! -f data/fbank/.msuan.done ]; then -# mkdir -p data/fbank -# ./local/compute_fbank_musan.py -# touch data/fbank/.msuan.done -# fi -# fi +lang_phone_dir=data/lang_phone +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare phone based lang" + mkdir -p $lang_phone_dir -# if [ $stage -le 40 ] && [ $stop_stage -ge 40 ]; then -# log "Stage 4: Compute fbank for musan" -# if [ ! -f data/fbank/.msuan.done ]; then -# mkdir -p data/fbank -# ./local/compute_whisper_fbank_musan.py -# touch data/fbank/.msuan.done -# fi -# fi + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/aishell/resource_aishell/lexicon.txt | + sort | uniq > $lang_phone_dir/lexicon.txt -# lang_phone_dir=data/lang_phone -# if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then -# log "Stage 5: Prepare phone based lang" -# mkdir -p $lang_phone_dir + ./local/generate_unique_lexicon.py --lang-dir $lang_phone_dir -# (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | -# cat - $dl_dir/aishell/resource_aishell/lexicon.txt | -# sort | uniq > $lang_phone_dir/lexicon.txt - -# ./local/generate_unique_lexicon.py --lang-dir $lang_phone_dir - -# if [ ! -f $lang_phone_dir/L_disambig.pt ]; then -# ./local/prepare_lang.py --lang-dir $lang_phone_dir -# fi + if [ ! -f $lang_phone_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_phone_dir + fi # Train a bigram P for MMI training @@ -182,93 +155,93 @@ fi cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt fi -# if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then -# ./local/convert_transcript_words_to_tokens.py \ -# --lexicon $lang_phone_dir/uniq_lexicon.txt \ -# --transcript $lang_phone_dir/transcript_words.txt \ -# --oov "" \ -# > $lang_phone_dir/transcript_tokens.txt -# fi + if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_phone_dir/uniq_lexicon.txt \ + --transcript $lang_phone_dir/transcript_words.txt \ + --oov "" \ + > $lang_phone_dir/transcript_tokens.txt + fi -# if [ ! -f $lang_phone_dir/P.arpa ]; then -# ./shared/make_kn_lm.py \ -# -ngram-order 2 \ -# -text $lang_phone_dir/transcript_tokens.txt \ -# -lm $lang_phone_dir/P.arpa -# fi + if [ ! -f $lang_phone_dir/P.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 2 \ + -text $lang_phone_dir/transcript_tokens.txt \ + -lm $lang_phone_dir/P.arpa + fi -# if [ ! -f $lang_phone_dir/P.fst.txt ]; then -# python3 -m kaldilm \ -# --read-symbol-table="$lang_phone_dir/tokens.txt" \ -# --disambig-symbol='#0' \ -# --max-order=2 \ -# $lang_phone_dir/P.arpa > $lang_phone_dir/P.fst.txt -# fi -# fi + if [ ! -f $lang_phone_dir/P.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_phone_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + $lang_phone_dir/P.arpa > $lang_phone_dir/P.fst.txt + fi +fi -# lang_char_dir=data/lang_char -# if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then -# log "Stage 6: Prepare char based lang" -# mkdir -p $lang_char_dir -# # We reuse words.txt from phone based lexicon -# # so that the two can share G.pt later. +lang_char_dir=data/lang_char +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare char based lang" + mkdir -p $lang_char_dir + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. -# # The transcripts in training set, generated in stage 5 -# cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt + # The transcripts in training set, generated in stage 5 + cp $lang_phone_dir/transcript_words.txt $lang_char_dir/transcript_words.txt -# cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt | -# cut -d " " -f 2- > $lang_char_dir/text + cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt | + cut -d " " -f 2- > $lang_char_dir/text -# (echo ' 0'; echo '!SIL 1'; echo ' 2'; echo ' 3';) \ -# > $lang_char_dir/words.txt + (echo ' 0'; echo '!SIL 1'; echo ' 2'; echo ' 3';) \ + > $lang_char_dir/words.txt -# cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \ -# | awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt + cat $lang_char_dir/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \ + | awk '{print $1" "NR+3}' >> $lang_char_dir/words.txt -# num_lines=$(< $lang_char_dir/words.txt wc -l) -# (echo "#0 $num_lines"; echo " $(($num_lines + 1))"; echo " $(($num_lines + 2))";) \ -# >> $lang_char_dir/words.txt + num_lines=$(< $lang_char_dir/words.txt wc -l) + (echo "#0 $num_lines"; echo " $(($num_lines + 1))"; echo " $(($num_lines + 2))";) \ + >> $lang_char_dir/words.txt -# if [ ! -f $lang_char_dir/L_disambig.pt ]; then -# ./local/prepare_char.py --lang-dir $lang_char_dir -# fi -# fi + if [ ! -f $lang_char_dir/L_disambig.pt ]; then + ./local/prepare_char.py --lang-dir $lang_char_dir + fi +fi -# if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then -# log "Stage 7: Prepare Byte BPE based lang" +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare Byte BPE based lang" -# for vocab_size in ${vocab_sizes[@]}; do -# lang_dir=data/lang_bbpe_${vocab_size} -# mkdir -p $lang_dir + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + mkdir -p $lang_dir -# cp $lang_char_dir/words.txt $lang_dir -# cp $lang_char_dir/text $lang_dir + cp $lang_char_dir/words.txt $lang_dir + cp $lang_char_dir/text $lang_dir -# if [ ! -f $lang_dir/bbpe.model ]; then -# ./local/train_bbpe_model.py \ -# --lang-dir $lang_dir \ -# --vocab-size $vocab_size \ -# --transcript $lang_dir/text -# fi + if [ ! -f $lang_dir/bbpe.model ]; then + ./local/train_bbpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/text + fi -# if [ ! -f $lang_dir/L_disambig.pt ]; then -# ./local/prepare_lang_bbpe.py --lang-dir $lang_dir -# fi -# done -# fi + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bbpe.py --lang-dir $lang_dir + fi + done +fi -# if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then -# log "Stage 8: Prepare G" +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare G" -# mkdir -p data/lm + mkdir -p data/lm -# # Train LM on transcripts -# if [ ! -f data/lm/3-gram.unpruned.arpa ]; then -# python3 ./shared/make_kn_lm.py \ -# -ngram-order 3 \ -# -text $lang_char_dir/transcript_words.txt \ -# -lm data/lm/3-gram.unpruned.arpa -# fi + # Train LM on transcripts + if [ ! -f data/lm/3-gram.unpruned.arpa ]; then + python3 ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text $lang_char_dir/transcript_words.txt \ + -lm data/lm/3-gram.unpruned.arpa + fi # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm @@ -294,112 +267,124 @@ fi fi fi -# if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then -# log "Stage 9: Compile LG & HLG" -# ./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone -# ./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char -# for vocab_size in ${vocab_sizes[@]}; do -# lang_dir=data/lang_bbpe_${vocab_size} -# ./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char -# done +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compile LG & HLG" + ./local/compile_hlg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone + ./local/compile_hlg.py --lang-dir $lang_char_dir --lm G_3_gram_char + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir --lm G_3_gram_char + done -# ./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone -# ./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char -# for vocab_size in ${vocab_sizes[@]}; do -# lang_dir=data/lang_bbpe_${vocab_size} -# ./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char -# done -# fi + ./local/compile_lg.py --lang-dir $lang_phone_dir --lm G_3_gram_phone + ./local/compile_lg.py --lang-dir $lang_char_dir --lm G_3_gram_char + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir --lm G_3_gram_char + done +fi -# if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then -# log "Stage 10: Generate LM training data" +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Generate LM training data" -# log "Processing char based data" -# out_dir=data/lm_training_char -# mkdir -p $out_dir $dl_dir/lm + log "Processing char based data" + out_dir=data/lm_training_char + mkdir -p $out_dir $dl_dir/lm -# if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then -# cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt -# fi + if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then + cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt + fi -# # training words -# ./local/prepare_char_lm_training_data.py \ -# --lang-char data/lang_char \ -# --lm-data $dl_dir/lm/aishell-train-word.txt \ -# --lm-archive $out_dir/lm_data.pt + # training words + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/aishell-train-word.txt \ + --lm-archive $out_dir/lm_data.pt -# # valid words -# if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then -# aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt -# aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid -# find $dl_dir/aishell/data_aishell/wav/dev -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_valid_uid -# awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_valid_uid $aishell_text | -# cut -d " " -f 2- > $dl_dir/lm/aishell-valid-word.txt -# fi + # valid words + if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then + aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt + aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid + find $dl_dir/aishell/data_aishell/wav/dev -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_valid_uid + awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_valid_uid $aishell_text | + cut -d " " -f 2- > $dl_dir/lm/aishell-valid-word.txt + fi -# ./local/prepare_char_lm_training_data.py \ -# --lang-char data/lang_char \ -# --lm-data $dl_dir/lm/aishell-valid-word.txt \ -# --lm-archive $out_dir/lm_data_valid.pt + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/aishell-valid-word.txt \ + --lm-archive $out_dir/lm_data_valid.pt -# # test words -# if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then -# aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt -# aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid -# find $dl_dir/aishell/data_aishell/wav/test -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_test_uid -# awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_test_uid $aishell_text | -# cut -d " " -f 2- > $dl_dir/lm/aishell-test-word.txt -# fi + # test words + if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then + aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt + aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid + find $dl_dir/aishell/data_aishell/wav/test -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_test_uid + awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_test_uid $aishell_text | + cut -d " " -f 2- > $dl_dir/lm/aishell-test-word.txt + fi -# ./local/prepare_char_lm_training_data.py \ -# --lang-char data/lang_char \ -# --lm-data $dl_dir/lm/aishell-test-word.txt \ -# --lm-archive $out_dir/lm_data_test.pt -# fi + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/aishell-test-word.txt \ + --lm-archive $out_dir/lm_data_test.pt +fi -# if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then -# log "Stage 11: Sort LM training data" -# # Sort LM training data by sentence length in descending order -# # for ease of training. -# # -# # Sentence length equals to the number of tokens -# # in a sentence. +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Sort LM training data" + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of tokens + # in a sentence. -# out_dir=data/lm_training_char -# mkdir -p $out_dir -# ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/ + out_dir=data/lm_training_char + mkdir -p $out_dir + ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/ -# ./local/sort_lm_training_data.py \ -# --in-lm-data $out_dir/lm_data.pt \ -# --out-lm-data $out_dir/sorted_lm_data.pt \ -# --out-statistics $out_dir/statistics.txt + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt -# ./local/sort_lm_training_data.py \ -# --in-lm-data $out_dir/lm_data_valid.pt \ -# --out-lm-data $out_dir/sorted_lm_data-valid.pt \ -# --out-statistics $out_dir/statistics-valid.txt + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data_valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt -# ./local/sort_lm_training_data.py \ -# --in-lm-data $out_dir/lm_data_test.pt \ -# --out-lm-data $out_dir/sorted_lm_data-test.pt \ -# --out-statistics $out_dir/statistics-test.txt -# fi + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data_test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt +fi -# if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then -# log "Stage 11: Train RNN LM model" -# python ../../../icefall/rnn_lm/train.py \ -# --start-epoch 0 \ -# --world-size 1 \ -# --num-epochs 20 \ -# --use-fp16 0 \ -# --embedding-dim 512 \ -# --hidden-dim 512 \ -# --num-layers 2 \ -# --batch-size 400 \ -# --exp-dir rnnlm_char/exp \ -# --lm-data $out_dir/sorted_lm_data.pt \ -# --lm-data-valid $out_dir/sorted_lm_data-valid.pt \ -# --vocab-size 4336 \ -# --master-port 12345 -# fi +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 11: Train RNN LM model" + python ../../../icefall/rnn_lm/train.py \ + --start-epoch 0 \ + --world-size 1 \ + --num-epochs 20 \ + --use-fp16 0 \ + --embedding-dim 512 \ + --hidden-dim 512 \ + --num-layers 2 \ + --batch-size 400 \ + --exp-dir rnnlm_char/exp \ + --lm-data $out_dir/sorted_lm_data.pt \ + --lm-data-valid $out_dir/sorted_lm_data-valid.pt \ + --vocab-size 4336 \ + --master-port 12345 +fi + +# whisper large-v3 using 128 mel bins, others using 80 mel bins +whisper_mel_bins=80 +if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then + log "Stage 30: Compute ${whisper_mel_bins} dim fbank for whisper model fine-tuning" + if [ ! -f data/fbank/.aishell.whisper.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true + ./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true + touch data/fbank/.aishell.whisper.done + fi +fi \ No newline at end of file diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py old mode 100644 new mode 100755 index 371350905..6c09b142b --- a/egs/aishell/ASR/whisper/decode.py +++ b/egs/aishell/ASR/whisper/decode.py @@ -115,10 +115,8 @@ def remove_punctuation(text: str or List[str]): result_text.append(t) return result_text else: - raise Exception(f'不支持该类型{type(text)}') + raise Exception(f'Not support type {type(text)}') - -# 将繁体中文总成简体中文 def to_simple(text: str or List[str]): if isinstance(text, str): text = convert(text, 'zh-cn') @@ -130,7 +128,7 @@ def to_simple(text: str or List[str]): result_text.append(t) return result_text else: - raise Exception(f'不支持该类型{type(text)}') + raise Exception(f'Not support type{type(text)}') def get_parser(): parser = argparse.ArgumentParser( @@ -192,27 +190,11 @@ def get_parser(): def get_params() -> AttributeDict: params = AttributeDict( { - # parameters for conformer - "subsampling_factor": 4, - "feature_dim": 80, - "nhead": 4, - "attention_dim": 512, - "num_encoder_layers": 12, - "num_decoder_layers": 6, - "vgg_frontend": False, - "use_feat_batchnorm": True, - # parameters for decoder - "search_beam": 20, - "output_beam": 7, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, "env_info": get_env_info(), } ) return params - def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -264,12 +246,6 @@ def decode_one_batch( feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device, dtype=dtype).transpose(1, 2) - # pad feature to T = 3000 - #T = 3000 - #if feature.shape[2] < T: - # feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2) - print(feature.shape,23333) - # at entry, feature is (N, T, C) supervisions = batch["supervisions"] feature_len = supervisions["num_frames"] @@ -281,7 +257,6 @@ def decode_one_batch( hyps = to_simple(hyps) hyps = [params.normalizer.normalize(hyp) for hyp in hyps] - print(hyps, 233333333) key = "beam-search" @@ -467,17 +442,14 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - # we need cut ids to display recognition results. args.return_cuts = True aishell = AishellAsrDataModule(args) - test_cuts = aishell.test_cuts() - test_dl = aishell.test_dataloaders(test_cuts) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) - #test_sets = ["test"] - #test_dls = [test_dl] - test_sets = ["valid"] - test_dls = [valid_dl] + test_dl = aishell.test_dataloaders(aishell.test_cuts()) + test_sets = ["valid", "test"] + test_dls = [valid_dl, test_dl] + for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset( dl=test_dl, diff --git a/egs/aishell/ASR/whisper/model.py b/egs/aishell/ASR/whisper/model.py old mode 100644 new mode 100755 index 2f8fea38c..9ec412513 --- a/egs/aishell/ASR/whisper/model.py +++ b/egs/aishell/ASR/whisper/model.py @@ -168,10 +168,8 @@ class AudioEncoder(nn.Module): x = F.gelu(self.conv2(x)) x = x.permute(0, 2, 1) - # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" - + # change whisper to process audio with any length x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype) - #x = (x + self.positional_embedding).to(x.dtype) for block in self.blocks: x = block(x) @@ -224,7 +222,6 @@ class TextDecoder(nn.Module): return logits - class Whisper(nn.Module): def __init__(self, dims: ModelDimensions): super().__init__() @@ -315,7 +312,6 @@ class Whisper(nn.Module): self.decoder.apply(install_hooks) return cache, hooks - #detect_language = detect_language_function transcribe = transcribe_function decode = decode_function @@ -432,9 +428,4 @@ def load_model( model = Whisper(dims) model.load_state_dict(checkpoint["model_state_dict"]) - # if alignment_heads is not None: - # model.set_alignment_heads(alignment_heads) - - return model.to(device) - - + return model.to(device) \ No newline at end of file diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py old mode 100644 new mode 100755 index 6c76d3cff..8b133b2a4 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -51,26 +51,24 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from typing import List -#from aishell import AIShell -#from asr_datamodule import AsrDataModule + from asr_datamodule import AishellAsrDataModule -#from decoder import Decoder -#from joiner import Joiner + from lhotse import CutSet, load_manifest from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed -#from model import Transducer + from optim import Eden, ScaledAdam from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.functional import pad as pad_tensor from torch.utils.tensorboard import SummaryWriter -#from zipformer import Zipformer + from icefall import diagnostics -#from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler + from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -109,13 +107,6 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - parser.add_argument( "--tensorboard", type=str2bool, @@ -209,19 +200,6 @@ def get_parser(): help="Add hooks to check for infinite module outputs and gradients.", ) - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - parser.add_argument( "--keep-last-k", type=int, @@ -314,11 +292,7 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 999999999999999999, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 100, + "valid_interval": 9999999, "env_info": get_env_info(), } ) @@ -491,26 +465,25 @@ def compute_loss( device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] - # at entry, feature is (N, T, C) + assert feature.ndim == 3 feature = feature.to(device) feature = feature.transpose(1, 2) # (N, C, T) - # pad feature from B,80,T to B,80,3000 - #feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1])) + supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) batch_idx_train = params.batch_idx_train - warm_step = params.warm_step texts = batch["supervisions"]["text"] # remove spaces in texts texts = [text.replace(" ", "") for text in texts] - #print(texts) + text_tokens_list = [list(tokenizer.sot_sequence_including_notimestamps) + tokenizer.encode(text) + [tokenizer.eot] for text in texts] # convert it to torch tensor text_tokens_list = [torch.LongTensor(text_tokens) for text_tokens in text_tokens_list] + # 50256 is the index of for all whisper models prev_outputs_tokens = _batch_tensors( [tokens[:-1] for tokens in text_tokens_list], pad_value=50256 ) @@ -522,9 +495,10 @@ def compute_loss( ) decoder_criterion = LabelSmoothingLoss(ignore_index=50256, label_smoothing=0.1, reduction="sum") + + # ignore the first 3 tokens, which are always , , ignore_prefix_size = 3 with torch.set_grad_enabled(is_training): - encoder_out = model.encoder(feature) text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) loss = decoder_criterion(text_logits, target_tokens.to(device)) @@ -697,27 +671,6 @@ def train_one_epoch( model_avg=model_avg, ) - # if ( - # params.batch_idx_train > 0 - # and params.batch_idx_train % params.save_every_n == 0 - # ): - # save_checkpoint_with_global_batch_idx( - # out_dir=params.exp_dir, - # global_batch_idx=params.batch_idx_train, - # model=model, - # model_avg=model_avg, - # params=params, - # optimizer=optimizer, - # scheduler=scheduler, - # sampler=train_dl.sampler, - # scaler=scaler, - # rank=rank, - # ) - # remove_checkpoints( - # out_dir=params.exp_dir, - # topk=params.keep_last_k, - # rank=rank, - # ) if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different @@ -791,8 +744,7 @@ def run(rank, world_size, args): logging.info(params) logging.info("About to create model") - # TODO download model only on rank 0 - # TODO may change compute validation loss using multiple cards + model = load_model(params.model_name) del model.alignment_heads num_param = sum([p.numel() for p in model.parameters()]) @@ -802,7 +754,6 @@ def run(rank, world_size, args): model.is_multilingual, num_languages=model.num_languages, language="zh", task="transcribe" ) - assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: # model_avg is only used with rank 0 @@ -863,9 +814,8 @@ def run(rank, world_size, args): else: sampler_state_dict = None - - train_dl = aishell.train_dataloaders(aishell.train_cuts(), rank=rank, world_size=world_size) - valid_dl = aishell.valid_dataloaders(aishell.valid_cuts(), rank=rank, world_size=world_size) + train_dl = aishell.train_dataloaders(aishell.train_cuts()) + valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 62036467e..9001aa214 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -28,7 +28,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine +from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter, MonoCut, combine from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -45,11 +45,10 @@ def is_cut_long(c: MonoCut) -> bool: return c.duration > 5 -def compute_fbank_musan(): +def compute_fbank_musan(num_mel_bins: int = 80, whisper_fbank: bool = False): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) - num_mel_bins = 80 dataset_parts = ( "music", @@ -81,7 +80,10 @@ def compute_fbank_musan(): logging.info("Extracting features for Musan") - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + if whisper_fbank: + extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda')) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. # create chunks of Musan with duration 5 - 10 seconds @@ -101,9 +103,26 @@ def compute_fbank_musan(): ) musan_cuts.to_file(musan_cuts_path) +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + return parser.parse_args() if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_musan() + compute_fbank_musan( + num_mel_bins=args.num_mel_bins, whisper_fbank=args.whisper_fbank + )