diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index c83c56a53..d09c72328 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -444,6 +444,11 @@ def average_checkpoints_with_averaged_model( scaling_factor=weight_end, ) + # avoid negative running variance of batchnorm layers + for k in avg.keys(): + if k.split('.')[-1] == 'running_var': + avg[k] = torch.clip(avg[k], min=1e-2) + return avg