mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 11:32:19 +00:00
Print indexes of largest grad
This commit is contained in:
parent
32cf5beaa7
commit
06b9138f33
@ -633,18 +633,37 @@ class ScaledAdam(BatchedOptimizer):
|
||||
"""
|
||||
largest_ratio = 0.0
|
||||
largest_name = ""
|
||||
# ratios_names is a list of 3-tuples: (grad_ratio, param_name, tensor)
|
||||
ratios_names = []
|
||||
for (p, state, batch_param_names) in tuples:
|
||||
dims = list(range(1, p.ndim))
|
||||
grad_ratio = (p.grad**2).mean(dim=dims) / state["exp_avg_sq"].mean(
|
||||
dim=dims
|
||||
|
||||
def mean(x):
|
||||
# workaround for bad interface of torch's "mean" for when dims is the empty list.
|
||||
if len(dims) > 0:
|
||||
return x.mean(dim=dims)
|
||||
else:
|
||||
return x
|
||||
|
||||
grad_ratio = (
|
||||
(mean(p.grad**2) / state["exp_avg_sq"].mean(dim=dims))
|
||||
.sqrt()
|
||||
.to("cpu")
|
||||
)
|
||||
max_grad_ratio, max_index = grad_ratio.to("cpu").max(dim=0)
|
||||
if max_grad_ratio.item() > largest_ratio:
|
||||
largest_ratio = max_grad_ratio.item()
|
||||
largest_name = batch_param_names[max_index.item()]
|
||||
|
||||
ratios_names += zip(
|
||||
grad_ratio.tolist(), batch_param_names, p.grad.unbind(dim=0)
|
||||
)
|
||||
|
||||
ratios_names = sorted(ratios_names, reverse=True)
|
||||
ratios_names = ratios_names[:10]
|
||||
ratios_names = [
|
||||
(ratio, name, largest_index(tensor))
|
||||
for (ratio, name, tensor) in ratios_names
|
||||
]
|
||||
|
||||
logging.warning(
|
||||
f"Parameter with most larger-than-usual grad is {largest_name}, with ratio (cur_grad / normal_grad) of "
|
||||
f"{largest_ratio ** 0.5}"
|
||||
f"Parameters with most larger-than-usual grads, with ratios, are: {ratios_names}"
|
||||
)
|
||||
|
||||
def _show_gradient_dominating_parameter(
|
||||
@ -714,6 +733,12 @@ class ScaledAdam(BatchedOptimizer):
|
||||
)
|
||||
|
||||
|
||||
def largest_index(x: Tensor):
|
||||
x = x.contiguous()
|
||||
argmax = x.abs().argmax().item()
|
||||
return [(argmax // x.stride(i)) % x.size(i) for i in range(x.ndim)]
|
||||
|
||||
|
||||
class LRScheduler(object):
|
||||
"""
|
||||
Base-class for learning rate schedulers where the learning-rate depends on both the
|
||||
|
Loading…
x
Reference in New Issue
Block a user