mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
load checkpoint from specific path
This commit is contained in:
parent
73a7687d8a
commit
50b575a2f1
@ -106,13 +106,10 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
log "Stage 5: Prepare AISHELL-4"
|
log "Stage 5: Prepare AISHELL-4"
|
||||||
if [ -e ../../aishell4/ASR/data/fbank/.fbank.done ]; then
|
if [ -e ../../aishell4/ASR/data/fbank/.fbank.done ]; then
|
||||||
cd data/fbank
|
cd data/fbank
|
||||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_train) .
|
|
||||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_dev) .
|
|
||||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_test) .
|
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_feats_test) .
|
||||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_L.jsonl.gz) .
|
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_L.jsonl.gz) .
|
||||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_M.jsonl.gz) .
|
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_M.jsonl.gz) .
|
||||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) .
|
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_train_S.jsonl.gz) .
|
||||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_dev.jsonl.gz) .
|
|
||||||
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) .
|
ln -svf $(realpath ../../../../aishell4/ASR/data/fbank/aishell4_cuts_test.jsonl.gz) .
|
||||||
cd ../..
|
cd ../..
|
||||||
else
|
else
|
||||||
|
@ -496,7 +496,7 @@ def main():
|
|||||||
|
|
||||||
test_sets = test_sets_cuts.keys()
|
test_sets = test_sets_cuts.keys()
|
||||||
test_dls = [
|
test_dls = [
|
||||||
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
|
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_long_utt))
|
||||||
for cuts_name in test_sets
|
for cuts_name in test_sets
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
|
|||||||
--model-name medium
|
--model-name medium
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
@ -151,6 +151,15 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained-model-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="""The path to the pretrained model if it is not None. Training will
|
||||||
|
start from this model. e.g. ./wenetspeech/ASR/whisper/exp_large_v2/epoch-4-avg-3.pt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-lr", type=float, default=1e-5, help="The base learning rate."
|
"--base-lr", type=float, default=1e-5, help="The base learning rate."
|
||||||
)
|
)
|
||||||
@ -617,6 +626,7 @@ def train_one_epoch(
|
|||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||||
)
|
)
|
||||||
|
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
@ -749,6 +759,16 @@ def run(rank, world_size, args):
|
|||||||
replace_whisper_encoder_forward()
|
replace_whisper_encoder_forward()
|
||||||
model = whisper.load_model(params.model_name, "cpu")
|
model = whisper.load_model(params.model_name, "cpu")
|
||||||
del model.alignment_heads
|
del model.alignment_heads
|
||||||
|
|
||||||
|
if params.pretrained_model_path:
|
||||||
|
checkpoint = torch.load(
|
||||||
|
params.pretrained_model_path, map_location="cpu"
|
||||||
|
)
|
||||||
|
if "model" not in checkpoint:
|
||||||
|
model.load_state_dict(checkpoint, strict=True)
|
||||||
|
else:
|
||||||
|
load_checkpoint(params.pretrained_model_path, model)
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -900,6 +920,7 @@ def run(rank, world_size, args):
|
|||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
||||||
tag=f"epoch-{params.cur_epoch}",
|
tag=f"epoch-{params.cur_epoch}",
|
||||||
)
|
)
|
||||||
|
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
||||||
else:
|
else:
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params,
|
params=params,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user