From 4f18d52c8ceea7432c9c6dcfe9a72c9e2387aba5 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 5 May 2022 21:16:32 +0800 Subject: [PATCH] add docs of the scaling in function average_checkpoints_with_averaged_model --- icefall/checkpoint.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 77c47fc94..5b562ccc8 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -390,9 +390,20 @@ def average_checkpoints_with_averaged_model( start model (excluded) and end model. Let start = batch_idx_train of model-start; - end = batch_idx_train of model-end. + end = batch_idx_train of model-end; + interval = end - start. Then the average model over range from start (excluded) to end is - avg = (model_end * end - model_start * start) / (start - end). + (1) avg = (model_end * end - model_start * start) / interval. + It can be written as + (2) avg = model_end * weight_end + model_start * weight_start, + where weight_end = end / interval, + weight_start = -start / interval = 1 - weight_end. + Since the terms `weight_end` and `weight_start` would be large + if the model has been trained for lots of batches, which would cause + overflow when multiplying the model parameters. + To avoid this, we rewrite (2) as: + (3) avg = (model_end + model_start * (weight_start / weight_end)) + * weight_end The model index could be epoch number or checkpoint number. @@ -412,6 +423,7 @@ def average_checkpoints_with_averaged_model( batch_idx_train_start = state_dict_start["batch_idx_train"] batch_idx_train_end = state_dict_end["batch_idx_train"] interval = batch_idx_train_end - batch_idx_train_start + assert interval > 0, interval weight_end = batch_idx_train_end / interval weight_start = 1 - weight_end @@ -439,7 +451,7 @@ def average_state_dict( scaling_factor: float = 1.0, ) -> Dict[str, Tensor]: """Average two state_dict with given weights: - state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 + weight_2) + state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2) * scaling_factor It is an in-place operation on state_dict_1 itself. """