diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 817fffd0d..d2ad6a6e5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -781,7 +781,6 @@ def train_one_epoch( params=params, model_cur=model, model_avg=model_avg, - decompose=True ) if ( @@ -905,7 +904,7 @@ def run(rank, world_size, args): model_avg: Optional[nn.Module] = None if rank == 0: # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) + model_avg = copy.deepcopy(model).to(torch.float64) assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 97e5bba65..4e02dd382 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -86,7 +86,7 @@ def save_checkpoint( } if model_avg is not None: - checkpoint["model_avg"] = model_avg.state_dict() + checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict() if params: for k, v in params.items(): @@ -351,7 +351,6 @@ def update_averaged_model( params: Dict[str, Tensor], model_cur: Union[nn.Module, DDP], model_avg: nn.Module, - decompose: bool = False ) -> None: """Update the averaged model: model_avg = model_cur * (average_period / batch_idx_train) @@ -364,12 +363,6 @@ def update_averaged_model( The current model. model_avg: The averaged model to be updated. - decompose: - If true, do the averaging after decomposing each non-scalar tensor into - a log-magnitude and a direction (note: the magnitude is computed with an - epsilon of 1e-5). You should give the same argument to - average_checkpoints_with_averaged_model() when you use the averaged - model. """ weight_cur = params.average_period / params.batch_idx_train weight_avg = 1 - weight_cur @@ -385,7 +378,6 @@ def update_averaged_model( state_dict_2=cur, weight_1=weight_avg, weight_2=weight_cur, - decompose=decompose ) @@ -393,7 +385,6 @@ def average_checkpoints_with_averaged_model( filename_start: str, filename_end: str, device: torch.device = torch.device("cpu"), - decompose: bool = False, ) -> Dict[str, Tensor]: """Average model parameters over the range with given start model (excluded) and end model. @@ -425,12 +416,6 @@ def average_checkpoints_with_averaged_model( is saved by :func:`save_checkpoint`. device: Move checkpoints to this device before averaging. - decompose: - If true, do the averaging after decomposing each non-scalar tensor into - a log-magnitude and a direction (note: the magnitude is computed with an - epsilon of 1e-5). You should give the same argument to - average_checkpoints_with_averaged_model() when you use the averaged - model. """ state_dict_start = torch.load(filename_start, map_location=device) state_dict_end = torch.load(filename_end, map_location=device) @@ -440,56 +425,22 @@ def average_checkpoints_with_averaged_model( interval = batch_idx_train_end - batch_idx_train_start assert interval > 0, interval weight_end = batch_idx_train_end / interval - # note: weight_start will be negative. weight_start = 1 - weight_end model_end = state_dict_end["model_avg"] model_start = state_dict_start["model_avg"] + avg = model_end # scale the weight to avoid overflow average_state_dict( - state_dict_1=model_end, + state_dict_1=avg, state_dict_2=model_start, - weight_1=weight_end, - weight_2=weight_start, - decompose=decompose + weight_1=1.0, + weight_2=weight_start / weight_end, + scaling_factor=weight_end, ) - # model_end contains averaged model - return model_end - - -def average_tensor( - t1: Tensor, - t2: Tensor, - weight_1: float, - weight_2: float, - decompose: bool): - """ - Computes, in-place, - t1[:] = weight_1 * t1 + weight_2 * t2 - If decompose == True and t1 and t2 have numel()>1, does this after - decomposing them into a log-magnitude and a direction (note: the magnitude - is computed with an epsilon of 1e-5). - """ - if t1.numel() == 1 or not decompose: - t1.mul_(weight_1) - t1.add_(t2, alpha=weight_2) - else: - # do this in double precision to reduce roundoff error. - output = t1 - t1 = t1.to(torch.float64) - t2 = t2.to(torch.float64) - eps = 1.0e-05 - scale_1 = (t1 ** 2).mean().sqrt() + eps - direction_1 = t1 / scale_1 - scale_2 = (t2 ** 2).mean().sqrt() + eps - direction_2 = t2 / scale_2 - log_scale_1 = scale_1.log() - log_scale_2 = scale_2.log() - average_tensor(log_scale_1, log_scale_2, weight_1, weight_2, False) - average_tensor(direction_1, direction_2, weight_1, weight_2, False) - output.copy_((log_scale_1.exp() * direction_1).to(dtype=t1.dtype)) + return avg def average_state_dict( @@ -497,17 +448,12 @@ def average_state_dict( state_dict_2: Dict[str, Tensor], weight_1: float, weight_2: float, - decompose: bool = False, + scaling_factor: float = 1.0, ) -> Dict[str, Tensor]: """Average two state_dict with given weights: state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2) - - The weights do not have to be positive. + * scaling_factor It is an in-place operation on state_dict_1 itself. - - If decompose == True, we do this operation after decomposing - each non-scalar tensor into a log-magnitude and a direction (note: - the magnitude is computed with an epsilon of 1e-5). """ # Identify shared parameters. Two parameters are said to be shared # if they have the same data_ptr @@ -520,8 +466,8 @@ def average_state_dict( uniqued_names = list(uniqued.values()) for k in uniqued_names: - average_tensor(state_dict_1[k], - state_dict_2[k].to(device=state_dict_1[k].device), - weight_1, - weight_2, - decompose=decompose) + state_dict_1[k] *= weight_1 + state_dict_1[k] += ( + state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 + ) + state_dict_1[k] *= scaling_factor