diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index c7b09c8ac..3ad346a1c 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -127,6 +127,7 @@ def load_checkpoint( checkpoint.pop("model") if model_avg is not None and "model_avg" in checkpoint: + logging.info("Loading averaged model") model_avg.load_state_dict(checkpoint["model_avg"], strict=strict) checkpoint.pop("model_avg") @@ -350,7 +351,9 @@ def update_averaged_model( model_cur: Union[nn.Module, DDP], model_avg: nn.Module, ) -> None: - """Update the averaged model, + """Update the averaged model: + model_avg = model_cur * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train) Args: params: @@ -358,7 +361,7 @@ def update_averaged_model( model_cur: The current model. model_avg: - The stored model averaged from start of training to update. + The averaged model to be updated. """ weight_cur = params.average_period / params.batch_idx_train weight_avg = 1 - weight_cur @@ -369,17 +372,12 @@ def update_averaged_model( cur = model_cur.state_dict() avg = model_avg.state_dict() - uniqued: Dict[int, str] = dict() - for k, v in avg.items(): - v_data_ptr = v.data_ptr() - if v_data_ptr in uniqued: - continue - uniqued[v_data_ptr] = k - - uniqued_names = list(uniqued.values()) - for k in uniqued_names: - avg[k] *= weight_avg - avg[k] += cur[k] * weight_cur + average_state_dict( + state_dict_1=avg, + state_dict_2=cur, + weight_1=weight_avg, + weight_2=weight_cur, + ) def average_checkpoints_with_averaged_model( @@ -388,12 +386,12 @@ def average_checkpoints_with_averaged_model( device: torch.device = torch.device("cpu"), ) -> Dict[str, Tensor]: """Average model parameters over the range with given - start model(excluded) and end model. + start model (excluded) and end model. - Let start = batch_idx_train of model-start, - end = batch_idx_train of model-end, - Then the average model over epoch [start+1, start+2, ..., end] is - avg = (model_end * end - model_start * start) / (start - end) + Let start = batch_idx_train of model-start; + end = batch_idx_train of model-end. + Then the average model over range from start (excluded) to end is + avg = (model_end * end - model_start * start) / (start - end). The model index could be epoch number or checkpoint number. @@ -413,17 +411,41 @@ def average_checkpoints_with_averaged_model( batch_idx_train_start = state_dict_start["batch_idx_train"] batch_idx_train_end = state_dict_end["batch_idx_train"] interval = batch_idx_train_end - batch_idx_train_start - weight_start = -batch_idx_train_start / interval weight_end = batch_idx_train_end / interval + weight_start = 1 - weight_end model_end = state_dict_end["model_avg"] model_start = state_dict_start["model_avg"] avg = model_end + # scale the weight to avoid overflow + average_state_dict( + state_dict_1=avg, + state_dict_2=model_start, + weight_1=1.0, + weight_2=weight_start / weight_end, + scaling_factor=weight_end, + ) + + return avg + + +def average_state_dict( + state_dict_1: Dict[str, Tensor], + state_dict_2: Dict[str, Tensor], + weight_1: float, + weight_2: float, + scaling_factor: float = 1.0, +) -> Dict[str, Tensor]: + """Average two state_dict with given weights: + state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 + weight_2) + * scaling_factor + It is an in-place operation on state_dict_1 itself. + """ # Identify shared parameters. Two parameters are said to be shared # if they have the same data_ptr uniqued: Dict[int, str] = dict() - for k, v in avg.items(): + for k, v in state_dict_1.items(): v_data_ptr = v.data_ptr() if v_data_ptr in uniqued: continue @@ -431,7 +453,6 @@ def average_checkpoints_with_averaged_model( uniqued_names = list(uniqued.values()) for k in uniqued_names: - avg[k] *= weight_end - avg[k] += model_start[k] * weight_start - - return avg + state_dict_1[k] *= weight_1 + state_dict_1[k] += state_dict_2[k] * weight_2 + state_dict_1[k] *= scaling_factor