Use single precision for model average; increase average-period to 200.

This commit is contained in:
Daniel Povey 2022-05-31 14:31:46 +08:00
parent ab9eb0d52c
commit b2259184b5
2 changed files with 3 additions and 3 deletions

View File

@ -311,7 +311,7 @@ def get_parser():
parser.add_argument(
"--average-period",
type=int,
default=100,
default=200,
help="""Update the averaged model, namely `model_avg`, after processing
this number of batches. `model_avg` is a separate version of model,
in which each floating-point parameter is the average of all the
@ -905,7 +905,7 @@ def run(rank, world_size, args):
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
model_avg = copy.deepcopy(model)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(

View File

@ -86,7 +86,7 @@ def save_checkpoint(
}
if model_avg is not None:
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
checkpoint["model_avg"] = model_avg.state_dict()
if params:
for k, v in params.items():