4x scale on max-eig constraint

This commit is contained in:
Daniel Povey 2022-09-20 14:20:13 +08:00
parent 3d72a65de8
commit db1f4ccdd1

View File

@ -397,14 +397,15 @@ class ActivationBalancer(torch.nn.Module):
return x return x
if self.max_var_per_eig > 0: if self.max_var_per_eig > 0:
max_eig_prob = 0.25
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
x, new_direction = MaxEigLimiterFunction.apply( x, new_direction = MaxEigLimiterFunction.apply(
x, self.max_eig_direction, x, self.max_eig_direction,
self.channel_dim, self.channel_dim,
0.25, # prob max_eig_prob,
True, # subtract_mean True, # subtract_mean
self.max_var_per_eig, self.max_var_per_eig,
self.max_factor, self.max_factor / max_eig_prob,
) )
self.max_eig_direction[:] = new_direction.detach() self.max_eig_direction[:] = new_direction.detach()