From fd4ebf3bfe4e09a4be39c98a5e345d5afe086909 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 25 Jan 2024 08:31:08 +0000 Subject: [PATCH] add manifest dir option --- egs/aishell/ASR/local/compute_fbank_aishell.py | 14 ++++++++++++-- egs/aishell/ASR/prepare.sh | 11 ++++++----- egs/aishell/ASR/whisper/decode.py | 2 ++ egs/aishell/ASR/whisper/train.py | 3 +++ egs/librispeech/ASR/local/compute_fbank_musan.py | 16 +++++++++++++--- 5 files changed, 36 insertions(+), 10 deletions(-) diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index 1a8ce1e8f..3c48f0aa1 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -50,10 +50,13 @@ torch.set_num_interop_threads(1) def compute_fbank_aishell( - num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False + num_mel_bins: int = 80, + perturb_speed: bool = False, + whisper_fbank: bool = False, + output_dir: str = "data/fbank", ): src_dir = Path("data/manifests") - output_dir = Path("data/fbank") + output_dir = Path(output_dir) num_jobs = min(15, os.cpu_count()) dataset_parts = ( @@ -130,6 +133,12 @@ def get_args(): default=False, help="Use WhisperFbank instead of Fbank. Default: False.", ) + parser.add_argument( + "--output-dir", + type=str, + default="data/fbank", + help="Output directory. Default: data/fbank.", + ) return parser.parse_args() @@ -143,4 +152,5 @@ if __name__ == "__main__": num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank, + output_dir=args.output_dir, ) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index f0578f4d6..b7be89bc8 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -379,12 +379,13 @@ fi # whisper large-v3 using 128 mel bins, others using 80 mel bins whisper_mel_bins=80 +output_dir=data/fbank_whisper if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then log "Stage 30: Compute ${whisper_mel_bins} dim fbank for whisper model fine-tuning" - if [ ! -f data/fbank/.aishell.whisper.done ]; then - mkdir -p data/fbank - ./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true - ./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true - touch data/fbank/.aishell.whisper.done + if [ ! -f $output_dir/.aishell.whisper.done ]; then + mkdir -p $output_dir + ./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir + ./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir + touch $output_dir/.aishell.whisper.done fi fi diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py index 07e28a8d4..7f841dcb7 100755 --- a/egs/aishell/ASR/whisper/decode.py +++ b/egs/aishell/ASR/whisper/decode.py @@ -28,6 +28,7 @@ python3 ./whisper/decode.py \ --exp-dir whisper/exp_large_v2 \ --model-name large-v2 \ --epoch 999 --avg 1 \ + --manifest-dir data/fbank_whisper \ --beam-size 10 --max-duration 50 # Command for decoding using pretrained models (before fine-tuning): @@ -36,6 +37,7 @@ python3 ./whisper/decode.py \ --exp-dir whisper/exp_large_v2 \ --model-name large-v2 \ --epoch -1 --avg 1 \ + --manifest-dir data/fbank_whisper \ --remove-whisper-encoder-input-length-restriction False \ --beam-size 10 --max-duration 50 diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index edea7e7ef..d16793eb2 100755 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -23,6 +23,7 @@ torchrun --nproc-per-node 8 ./whisper/train.py \ --max-duration 200 \ --exp-dir whisper/exp_large_v2 \ --model-name large-v2 \ + --manifest-dir data/fbank_whisper \ --deepspeed \ --deepspeed_config ./whisper/ds_config_zero1.json @@ -30,6 +31,7 @@ torchrun --nproc-per-node 8 ./whisper/train.py \ torchrun --nproc-per-node 8 ./whisper/train.py \ --max-duration 200 \ --exp-dir whisper/exp_medium \ + --manifest-dir data/fbank_whisper \ --base-lr 1e-5 \ --model-name medium """ @@ -253,6 +255,7 @@ def get_params() -> AttributeDict: params = AttributeDict( { "frame_shift_ms": 10.0, + "subsampling_factor": 2, "allowed_excess_duration_ratio": 0.1, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index a1b695243..d7781687f 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -54,9 +54,11 @@ def is_cut_long(c: MonoCut) -> bool: return c.duration > 5 -def compute_fbank_musan(num_mel_bins: int = 80, whisper_fbank: bool = False): +def compute_fbank_musan( + num_mel_bins: int = 80, whisper_fbank: bool = False, output_dir: str = "data/fbank" +): src_dir = Path("data/manifests") - output_dir = Path("data/fbank") + output_dir = Path(output_dir) num_jobs = min(15, os.cpu_count()) dataset_parts = ( @@ -129,6 +131,12 @@ def get_args(): default=False, help="Use WhisperFbank instead of Fbank. Default: False.", ) + parser.add_argument( + "--output-dir", + type=str, + default="data/fbank", + help="Output directory. Default: data/fbank.", + ) return parser.parse_args() @@ -138,5 +146,7 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() compute_fbank_musan( - num_mel_bins=args.num_mel_bins, whisper_fbank=args.whisper_fbank + num_mel_bins=args.num_mel_bins, + whisper_fbank=args.whisper_fbank, + output_dir=args.output_dir, )