mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
refactor the checkpoint.py
This commit is contained in:
parent
a0592e0d0f
commit
8bf2fef1e0
@ -127,6 +127,7 @@ def load_checkpoint(
|
|||||||
checkpoint.pop("model")
|
checkpoint.pop("model")
|
||||||
|
|
||||||
if model_avg is not None and "model_avg" in checkpoint:
|
if model_avg is not None and "model_avg" in checkpoint:
|
||||||
|
logging.info("Loading averaged model")
|
||||||
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
|
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
|
||||||
checkpoint.pop("model_avg")
|
checkpoint.pop("model_avg")
|
||||||
|
|
||||||
@ -350,7 +351,9 @@ def update_averaged_model(
|
|||||||
model_cur: Union[nn.Module, DDP],
|
model_cur: Union[nn.Module, DDP],
|
||||||
model_avg: nn.Module,
|
model_avg: nn.Module,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update the averaged model,
|
"""Update the averaged model:
|
||||||
|
model_avg = model_cur * (average_period / batch_idx_train)
|
||||||
|
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params:
|
params:
|
||||||
@ -358,7 +361,7 @@ def update_averaged_model(
|
|||||||
model_cur:
|
model_cur:
|
||||||
The current model.
|
The current model.
|
||||||
model_avg:
|
model_avg:
|
||||||
The stored model averaged from start of training to update.
|
The averaged model to be updated.
|
||||||
"""
|
"""
|
||||||
weight_cur = params.average_period / params.batch_idx_train
|
weight_cur = params.average_period / params.batch_idx_train
|
||||||
weight_avg = 1 - weight_cur
|
weight_avg = 1 - weight_cur
|
||||||
@ -369,17 +372,12 @@ def update_averaged_model(
|
|||||||
cur = model_cur.state_dict()
|
cur = model_cur.state_dict()
|
||||||
avg = model_avg.state_dict()
|
avg = model_avg.state_dict()
|
||||||
|
|
||||||
uniqued: Dict[int, str] = dict()
|
average_state_dict(
|
||||||
for k, v in avg.items():
|
state_dict_1=avg,
|
||||||
v_data_ptr = v.data_ptr()
|
state_dict_2=cur,
|
||||||
if v_data_ptr in uniqued:
|
weight_1=weight_avg,
|
||||||
continue
|
weight_2=weight_cur,
|
||||||
uniqued[v_data_ptr] = k
|
)
|
||||||
|
|
||||||
uniqued_names = list(uniqued.values())
|
|
||||||
for k in uniqued_names:
|
|
||||||
avg[k] *= weight_avg
|
|
||||||
avg[k] += cur[k] * weight_cur
|
|
||||||
|
|
||||||
|
|
||||||
def average_checkpoints_with_averaged_model(
|
def average_checkpoints_with_averaged_model(
|
||||||
@ -388,12 +386,12 @@ def average_checkpoints_with_averaged_model(
|
|||||||
device: torch.device = torch.device("cpu"),
|
device: torch.device = torch.device("cpu"),
|
||||||
) -> Dict[str, Tensor]:
|
) -> Dict[str, Tensor]:
|
||||||
"""Average model parameters over the range with given
|
"""Average model parameters over the range with given
|
||||||
start model(excluded) and end model.
|
start model (excluded) and end model.
|
||||||
|
|
||||||
Let start = batch_idx_train of model-start,
|
Let start = batch_idx_train of model-start;
|
||||||
end = batch_idx_train of model-end,
|
end = batch_idx_train of model-end.
|
||||||
Then the average model over epoch [start+1, start+2, ..., end] is
|
Then the average model over range from start (excluded) to end is
|
||||||
avg = (model_end * end - model_start * start) / (start - end)
|
avg = (model_end * end - model_start * start) / (start - end).
|
||||||
|
|
||||||
The model index could be epoch number or checkpoint number.
|
The model index could be epoch number or checkpoint number.
|
||||||
|
|
||||||
@ -413,17 +411,41 @@ def average_checkpoints_with_averaged_model(
|
|||||||
batch_idx_train_start = state_dict_start["batch_idx_train"]
|
batch_idx_train_start = state_dict_start["batch_idx_train"]
|
||||||
batch_idx_train_end = state_dict_end["batch_idx_train"]
|
batch_idx_train_end = state_dict_end["batch_idx_train"]
|
||||||
interval = batch_idx_train_end - batch_idx_train_start
|
interval = batch_idx_train_end - batch_idx_train_start
|
||||||
weight_start = -batch_idx_train_start / interval
|
|
||||||
weight_end = batch_idx_train_end / interval
|
weight_end = batch_idx_train_end / interval
|
||||||
|
weight_start = 1 - weight_end
|
||||||
|
|
||||||
model_end = state_dict_end["model_avg"]
|
model_end = state_dict_end["model_avg"]
|
||||||
model_start = state_dict_start["model_avg"]
|
model_start = state_dict_start["model_avg"]
|
||||||
avg = model_end
|
avg = model_end
|
||||||
|
|
||||||
|
# scale the weight to avoid overflow
|
||||||
|
average_state_dict(
|
||||||
|
state_dict_1=avg,
|
||||||
|
state_dict_2=model_start,
|
||||||
|
weight_1=1.0,
|
||||||
|
weight_2=weight_start / weight_end,
|
||||||
|
scaling_factor=weight_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def average_state_dict(
|
||||||
|
state_dict_1: Dict[str, Tensor],
|
||||||
|
state_dict_2: Dict[str, Tensor],
|
||||||
|
weight_1: float,
|
||||||
|
weight_2: float,
|
||||||
|
scaling_factor: float = 1.0,
|
||||||
|
) -> Dict[str, Tensor]:
|
||||||
|
"""Average two state_dict with given weights:
|
||||||
|
state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 + weight_2)
|
||||||
|
* scaling_factor
|
||||||
|
It is an in-place operation on state_dict_1 itself.
|
||||||
|
"""
|
||||||
# Identify shared parameters. Two parameters are said to be shared
|
# Identify shared parameters. Two parameters are said to be shared
|
||||||
# if they have the same data_ptr
|
# if they have the same data_ptr
|
||||||
uniqued: Dict[int, str] = dict()
|
uniqued: Dict[int, str] = dict()
|
||||||
for k, v in avg.items():
|
for k, v in state_dict_1.items():
|
||||||
v_data_ptr = v.data_ptr()
|
v_data_ptr = v.data_ptr()
|
||||||
if v_data_ptr in uniqued:
|
if v_data_ptr in uniqued:
|
||||||
continue
|
continue
|
||||||
@ -431,7 +453,6 @@ def average_checkpoints_with_averaged_model(
|
|||||||
|
|
||||||
uniqued_names = list(uniqued.values())
|
uniqued_names = list(uniqued.values())
|
||||||
for k in uniqued_names:
|
for k in uniqued_names:
|
||||||
avg[k] *= weight_end
|
state_dict_1[k] *= weight_1
|
||||||
avg[k] += model_start[k] * weight_start
|
state_dict_1[k] += state_dict_2[k] * weight_2
|
||||||
|
state_dict_1[k] *= scaling_factor
|
||||||
return avg
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user