Various fixes from debugging with nvtx, but removed the NVTX annotations.
This commit is contained in:
parent
bd7dce460b
commit
dece8ad204
@ -76,39 +76,37 @@ class BatchedOptimizer(Optimizer):
|
||||
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
||||
|
||||
for p in param_group:
|
||||
assert not p.grad.is_sparse and "Sparse gradients not supported."
|
||||
key = (str(p.dtype), *p.shape)
|
||||
batches[key].append(p)
|
||||
|
||||
stacked_params_dict = dict()
|
||||
|
||||
# turn batches into a list, in deterministic order.
|
||||
batches = [ batches[key] for key in sorted(batches.keys()) ]
|
||||
# pairs will contain pairs of (stacked_param, state), one for each batch
|
||||
# in `batches`.
|
||||
pairs = []
|
||||
for p in param_group:
|
||||
key = (str(p.dtype), *p.shape)
|
||||
batch = batches[key]
|
||||
if p is batch[0]:
|
||||
# if this is the 1st param in the batch...
|
||||
# we arbitrarily store the state in the
|
||||
# state corresponding to the 1st parameter in the
|
||||
# group. class Optimizer will take care of saving/loading state.
|
||||
state = self.state[p]
|
||||
p_stacked = torch.stack(batch)
|
||||
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ])
|
||||
p_stacked.grad = grad
|
||||
|
||||
stacked_params_dict[key] = p_stacked
|
||||
pairs.append((p_stacked, state))
|
||||
for batch in batches:
|
||||
p = batch[0]
|
||||
# we arbitrarily store the state in the
|
||||
# state corresponding to the 1st parameter in the
|
||||
# group. class Optimizer will take care of saving/loading state.
|
||||
state = self.state[p]
|
||||
p_stacked = torch.stack(batch)
|
||||
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ])
|
||||
p_stacked.grad = grad
|
||||
stacked_params_dict[key] = p_stacked
|
||||
pairs.append((p_stacked, state))
|
||||
|
||||
yield pairs # <-- calling code will do the actual optimization here!
|
||||
|
||||
for key, batch in batches.items():
|
||||
stacked_params = stacked_params_dict[key]
|
||||
for ((stacked_params, _state), batch) in zip(pairs, batches):
|
||||
for i, p in enumerate(batch): # batch is list of Parameter
|
||||
p.copy_(stacked_params[i])
|
||||
|
||||
|
||||
|
||||
|
||||
class ScaledAdam(BatchedOptimizer):
|
||||
"""
|
||||
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
||||
@ -230,7 +228,6 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def _init_state(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
@ -326,6 +323,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
||||
|
||||
if step % clipping_update_period == 0:
|
||||
print(f"step = {step}")
|
||||
# Print some stats.
|
||||
# We don't reach here if step == 0 because we would have returned
|
||||
# above.
|
||||
@ -349,9 +347,13 @@ class ScaledAdam(BatchedOptimizer):
|
||||
if step < clipping_update_period:
|
||||
return 1.0 # We have not yet estimated a norm to clip to.
|
||||
else:
|
||||
model_norm_threshold = first_state["model_norm_threshold"]
|
||||
ans = min(1.0,
|
||||
(model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
||||
try:
|
||||
model_norm_threshold = first_state["model_norm_threshold"]
|
||||
except:
|
||||
logging.info("Warning: model_norm_threshold not in state: possibly "
|
||||
"you changed config when restarting, adding clipping_scale option?")
|
||||
return 1.0
|
||||
ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
||||
if ans < 1.0:
|
||||
first_state["num_clipped"] += 1
|
||||
if ans < 0.1:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user