From dece8ad204d8c92732457d3fae2de58f4cad88a3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 9 Oct 2022 21:14:52 +0800 Subject: [PATCH] Various fixes from debugging with nvtx, but removed the NVTX annotations. --- .../ASR/pruned_transducer_stateless7/optim.py | 46 ++++++++++--------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index c1589b907..544324148 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -76,39 +76,37 @@ class BatchedOptimizer(Optimizer): batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter for p in param_group: - assert not p.grad.is_sparse and "Sparse gradients not supported." key = (str(p.dtype), *p.shape) batches[key].append(p) stacked_params_dict = dict() + # turn batches into a list, in deterministic order. + batches = [ batches[key] for key in sorted(batches.keys()) ] + # pairs will contain pairs of (stacked_param, state), one for each batch + # in `batches`. pairs = [] - for p in param_group: - key = (str(p.dtype), *p.shape) - batch = batches[key] - if p is batch[0]: - # if this is the 1st param in the batch... - # we arbitrarily store the state in the - # state corresponding to the 1st parameter in the - # group. class Optimizer will take care of saving/loading state. - state = self.state[p] - p_stacked = torch.stack(batch) - grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) - p_stacked.grad = grad - stacked_params_dict[key] = p_stacked - pairs.append((p_stacked, state)) + for batch in batches: + p = batch[0] + # we arbitrarily store the state in the + # state corresponding to the 1st parameter in the + # group. class Optimizer will take care of saving/loading state. + state = self.state[p] + p_stacked = torch.stack(batch) + grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) + p_stacked.grad = grad + stacked_params_dict[key] = p_stacked + pairs.append((p_stacked, state)) yield pairs # <-- calling code will do the actual optimization here! - for key, batch in batches.items(): - stacked_params = stacked_params_dict[key] + for ((stacked_params, _state), batch) in zip(pairs, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) - class ScaledAdam(BatchedOptimizer): """ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update @@ -230,7 +228,6 @@ class ScaledAdam(BatchedOptimizer): return loss - def _init_state(self, group: dict, p: Tensor, @@ -326,6 +323,7 @@ class ScaledAdam(BatchedOptimizer): first_state["model_norms"][step % clipping_update_period] = tot_norm if step % clipping_update_period == 0: + print(f"step = {step}") # Print some stats. # We don't reach here if step == 0 because we would have returned # above. @@ -349,9 +347,13 @@ class ScaledAdam(BatchedOptimizer): if step < clipping_update_period: return 1.0 # We have not yet estimated a norm to clip to. else: - model_norm_threshold = first_state["model_norm_threshold"] - ans = min(1.0, - (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + try: + model_norm_threshold = first_state["model_norm_threshold"] + except: + logging.info("Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?") + return 1.0 + ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item()) if ans < 1.0: first_state["num_clipped"] += 1 if ans < 0.1: