mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
Use single precision for model average; increase average-period to 200.
This commit is contained in:
parent
ab9eb0d52c
commit
b2259184b5
@ -311,7 +311,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--average-period",
|
"--average-period",
|
||||||
type=int,
|
type=int,
|
||||||
default=100,
|
default=200,
|
||||||
help="""Update the averaged model, namely `model_avg`, after processing
|
help="""Update the averaged model, namely `model_avg`, after processing
|
||||||
this number of batches. `model_avg` is a separate version of model,
|
this number of batches. `model_avg` is a separate version of model,
|
||||||
in which each floating-point parameter is the average of all the
|
in which each floating-point parameter is the average of all the
|
||||||
@ -905,7 +905,7 @@ def run(rank, world_size, args):
|
|||||||
model_avg: Optional[nn.Module] = None
|
model_avg: Optional[nn.Module] = None
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
# model_avg is only used with rank 0
|
# model_avg is only used with rank 0
|
||||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
model_avg = copy.deepcopy(model)
|
||||||
|
|
||||||
assert params.start_epoch > 0, params.start_epoch
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
checkpoints = load_checkpoint_if_available(
|
checkpoints = load_checkpoint_if_available(
|
||||||
|
@ -86,7 +86,7 @@ def save_checkpoint(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if model_avg is not None:
|
if model_avg is not None:
|
||||||
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
|
checkpoint["model_avg"] = model_avg.state_dict()
|
||||||
|
|
||||||
if params:
|
if params:
|
||||||
for k, v in params.items():
|
for k, v in params.items():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user