Reduce debug frequencies

This commit is contained in:
Daniel Povey 2022-06-20 13:48:42 +08:00
parent c10a9889fa
commit 4124cd7241

View File

@ -395,7 +395,7 @@ class NeutralGradient(Optimizer):
cur_grad = self._change_coordinates(cur_grad, state, forward=False)
if random.random() < 0.0002:
if random.random() < 0.00005:
# This is only for debug. The logic below would not be valid for n_cache_grads>0,
# anyway we will delete this code at some point.
# in principle, the cur_grad is supposed to have the same rms as params, on average.
@ -407,7 +407,7 @@ class NeutralGradient(Optimizer):
param_rms = (p**2).mean().sqrt()
print(f"cur_grad_rms={cur_grad_rms.item():.3e}, corrected_grad_rms={cur_grad_rms_corrected.item():.3e}, param_rms={param_rms.item():.3e}")
if random.random() < 0.001:
if random.random() < 0.0005:
# check the cosine angle between cur_grad and grad, to see how different this update
# is from gradient descent.
prod = (grad*cur_grad).mean()
@ -735,7 +735,7 @@ class NeutralGradient(Optimizer):
R_scale = R.diag().mean() + 1.0e-20
cov_scale = cov.diag().mean() + 1.0e-20
if random.random() < 0.1:
if random.random() < 0.02:
print(f"required_rank={required_rank}, rank={rank}, size={size}, R_scale={R_scale}, rand_scale={rand_scale}, shape={shape}")
cov.add_(R, alpha=rand_scale * cov_scale / R_scale)
@ -816,7 +816,7 @@ class NeutralGradient(Optimizer):
P = torch.matmul(Y, Y.t())
if random.random() < 0.1:
if random.random() < 0.025:
# TEMP:
_,s,_ = P.svd()
@ -909,7 +909,7 @@ class NeutralGradient(Optimizer):
param_periods = param_periods.tolist()
logging.info(f"NeutralGradient._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}")
print_info = random.random() < 0.05
print_info = random.random() < 0.01
i = 0
for p in group["params"]:
if p.grad is None:
@ -1150,7 +1150,7 @@ class Cain(Optimizer):
this_delta = grad / denom
alpha = -lr*(1-beta1)*(bias_correction2 ** 0.5)
delta.add_(this_delta, alpha=alpha)
if random.random() < 0.0005:
if random.random() < 0.0001:
print(f"Delta rms = {(delta**2).mean().item()}, shape = {delta.shape}")
p.add_(delta)