diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index b4c468e43..817fffd0d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -311,7 +311,7 @@ def get_parser(): parser.add_argument( "--average-period", type=int, - default=100, + default=200, help="""Update the averaged model, namely `model_avg`, after processing this number of batches. `model_avg` is a separate version of model, in which each floating-point parameter is the average of all the @@ -905,7 +905,7 @@ def run(rank, world_size, args): model_avg: Optional[nn.Module] = None if rank == 0: # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) + model_avg = copy.deepcopy(model) assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 23a1fa0c4..618284e74 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -86,7 +86,7 @@ def save_checkpoint( } if model_avg is not None: - checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict() + checkpoint["model_avg"] = model_avg.state_dict() if params: for k, v in params.items():