Various fixes from debugging with nvtx, but removed the NVTX annotations.

This commit is contained in:
Daniel Povey 2022-10-09 21:14:52 +08:00
parent bd7dce460b
commit dece8ad204

View File

@ -76,39 +76,37 @@ class BatchedOptimizer(Optimizer):
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
for p in param_group:
assert not p.grad.is_sparse and "Sparse gradients not supported."
key = (str(p.dtype), *p.shape)
batches[key].append(p)
stacked_params_dict = dict()
# turn batches into a list, in deterministic order.
batches = [ batches[key] for key in sorted(batches.keys()) ]
# pairs will contain pairs of (stacked_param, state), one for each batch
# in `batches`.
pairs = []
for p in param_group:
key = (str(p.dtype), *p.shape)
batch = batches[key]
if p is batch[0]:
# if this is the 1st param in the batch...
# we arbitrarily store the state in the
# state corresponding to the 1st parameter in the
# group. class Optimizer will take care of saving/loading state.
state = self.state[p]
p_stacked = torch.stack(batch)
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ])
p_stacked.grad = grad
stacked_params_dict[key] = p_stacked
pairs.append((p_stacked, state))
for batch in batches:
p = batch[0]
# we arbitrarily store the state in the
# state corresponding to the 1st parameter in the
# group. class Optimizer will take care of saving/loading state.
state = self.state[p]
p_stacked = torch.stack(batch)
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ])
p_stacked.grad = grad
stacked_params_dict[key] = p_stacked
pairs.append((p_stacked, state))
yield pairs # <-- calling code will do the actual optimization here!
for key, batch in batches.items():
stacked_params = stacked_params_dict[key]
for ((stacked_params, _state), batch) in zip(pairs, batches):
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])
class ScaledAdam(BatchedOptimizer):
"""
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
@ -230,7 +228,6 @@ class ScaledAdam(BatchedOptimizer):
return loss
def _init_state(self,
group: dict,
p: Tensor,
@ -326,6 +323,7 @@ class ScaledAdam(BatchedOptimizer):
first_state["model_norms"][step % clipping_update_period] = tot_norm
if step % clipping_update_period == 0:
print(f"step = {step}")
# Print some stats.
# We don't reach here if step == 0 because we would have returned
# above.
@ -349,9 +347,13 @@ class ScaledAdam(BatchedOptimizer):
if step < clipping_update_period:
return 1.0 # We have not yet estimated a norm to clip to.
else:
model_norm_threshold = first_state["model_norm_threshold"]
ans = min(1.0,
(model_norm_threshold / (tot_norm + 1.0e-20)).item())
try:
model_norm_threshold = first_state["model_norm_threshold"]
except:
logging.info("Warning: model_norm_threshold not in state: possibly "
"you changed config when restarting, adding clipping_scale option?")
return 1.0
ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans < 1.0:
first_state["num_clipped"] += 1
if ans < 0.1: