load checkpoint from specific path

This commit is contained in:
Yuekai Zhang 2024-03-05 16:37:29 +08:00
parent 73a7687d8a
commit 50b575a2f1
3 changed files with 23 additions and 5 deletions

View File

@ -106,13 +106,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare AISHELL-4" log "Stage 5: Prepare AISHELL-4"
if [ -e ../../aishell4/ASR/data/fbank/.fbank.done ]; then if [ -e ../../aishell4/ASR/data/fbank/.fbank.done ]; then
cd data/fbank cd data/fbank
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_dev) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_test) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_test) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_L.jsonl.gz) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_L.jsonl.gz) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_M.jsonl.gz) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_M.jsonl.gz) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_dev.jsonl.gz) .
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) . ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) .
cd ../.. cd ../..
else else

View File

@ -496,7 +496,7 @@ def main():
test_sets = test_sets_cuts.keys() test_sets = test_sets_cuts.keys()
test_dls = [ test_dls = [
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt))
for cuts_name in test_sets for cuts_name in test_sets
] ]

View File

@ -34,7 +34,7 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
--model-name medium --model-name medium
""" """
import os
import argparse import argparse
import copy import copy
import logging import logging
@ -151,6 +151,15 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--pretrained-model-path",
type=str,
default=None,
help="""The path to the pretrained model if it is not None. Training will
start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt
""",
)
parser.add_argument( parser.add_argument(
"--base-lr", type=float, default=1e-5, help="The base learning rate." "--base-lr", type=float, default=1e-5, help="The base learning rate."
) )
@ -617,6 +626,7 @@ def train_one_epoch(
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
) )
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}")
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
@ -749,6 +759,16 @@ def run(rank, world_size, args):
replace_whisper_encoder_forward() replace_whisper_encoder_forward()
model = whisper.load_model(params.model_name, "cpu") model = whisper.load_model(params.model_name, "cpu")
del model.alignment_heads del model.alignment_heads
if params.pretrained_model_path:
checkpoint = torch.load(
params.pretrained_model_path, map_location="cpu"
)
if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=True)
else:
load_checkpoint(params.pretrained_model_path, model)
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -900,6 +920,7 @@ def run(rank, world_size, args):
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}", tag=f"epoch-{params.cur_epoch}",
) )
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
else: else:
save_checkpoint( save_checkpoint(
params=params, params=params,