mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
add docs of the scaling in function average_checkpoints_with_averaged_model
This commit is contained in:
parent
5c07402af8
commit
4f18d52c8c
@ -390,9 +390,20 @@ def average_checkpoints_with_averaged_model(
|
||||
start model (excluded) and end model.
|
||||
|
||||
Let start = batch_idx_train of model-start;
|
||||
end = batch_idx_train of model-end.
|
||||
end = batch_idx_train of model-end;
|
||||
interval = end - start.
|
||||
Then the average model over range from start (excluded) to end is
|
||||
avg = (model_end * end - model_start * start) / (start - end).
|
||||
(1) avg = (model_end * end - model_start * start) / interval.
|
||||
It can be written as
|
||||
(2) avg = model_end * weight_end + model_start * weight_start,
|
||||
where weight_end = end / interval,
|
||||
weight_start = -start / interval = 1 - weight_end.
|
||||
Since the terms `weight_end` and `weight_start` would be large
|
||||
if the model has been trained for lots of batches, which would cause
|
||||
overflow when multiplying the model parameters.
|
||||
To avoid this, we rewrite (2) as:
|
||||
(3) avg = (model_end + model_start * (weight_start / weight_end))
|
||||
* weight_end
|
||||
|
||||
The model index could be epoch number or checkpoint number.
|
||||
|
||||
@ -412,6 +423,7 @@ def average_checkpoints_with_averaged_model(
|
||||
batch_idx_train_start = state_dict_start["batch_idx_train"]
|
||||
batch_idx_train_end = state_dict_end["batch_idx_train"]
|
||||
interval = batch_idx_train_end - batch_idx_train_start
|
||||
assert interval > 0, interval
|
||||
weight_end = batch_idx_train_end / interval
|
||||
weight_start = 1 - weight_end
|
||||
|
||||
@ -439,7 +451,7 @@ def average_state_dict(
|
||||
scaling_factor: float = 1.0,
|
||||
) -> Dict[str, Tensor]:
|
||||
"""Average two state_dict with given weights:
|
||||
state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 + weight_2)
|
||||
state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2)
|
||||
* scaling_factor
|
||||
It is an in-place operation on state_dict_1 itself.
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user