mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Use single precision for model average; increase average-period to 200.
This commit is contained in:
parent
ab9eb0d52c
commit
b2259184b5
@ -311,7 +311,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--average-period",
|
||||
type=int,
|
||||
default=100,
|
||||
default=200,
|
||||
help="""Update the averaged model, namely `model_avg`, after processing
|
||||
this number of batches. `model_avg` is a separate version of model,
|
||||
in which each floating-point parameter is the average of all the
|
||||
@ -905,7 +905,7 @@ def run(rank, world_size, args):
|
||||
model_avg: Optional[nn.Module] = None
|
||||
if rank == 0:
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
model_avg = copy.deepcopy(model)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
|
@ -86,7 +86,7 @@ def save_checkpoint(
|
||||
}
|
||||
|
||||
if model_avg is not None:
|
||||
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
|
||||
checkpoint["model_avg"] = model_avg.state_dict()
|
||||
|
||||
if params:
|
||||
for k, v in params.items():
|
||||
|
Loading…
x
Reference in New Issue
Block a user