diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index 21adb7752..a7a8ef149 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -655,8 +655,12 @@ def load_model_params( dst_state_dict = model.state_dict() for module in init_modules: logging.info(f"Loading parameters starting with prefix {module}") - src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] - dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] assert set(src_keys) == set(dst_keys) # two sets should match exactly for key in src_keys: dst_state_dict[key] = src_state_dict.pop(key) @@ -1077,23 +1081,26 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + # load model parameters for model fine-tuning if params.do_finetune: modules = params.init_modules.split(",") if params.init_modules else None checkpoints = load_model_params( ckpt=params.finetune_ckpt, model=model, init_modules=modules ) + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) else: assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( params=params, model=model, model_avg=model_avg ) - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) model.to(device) if world_size > 1: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py index c943a84af..ba91980d3 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -498,8 +498,12 @@ def load_model_params( dst_state_dict = model.state_dict() for module in init_modules: logging.info(f"Loading parameters starting with prefix {module}") - src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] - dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] assert set(src_keys) == set(dst_keys) # two sets should match exactly for key in src_keys: dst_state_dict[key] = src_state_dict.pop(key)