mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
Update checkpoint.py to support decompose argument
This commit is contained in:
parent
1651fe0d42
commit
8d4c987e21
@ -351,6 +351,7 @@ def update_averaged_model(
|
|||||||
params: Dict[str, Tensor],
|
params: Dict[str, Tensor],
|
||||||
model_cur: Union[nn.Module, DDP],
|
model_cur: Union[nn.Module, DDP],
|
||||||
model_avg: nn.Module,
|
model_avg: nn.Module,
|
||||||
|
decompose: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update the averaged model:
|
"""Update the averaged model:
|
||||||
model_avg = model_cur * (average_period / batch_idx_train)
|
model_avg = model_cur * (average_period / batch_idx_train)
|
||||||
@ -363,6 +364,12 @@ def update_averaged_model(
|
|||||||
The current model.
|
The current model.
|
||||||
model_avg:
|
model_avg:
|
||||||
The averaged model to be updated.
|
The averaged model to be updated.
|
||||||
|
decompose:
|
||||||
|
If true, do the averaging after decomposing each non-scalar tensor into
|
||||||
|
a log-magnitude and a direction (note: the magnitude is computed with an
|
||||||
|
epsilon of 1e-5). You should give the same argument to
|
||||||
|
average_checkpoints_with_averaged_model() when you use the averaged
|
||||||
|
model.
|
||||||
"""
|
"""
|
||||||
weight_cur = params.average_period / params.batch_idx_train
|
weight_cur = params.average_period / params.batch_idx_train
|
||||||
weight_avg = 1 - weight_cur
|
weight_avg = 1 - weight_cur
|
||||||
@ -378,6 +385,7 @@ def update_averaged_model(
|
|||||||
state_dict_2=cur,
|
state_dict_2=cur,
|
||||||
weight_1=weight_avg,
|
weight_1=weight_avg,
|
||||||
weight_2=weight_cur,
|
weight_2=weight_cur,
|
||||||
|
decompose=decompose
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -385,6 +393,7 @@ def average_checkpoints_with_averaged_model(
|
|||||||
filename_start: str,
|
filename_start: str,
|
||||||
filename_end: str,
|
filename_end: str,
|
||||||
device: torch.device = torch.device("cpu"),
|
device: torch.device = torch.device("cpu"),
|
||||||
|
decompose: bool = False,
|
||||||
) -> Dict[str, Tensor]:
|
) -> Dict[str, Tensor]:
|
||||||
"""Average model parameters over the range with given
|
"""Average model parameters over the range with given
|
||||||
start model (excluded) and end model.
|
start model (excluded) and end model.
|
||||||
@ -416,6 +425,12 @@ def average_checkpoints_with_averaged_model(
|
|||||||
is saved by :func:`save_checkpoint`.
|
is saved by :func:`save_checkpoint`.
|
||||||
device:
|
device:
|
||||||
Move checkpoints to this device before averaging.
|
Move checkpoints to this device before averaging.
|
||||||
|
decompose:
|
||||||
|
If true, do the averaging after decomposing each non-scalar tensor into
|
||||||
|
a log-magnitude and a direction (note: the magnitude is computed with an
|
||||||
|
epsilon of 1e-5). You should give the same argument to
|
||||||
|
average_checkpoints_with_averaged_model() when you use the averaged
|
||||||
|
model.
|
||||||
"""
|
"""
|
||||||
state_dict_start = torch.load(filename_start, map_location=device)
|
state_dict_start = torch.load(filename_start, map_location=device)
|
||||||
state_dict_end = torch.load(filename_end, map_location=device)
|
state_dict_end = torch.load(filename_end, map_location=device)
|
||||||
@ -425,22 +440,52 @@ def average_checkpoints_with_averaged_model(
|
|||||||
interval = batch_idx_train_end - batch_idx_train_start
|
interval = batch_idx_train_end - batch_idx_train_start
|
||||||
assert interval > 0, interval
|
assert interval > 0, interval
|
||||||
weight_end = batch_idx_train_end / interval
|
weight_end = batch_idx_train_end / interval
|
||||||
|
# note: weight_start will be negative.
|
||||||
weight_start = 1 - weight_end
|
weight_start = 1 - weight_end
|
||||||
|
|
||||||
model_end = state_dict_end["model_avg"]
|
model_end = state_dict_end["model_avg"]
|
||||||
model_start = state_dict_start["model_avg"]
|
model_start = state_dict_start["model_avg"]
|
||||||
avg = model_end
|
|
||||||
|
|
||||||
# scale the weight to avoid overflow
|
# scale the weight to avoid overflow
|
||||||
average_state_dict(
|
average_state_dict(
|
||||||
state_dict_1=avg,
|
state_dict_1=model_end,
|
||||||
state_dict_2=model_start,
|
state_dict_2=model_start,
|
||||||
weight_1=1.0,
|
weight_1=weight_end,
|
||||||
weight_2=weight_start / weight_end,
|
weight_2=weight_start,
|
||||||
scaling_factor=weight_end,
|
decompose=decompose
|
||||||
)
|
)
|
||||||
|
|
||||||
return avg
|
# model_end contains averaged model
|
||||||
|
return model_end
|
||||||
|
|
||||||
|
|
||||||
|
def average_tensor(
|
||||||
|
t1: Tensor,
|
||||||
|
t2: Tensor,
|
||||||
|
weight_1: float,
|
||||||
|
weight_2: float,
|
||||||
|
decompose: bool):
|
||||||
|
"""
|
||||||
|
Computes, in-place,
|
||||||
|
t1[:] = weight_1 * t1 + weight_2 * t2
|
||||||
|
If decompose == True and t1 and t2 have numel()>1, does this after
|
||||||
|
decomposing them into a log-magnitude and a direction (note: the magnitude
|
||||||
|
is computed with an epsilon of 1e-5).
|
||||||
|
"""
|
||||||
|
if t1.numel() == 1 or not decompose:
|
||||||
|
t1.mul_(weight_1)
|
||||||
|
t1.add_(t2, alpha=weight_2)
|
||||||
|
else:
|
||||||
|
eps = 1.0e-05
|
||||||
|
scale_1 = (t1 ** 2).mean().sqrt() + eps
|
||||||
|
direction_1 = t1 / scale_1
|
||||||
|
scale_2 = (t2 ** 2).mean().sqrt() + eps
|
||||||
|
direction_2 = t2 / scale_2
|
||||||
|
log_scale_1 = scale_1.log()
|
||||||
|
log_scale_2 = scale_2.log()
|
||||||
|
average_tensor(log_scale_1, log_scale_2, weight_1, weight_2, False)
|
||||||
|
average_tensor(direction_1, direction_2, weight_1, weight_2, False)
|
||||||
|
t1.copy_(log_scale_1.exp() * direction_1)
|
||||||
|
|
||||||
|
|
||||||
def average_state_dict(
|
def average_state_dict(
|
||||||
@ -448,12 +493,17 @@ def average_state_dict(
|
|||||||
state_dict_2: Dict[str, Tensor],
|
state_dict_2: Dict[str, Tensor],
|
||||||
weight_1: float,
|
weight_1: float,
|
||||||
weight_2: float,
|
weight_2: float,
|
||||||
scaling_factor: float = 1.0,
|
decompose: bool = False,
|
||||||
) -> Dict[str, Tensor]:
|
) -> Dict[str, Tensor]:
|
||||||
"""Average two state_dict with given weights:
|
"""Average two state_dict with given weights:
|
||||||
state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2)
|
state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2)
|
||||||
* scaling_factor
|
|
||||||
|
The weights do not have to be positive.
|
||||||
It is an in-place operation on state_dict_1 itself.
|
It is an in-place operation on state_dict_1 itself.
|
||||||
|
|
||||||
|
If decompose == True, we do this operation after decomposing
|
||||||
|
each non-scalar tensor into a log-magnitude and a direction (note:
|
||||||
|
the magnitude is computed with an epsilon of 1e-5).
|
||||||
"""
|
"""
|
||||||
# Identify shared parameters. Two parameters are said to be shared
|
# Identify shared parameters. Two parameters are said to be shared
|
||||||
# if they have the same data_ptr
|
# if they have the same data_ptr
|
||||||
@ -466,8 +516,8 @@ def average_state_dict(
|
|||||||
|
|
||||||
uniqued_names = list(uniqued.values())
|
uniqued_names = list(uniqued.values())
|
||||||
for k in uniqued_names:
|
for k in uniqued_names:
|
||||||
state_dict_1[k] *= weight_1
|
average_tensor(state_dict_1[k],
|
||||||
state_dict_1[k] += (
|
state_dict_2[k].to(device=state_dict_1[k].device),
|
||||||
state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
|
weight_1,
|
||||||
)
|
weight_2,
|
||||||
state_dict_1[k] *= scaling_factor
|
decompose=decompose)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user