avoid negative values of running_var when using average model

This commit is contained in:
Tiance Wang 2023-05-09 11:58:05 +08:00 committed by GitHub
parent ebbab37776
commit 77c8e03c6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -444,6 +444,11 @@ def average_checkpoints_with_averaged_model(
scaling_factor=weight_end, scaling_factor=weight_end,
) )
# avoid negative running variance of batchnorm layers
for k in avg.keys():
if k.split('.')[-1] == 'running_var':
avg[k] = torch.clip(avg[k], min=1e-2)
return avg return avg