mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
load checkpoint from specific path
This commit is contained in:
parent
73a7687d8a
commit
50b575a2f1
@ -106,13 +106,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare AISHELL-4"
|
||||
if [ -e ../../aishell4/ASR/data/fbank/.fbank.done ]; then
|
||||
cd data/fbank
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_dev) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_test) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_L.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_M.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_dev.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) .
|
||||
cd ../..
|
||||
else
|
||||
|
@ -496,7 +496,7 @@ def main():
|
||||
|
||||
test_sets = test_sets_cuts.keys()
|
||||
test_dls = [
|
||||
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
|
||||
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt))
|
||||
for cuts_name in test_sets
|
||||
]
|
||||
|
||||
|
@ -34,7 +34,7 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
--model-name medium
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
@ -151,6 +151,15 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pretrained-model-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="""The path to the pretrained model if it is not None. Training will
|
||||
start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-lr", type=float, default=1e-5, help="The base learning rate."
|
||||
)
|
||||
@ -617,6 +626,7 @@ def train_one_epoch(
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
)
|
||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}")
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
@ -749,6 +759,16 @@ def run(rank, world_size, args):
|
||||
replace_whisper_encoder_forward()
|
||||
model = whisper.load_model(params.model_name, "cpu")
|
||||
del model.alignment_heads
|
||||
|
||||
if params.pretrained_model_path:
|
||||
checkpoint = torch.load(
|
||||
params.pretrained_model_path, map_location="cpu"
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
else:
|
||||
load_checkpoint(params.pretrained_model_path, model)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -900,6 +920,7 @@ def run(rank, world_size, args):
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
)
|
||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
||||
else:
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
|
Loading…
x
Reference in New Issue
Block a user