diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index 4e261dbc1..a7a8ef149 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -655,8 +655,12 @@ def load_model_params( dst_state_dict = model.state_dict() for module in init_modules: logging.info(f"Loading parameters starting with prefix {module}") - src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] - dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] assert set(src_keys) == set(dst_keys) # two sets should match exactly for key in src_keys: dst_state_dict[key] = src_state_dict.pop(key) @@ -1089,6 +1093,9 @@ def run(rank, world_size, args): checkpoints = load_model_params( ckpt=params.finetune_ckpt, model=model, init_modules=modules ) + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) else: assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py index c943a84af..ba91980d3 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -498,8 +498,12 @@ def load_model_params( dst_state_dict = model.state_dict() for module in init_modules: logging.info(f"Loading parameters starting with prefix {module}") - src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] - dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] assert set(src_keys) == set(dst_keys) # two sets should match exactly for key in src_keys: dst_state_dict[key] = src_state_dict.pop(key)