Remove decomposition code from checkpoint.py; restore double precision model_avg

This commit is contained in:
Daniel Povey 2022-06-01 14:01:58 +08:00
parent 03e07e80ce
commit ca09b9798f
2 changed files with 15 additions and 70 deletions

View File

@ -781,7 +781,6 @@ def train_one_epoch(
params=params,
model_cur=model,
model_avg=model_avg,
decompose=True
)
if (
@ -905,7 +904,7 @@ def run(rank, world_size, args):
model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model)
model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(

View File

@ -86,7 +86,7 @@ def save_checkpoint(
}
if model_avg is not None:
checkpoint["model_avg"] = model_avg.state_dict()
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
if params:
for k, v in params.items():
@ -351,7 +351,6 @@ def update_averaged_model(
params: Dict[str, Tensor],
model_cur: Union[nn.Module, DDP],
model_avg: nn.Module,
decompose: bool = False
) -> None:
"""Update the averaged model:
model_avg = model_cur * (average_period / batch_idx_train)
@ -364,12 +363,6 @@ def update_averaged_model(
The current model.
model_avg:
The averaged model to be updated.
decompose:
If true, do the averaging after decomposing each non-scalar tensor into
a log-magnitude and a direction (note: the magnitude is computed with an
epsilon of 1e-5). You should give the same argument to
average_checkpoints_with_averaged_model() when you use the averaged
model.
"""
weight_cur = params.average_period / params.batch_idx_train
weight_avg = 1 - weight_cur
@ -385,7 +378,6 @@ def update_averaged_model(
state_dict_2=cur,
weight_1=weight_avg,
weight_2=weight_cur,
decompose=decompose
)
@ -393,7 +385,6 @@ def average_checkpoints_with_averaged_model(
filename_start: str,
filename_end: str,
device: torch.device = torch.device("cpu"),
decompose: bool = False,
) -> Dict[str, Tensor]:
"""Average model parameters over the range with given
start model (excluded) and end model.
@ -425,12 +416,6 @@ def average_checkpoints_with_averaged_model(
is saved by :func:`save_checkpoint`.
device:
Move checkpoints to this device before averaging.
decompose:
If true, do the averaging after decomposing each non-scalar tensor into
a log-magnitude and a direction (note: the magnitude is computed with an
epsilon of 1e-5). You should give the same argument to
average_checkpoints_with_averaged_model() when you use the averaged
model.
"""
state_dict_start = torch.load(filename_start, map_location=device)
state_dict_end = torch.load(filename_end, map_location=device)
@ -440,56 +425,22 @@ def average_checkpoints_with_averaged_model(
interval = batch_idx_train_end - batch_idx_train_start
assert interval > 0, interval
weight_end = batch_idx_train_end / interval
# note: weight_start will be negative.
weight_start = 1 - weight_end
model_end = state_dict_end["model_avg"]
model_start = state_dict_start["model_avg"]
avg = model_end
# scale the weight to avoid overflow
average_state_dict(
state_dict_1=model_end,
state_dict_1=avg,
state_dict_2=model_start,
weight_1=weight_end,
weight_2=weight_start,
decompose=decompose
weight_1=1.0,
weight_2=weight_start / weight_end,
scaling_factor=weight_end,
)
# model_end contains averaged model
return model_end
def average_tensor(
t1: Tensor,
t2: Tensor,
weight_1: float,
weight_2: float,
decompose: bool):
"""
Computes, in-place,
t1[:] = weight_1 * t1 + weight_2 * t2
If decompose == True and t1 and t2 have numel()>1, does this after
decomposing them into a log-magnitude and a direction (note: the magnitude
is computed with an epsilon of 1e-5).
"""
if t1.numel() == 1 or not decompose:
t1.mul_(weight_1)
t1.add_(t2, alpha=weight_2)
else:
# do this in double precision to reduce roundoff error.
output = t1
t1 = t1.to(torch.float64)
t2 = t2.to(torch.float64)
eps = 1.0e-05
scale_1 = (t1 ** 2).mean().sqrt() + eps
direction_1 = t1 / scale_1
scale_2 = (t2 ** 2).mean().sqrt() + eps
direction_2 = t2 / scale_2
log_scale_1 = scale_1.log()
log_scale_2 = scale_2.log()
average_tensor(log_scale_1, log_scale_2, weight_1, weight_2, False)
average_tensor(direction_1, direction_2, weight_1, weight_2, False)
output.copy_((log_scale_1.exp() * direction_1).to(dtype=t1.dtype))
return avg
def average_state_dict(
@ -497,17 +448,12 @@ def average_state_dict(
state_dict_2: Dict[str, Tensor],
weight_1: float,
weight_2: float,
decompose: bool = False,
scaling_factor: float = 1.0,
) -> Dict[str, Tensor]:
"""Average two state_dict with given weights:
state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2)
The weights do not have to be positive.
* scaling_factor
It is an in-place operation on state_dict_1 itself.
If decompose == True, we do this operation after decomposing
each non-scalar tensor into a log-magnitude and a direction (note:
the magnitude is computed with an epsilon of 1e-5).
"""
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
@ -520,8 +466,8 @@ def average_state_dict(
uniqued_names = list(uniqued.values())
for k in uniqued_names:
average_tensor(state_dict_1[k],
state_dict_2[k].to(device=state_dict_1[k].device),
weight_1,
weight_2,
decompose=decompose)
state_dict_1[k] *= weight_1
state_dict_1[k] += (
state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
)
state_dict_1[k] *= scaling_factor