Fix model avg (#1317)

* fix a bug about the model_avg during finetuning by exchanging the order of loading pre-trained model and initializing avg model

* only match the exact module prefix
This commit is contained in:
marcoyang1998 2023-10-18 17:36:14 +08:00 committed by GitHub
parent 807816fec0
commit 52c24df61d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 4 deletions

View File

@ -655,8 +655,12 @@ def load_model_params(
dst_state_dict = model.state_dict()
for module in init_modules:
logging.info(f"Loading parameters starting with prefix {module}")
src_keys = [k for k in src_state_dict.keys() if k.startswith(module)]
dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)]
src_keys = [
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
]
dst_keys = [
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
]
assert set(src_keys) == set(dst_keys) # two sets should match exactly
for key in src_keys:
dst_state_dict[key] = src_state_dict.pop(key)
@ -1089,6 +1093,9 @@ def run(rank, world_size, args):
checkpoints = load_model_params(
ckpt=params.finetune_ckpt, model=model, init_modules=modules
)
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
else:
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(

View File

@ -498,8 +498,12 @@ def load_model_params(
dst_state_dict = model.state_dict()
for module in init_modules:
logging.info(f"Loading parameters starting with prefix {module}")
src_keys = [k for k in src_state_dict.keys() if k.startswith(module)]
dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)]
src_keys = [
k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
]
dst_keys = [
k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
]
assert set(src_keys) == set(dst_keys) # two sets should match exactly
for key in src_keys:
dst_state_dict[key] = src_state_dict.pop(key)