Update checkpoint.py to support decompose argument

This commit is contained in:
Daniel Povey 2022-05-31 14:25:45 +08:00
parent 1651fe0d42
commit 8d4c987e21

View File

@ -351,6 +351,7 @@ def update_averaged_model(
params: Dict[str, Tensor], params: Dict[str, Tensor],
model_cur: Union[nn.Module, DDP], model_cur: Union[nn.Module, DDP],
model_avg: nn.Module, model_avg: nn.Module,
decompose: bool = False
) -> None: ) -> None:
"""Update the averaged model: """Update the averaged model:
model_avg = model_cur * (average_period / batch_idx_train) model_avg = model_cur * (average_period / batch_idx_train)
@ -363,6 +364,12 @@ def update_averaged_model(
The current model. The current model.
model_avg: model_avg:
The averaged model to be updated. The averaged model to be updated.
decompose:
If true, do the averaging after decomposing each non-scalar tensor into
a log-magnitude and a direction (note: the magnitude is computed with an
epsilon of 1e-5). You should give the same argument to
average_checkpoints_with_averaged_model() when you use the averaged
model.
""" """
weight_cur = params.average_period / params.batch_idx_train weight_cur = params.average_period / params.batch_idx_train
weight_avg = 1 - weight_cur weight_avg = 1 - weight_cur
@ -378,6 +385,7 @@ def update_averaged_model(
state_dict_2=cur, state_dict_2=cur,
weight_1=weight_avg, weight_1=weight_avg,
weight_2=weight_cur, weight_2=weight_cur,
decompose=decompose
) )
@ -385,6 +393,7 @@ def average_checkpoints_with_averaged_model(
filename_start: str, filename_start: str,
filename_end: str, filename_end: str,
device: torch.device = torch.device("cpu"), device: torch.device = torch.device("cpu"),
decompose: bool = False,
) -> Dict[str, Tensor]: ) -> Dict[str, Tensor]:
"""Average model parameters over the range with given """Average model parameters over the range with given
start model (excluded) and end model. start model (excluded) and end model.
@ -416,6 +425,12 @@ def average_checkpoints_with_averaged_model(
is saved by :func:`save_checkpoint`. is saved by :func:`save_checkpoint`.
device: device:
Move checkpoints to this device before averaging. Move checkpoints to this device before averaging.
decompose:
If true, do the averaging after decomposing each non-scalar tensor into
a log-magnitude and a direction (note: the magnitude is computed with an
epsilon of 1e-5). You should give the same argument to
average_checkpoints_with_averaged_model() when you use the averaged
model.
""" """
state_dict_start = torch.load(filename_start, map_location=device) state_dict_start = torch.load(filename_start, map_location=device)
state_dict_end = torch.load(filename_end, map_location=device) state_dict_end = torch.load(filename_end, map_location=device)
@ -425,22 +440,52 @@ def average_checkpoints_with_averaged_model(
interval = batch_idx_train_end - batch_idx_train_start interval = batch_idx_train_end - batch_idx_train_start
assert interval > 0, interval assert interval > 0, interval
weight_end = batch_idx_train_end / interval weight_end = batch_idx_train_end / interval
# note: weight_start will be negative.
weight_start = 1 - weight_end weight_start = 1 - weight_end
model_end = state_dict_end["model_avg"] model_end = state_dict_end["model_avg"]
model_start = state_dict_start["model_avg"] model_start = state_dict_start["model_avg"]
avg = model_end
# scale the weight to avoid overflow # scale the weight to avoid overflow
average_state_dict( average_state_dict(
state_dict_1=avg, state_dict_1=model_end,
state_dict_2=model_start, state_dict_2=model_start,
weight_1=1.0, weight_1=weight_end,
weight_2=weight_start / weight_end, weight_2=weight_start,
scaling_factor=weight_end, decompose=decompose
) )
return avg # model_end contains averaged model
return model_end
def average_tensor(
t1: Tensor,
t2: Tensor,
weight_1: float,
weight_2: float,
decompose: bool):
"""
Computes, in-place,
t1[:] = weight_1 * t1 + weight_2 * t2
If decompose == True and t1 and t2 have numel()>1, does this after
decomposing them into a log-magnitude and a direction (note: the magnitude
is computed with an epsilon of 1e-5).
"""
if t1.numel() == 1 or not decompose:
t1.mul_(weight_1)
t1.add_(t2, alpha=weight_2)
else:
eps = 1.0e-05
scale_1 = (t1 ** 2).mean().sqrt() + eps
direction_1 = t1 / scale_1
scale_2 = (t2 ** 2).mean().sqrt() + eps
direction_2 = t2 / scale_2
log_scale_1 = scale_1.log()
log_scale_2 = scale_2.log()
average_tensor(log_scale_1, log_scale_2, weight_1, weight_2, False)
average_tensor(direction_1, direction_2, weight_1, weight_2, False)
t1.copy_(log_scale_1.exp() * direction_1)
def average_state_dict( def average_state_dict(
@ -448,12 +493,17 @@ def average_state_dict(
state_dict_2: Dict[str, Tensor], state_dict_2: Dict[str, Tensor],
weight_1: float, weight_1: float,
weight_2: float, weight_2: float,
scaling_factor: float = 1.0, decompose: bool = False,
) -> Dict[str, Tensor]: ) -> Dict[str, Tensor]:
"""Average two state_dict with given weights: """Average two state_dict with given weights:
state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2) state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2)
* scaling_factor
The weights do not have to be positive.
It is an in-place operation on state_dict_1 itself. It is an in-place operation on state_dict_1 itself.
If decompose == True, we do this operation after decomposing
each non-scalar tensor into a log-magnitude and a direction (note:
the magnitude is computed with an epsilon of 1e-5).
""" """
# Identify shared parameters. Two parameters are said to be shared # Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr # if they have the same data_ptr
@ -466,8 +516,8 @@ def average_state_dict(
uniqued_names = list(uniqued.values()) uniqued_names = list(uniqued.values())
for k in uniqued_names: for k in uniqued_names:
state_dict_1[k] *= weight_1 average_tensor(state_dict_1[k],
state_dict_1[k] += ( state_dict_2[k].to(device=state_dict_1[k].device),
state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 weight_1,
) weight_2,
state_dict_1[k] *= scaling_factor decompose=decompose)