Fix regarding reverse_cutoff formula

This commit is contained in:
Daniel Povey 2022-06-25 18:24:05 +08:00
parent 8a0277d493
commit 0aa5a334d6

View File

@ -681,7 +681,7 @@ class NeutralGradient(Optimizer):
if size == 1: if size == 1:
continue continue
param_diag_var = param_diag_vars[dim] param_diag_var = param_diag_vars[dim]
num_samples = (p.numel() // size) * 4 > size num_samples = p.numel() // size
# don't apply this reverse_cutoff thing in situations where we can't get a reasonable estimate # don't apply this reverse_cutoff thing in situations where we can't get a reasonable estimate
# of param_cov even with stats accumulation, due to the shape of the tensor. # of param_cov even with stats accumulation, due to the shape of the tensor.
reverse_cutoff = (param_reverse_cutoff if num_samples > size // 4 else 1.0e+10) reverse_cutoff = (param_reverse_cutoff if num_samples > size // 4 else 1.0e+10)
@ -920,8 +920,6 @@ class NeutralGradient(Optimizer):
# S_sqrt is S.sqrt() in the limit where param_pow == 1.0, # S_sqrt is S.sqrt() in the limit where param_pow == 1.0,
# param_rel_eps=0, param_rel_max=inf # param_rel_eps=0, param_rel_max=inf
# don't apply this reverse_cutoff thing in situations where we can't get a reasonable estimate
# of param_cov even with stats accumulation, due to the shape of the tensor.
S_smoothed = self._smooth_param_diag_var(S, param_pow, S_smoothed = self._smooth_param_diag_var(S, param_pow,
param_rel_eps, param_rel_max, param_rel_eps, param_rel_max,
param_reverse_cutoff) param_reverse_cutoff)