mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-12 18:44:20 +00:00
avoid negative values of running_var when using average model
This commit is contained in:
parent
ebbab37776
commit
77c8e03c6b
@ -444,6 +444,11 @@ def average_checkpoints_with_averaged_model(
|
|||||||
scaling_factor=weight_end,
|
scaling_factor=weight_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# avoid negative running variance of batchnorm layers
|
||||||
|
for k in avg.keys():
|
||||||
|
if k.split('.')[-1] == 'running_var':
|
||||||
|
avg[k] = torch.clip(avg[k], min=1e-2)
|
||||||
|
|
||||||
return avg
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user