mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
black formatted
This commit is contained in:
parent
736a60fb48
commit
edb2bd56b2
@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer):
|
||||
|
||||
yield tuples # <-- calling code will do the actual optimization here!
|
||||
|
||||
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
||||
for (stacked_params, _state, _names), batch in zip(tuples, batches):
|
||||
for i, p in enumerate(batch): # batch is list of Parameter
|
||||
p.copy_(stacked_params[i])
|
||||
|
||||
@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer):
|
||||
size_update_period=4,
|
||||
clipping_update_period=100,
|
||||
):
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
clipping_scale=clipping_scale,
|
||||
@ -327,9 +326,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
batch = True
|
||||
|
||||
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||
|
||||
with self.batched_params(group["params"], group_params_names) as batches:
|
||||
|
||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||
# a stacking dim, it is not a real dim.
|
||||
@ -429,7 +426,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
clipping_update_period = group["clipping_update_period"]
|
||||
|
||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||
for (p, state, param_names) in tuples:
|
||||
for p, state, param_names in tuples:
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
@ -514,7 +511,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
from tuples, we still pass it to save some time.
|
||||
"""
|
||||
all_sumsq_orig = {}
|
||||
for (p, state, batch_param_names) in tuples:
|
||||
for p, state, batch_param_names in tuples:
|
||||
# p is a stacked batch parameters.
|
||||
batch_grad = p.grad
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
@ -530,7 +527,6 @@ class ScaledAdam(BatchedOptimizer):
|
||||
for name, sumsq_orig, rms, grad in zip(
|
||||
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
||||
):
|
||||
|
||||
proportion_orig = sumsq_orig / tot_sumsq
|
||||
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||
|
||||
@ -1106,7 +1102,7 @@ def _test_scaled_adam(hidden_dim: int):
|
||||
|
||||
# if epoch == 130:
|
||||
# opts = diagnostics.TensorDiagnosticOptions(
|
||||
# 2 ** 22
|
||||
# 512
|
||||
# ) # allow 4 megabytes per sub-module
|
||||
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
||||
|
||||
|
@ -244,16 +244,14 @@ class TensorDiagnostic(object):
|
||||
|
||||
if stats_type == "eigs":
|
||||
try:
|
||||
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'):
|
||||
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
|
||||
eigs, _ = torch.linalg.eigh(stats)
|
||||
else:
|
||||
eigs, _ = torch.symeig(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
except: # noqa
|
||||
print(
|
||||
"Error getting eigenvalues, trying another method."
|
||||
)
|
||||
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'):
|
||||
print("Error getting eigenvalues, trying another method.")
|
||||
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"):
|
||||
eigs, _ = torch.linalg.eig(stats)
|
||||
eigs = eigs.abs()
|
||||
else:
|
||||
@ -579,10 +577,15 @@ def attach_diagnostics(
|
||||
)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
|
||||
class_name=get_class_name(_module))
|
||||
|
||||
if isinstance(o, Tensor) and o.dtype in (
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.float64,
|
||||
):
|
||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(
|
||||
o, class_name=get_class_name(_module)
|
||||
)
|
||||
|
||||
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
|
||||
if isinstance(_output, tuple) and len(_output) == 1:
|
||||
_output = _output[0]
|
||||
@ -596,9 +599,15 @@ def attach_diagnostics(
|
||||
)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
|
||||
class_name=get_class_name(_module))
|
||||
if isinstance(o, Tensor) and o.dtype in (
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.float64,
|
||||
):
|
||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
|
||||
o, class_name=get_class_name(_module)
|
||||
)
|
||||
|
||||
module.register_forward_hook(forward_hook)
|
||||
module.register_backward_hook(backward_hook)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user