diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py b/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py index a767761eb..159e363c7 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py @@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer): yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for (stacked_params, _state, _names), batch in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer): size_update_period=4, clipping_update_period=100, ): - defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -327,9 +326,7 @@ class ScaledAdam(BatchedOptimizer): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -429,7 +426,7 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: + for p, state, param_names in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -514,7 +511,7 @@ class ScaledAdam(BatchedOptimizer): from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: + for p, state, batch_param_names in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars @@ -530,7 +527,6 @@ class ScaledAdam(BatchedOptimizer): for name, sumsq_orig, rms, grad in zip( batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad ): - proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) @@ -1106,7 +1102,7 @@ def _test_scaled_adam(hidden_dim: int): # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 + # 512 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index ebf61784e..65b6f67b0 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -244,16 +244,14 @@ class TensorDiagnostic(object): if stats_type == "eigs": try: - if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'): + if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"): eigs, _ = torch.linalg.eigh(stats) else: eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() except: # noqa - print( - "Error getting eigenvalues, trying another method." - ) - if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'): + print("Error getting eigenvalues, trying another method.") + if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"): eigs, _ = torch.linalg.eig(stats) eigs = eigs.abs() else: @@ -579,10 +577,15 @@ def attach_diagnostics( ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, - class_name=get_class_name(_module)) - + if isinstance(o, Tensor) and o.dtype in ( + torch.float32, + torch.float16, + torch.float64, + ): + _model_diagnostic[f"{_name}.output[{i}]"].accumulate( + o, class_name=get_class_name(_module) + ) + def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] @@ -596,9 +599,15 @@ def attach_diagnostics( ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, - class_name=get_class_name(_module)) + if isinstance(o, Tensor) and o.dtype in ( + torch.float32, + torch.float16, + torch.float64, + ): + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate( + o, class_name=get_class_name(_module) + ) + module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook)