diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 147bcf658..cc61b3b32 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -867,10 +867,6 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank]) model.device = device - if rank == 0: - model_avg.to(device) - model_avg.device = device - optimizer = Eve(model.parameters(), lr=params.initial_lr) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 5b562ccc8..ba3823ffc 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -467,5 +467,7 @@ def average_state_dict( uniqued_names = list(uniqued.values()) for k in uniqued_names: state_dict_1[k] *= weight_1 - state_dict_1[k] += state_dict_2[k] * weight_2 + state_dict_1[k] += ( + state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 + ) state_dict_1[k] *= scaling_factor