diff --git a/egs/multi_zh-hans/ASR/prepare.sh b/egs/multi_zh-hans/ASR/prepare.sh index 2f632fba0..d9fce3251 100755 --- a/egs/multi_zh-hans/ASR/prepare.sh +++ b/egs/multi_zh-hans/ASR/prepare.sh @@ -106,13 +106,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Prepare AISHELL-4" if [ -e ../../aishell4/ASR/data/fbank/.fbank.done ]; then cd data/fbank - ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train) . - ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_dev) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_test) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_L.jsonl.gz) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_M.jsonl.gz) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) . - ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_dev.jsonl.gz) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) . cd ../.. else diff --git a/egs/multi_zh-hans/ASR/whisper/decode.py b/egs/multi_zh-hans/ASR/whisper/decode.py index 2acee2e83..1452c86a3 100644 --- a/egs/multi_zh-hans/ASR/whisper/decode.py +++ b/egs/multi_zh-hans/ASR/whisper/decode.py @@ -496,7 +496,7 @@ def main(): test_sets = test_sets_cuts.keys() test_dls = [ - data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) + data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt)) for cuts_name in test_sets ] diff --git a/egs/multi_zh-hans/ASR/whisper/train.py b/egs/multi_zh-hans/ASR/whisper/train.py index 87a09fa23..40d0dc893 100644 --- a/egs/multi_zh-hans/ASR/whisper/train.py +++ b/egs/multi_zh-hans/ASR/whisper/train.py @@ -34,7 +34,7 @@ torchrun --nproc-per-node 8 ./whisper/train.py \ --model-name medium """ - +import os import argparse import copy import logging @@ -151,6 +151,15 @@ def get_parser(): """, ) + parser.add_argument( + "--pretrained-model-path", + type=str, + default=None, + help="""The path to the pretrained model if it is not None. Training will + start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt + """, + ) + parser.add_argument( "--base-lr", type=float, default=1e-5, help="The base learning rate." ) @@ -617,6 +626,7 @@ def train_one_epoch( f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", ) + os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}") try: with torch.cuda.amp.autocast(enabled=params.use_fp16): @@ -749,6 +759,16 @@ def run(rank, world_size, args): replace_whisper_encoder_forward() model = whisper.load_model(params.model_name, "cpu") del model.alignment_heads + + if params.pretrained_model_path: + checkpoint = torch.load( + params.pretrained_model_path, map_location="cpu" + ) + if "model" not in checkpoint: + model.load_state_dict(checkpoint, strict=True) + else: + load_checkpoint(params.pretrained_model_path, model) + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -900,6 +920,7 @@ def run(rank, world_size, args): f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", tag=f"epoch-{params.cur_epoch}", ) + os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") else: save_checkpoint( params=params,