show dominant parameters

This commit is contained in:
Guo Liyong 2022-11-26 00:50:21 +08:00
parent e5d942696a
commit 89c3982a07
2 changed files with 79 additions and 13 deletions

View File

@ -42,7 +42,7 @@ class BatchedOptimizer(Optimizer):
super(BatchedOptimizer, self).__init__(params, defaults) super(BatchedOptimizer, self).__init__(params, defaults)
@contextlib.contextmanager @contextlib.contextmanager
def batched_params(self, param_group): def batched_params(self, param_group, group_params_names=None):
""" """
This function returns (technically, yields) a list of This function returns (technically, yields) a list of
of tuples (p, state), where of tuples (p, state), where
@ -75,20 +75,28 @@ class BatchedOptimizer(Optimizer):
batches = defaultdict( batches = defaultdict(
list list
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
batches_names = defaultdict(
list
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
for p in param_group: for p, named_p in zip(param_group, group_params_names):
key = (str(p.dtype), *p.shape) key = (str(p.dtype), *p.shape)
batches[key].append(p) batches[key].append(p)
batches_names[key].append(named_p)
batches_names_keys = list(batches_names.keys())
sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
stacked_params_dict = dict() stacked_params_dict = dict()
# turn batches into a list, in deterministic order. # turn batches into a list, in deterministic order.
batches = [batches[key] for key in sorted(batches.keys())]
# pairs will contain pairs of (stacked_param, state), one for each batch # pairs will contain pairs of (stacked_param, state), one for each batch
# in `batches`. # in `batches`.
pairs = [] pairs = []
for batch in batches: for batch, batch_names in zip(batches, batches_names):
p = batch[0] p = batch[0]
# we arbitrarily store the state in the # we arbitrarily store the state in the
# state corresponding to the 1st parameter in the # state corresponding to the 1st parameter in the
@ -100,11 +108,11 @@ class BatchedOptimizer(Optimizer):
) )
p_stacked.grad = grad p_stacked.grad = grad
stacked_params_dict[key] = p_stacked stacked_params_dict[key] = p_stacked
pairs.append((p_stacked, state)) pairs.append((p_stacked, state, batch_names))
yield pairs # <-- calling code will do the actual optimization here! yield pairs # <-- calling code will do the actual optimization here!
for ((stacked_params, _state), batch) in zip(pairs, batches): for ((stacked_params, _state, _names), batch) in zip(pairs, batches):
for i, p in enumerate(batch): # batch is list of Parameter for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i]) p.copy_(stacked_params[i])
@ -165,6 +173,8 @@ class ScaledAdam(BatchedOptimizer):
scalar_max=10.0, scalar_max=10.0,
size_update_period=4, size_update_period=4,
clipping_update_period=100, clipping_update_period=100,
parameters_names=None,
show_dominant_parameters=False,
): ):
defaults = dict( defaults = dict(
@ -181,6 +191,8 @@ class ScaledAdam(BatchedOptimizer):
) )
super(ScaledAdam, self).__init__(params, defaults) super(ScaledAdam, self).__init__(params, defaults)
self.parameters_names = parameters_names
self.show_dominant_parameters = show_dominant_parameters
def __setstate__(self, state): def __setstate__(self, state):
super(ScaledAdam, self).__setstate__(state) super(ScaledAdam, self).__setstate__(state)
@ -199,9 +211,11 @@ class ScaledAdam(BatchedOptimizer):
loss = closure() loss = closure()
batch = True batch = True
for group in self.param_groups: assert len(self.param_groups) == len(self.parameters_names)
with self.batched_params(group["params"]) as batches: for group, group_params_names in zip(self.param_groups, self.parameters_names):
with self.batched_params(group["params"], group_params_names) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like # batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to # a regular parameter, and will have a .grad, but the 1st dim corresponds to
@ -214,7 +228,7 @@ class ScaledAdam(BatchedOptimizer):
else: else:
clipping_scale = self._get_clipping_scale(group, batches) clipping_scale = self._get_clipping_scale(group, batches)
for p, state in batches: for p, state, _ in batches:
# Perform optimization step. # Perform optimization step.
# grad is not going to be None, we handled that when creating the batches. # grad is not going to be None, we handled that when creating the batches.
grad = p.grad grad = p.grad
@ -276,7 +290,7 @@ class ScaledAdam(BatchedOptimizer):
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
def _get_clipping_scale( def _get_clipping_scale(
self, group: dict, pairs: List[Tuple[Tensor, dict]] self, group: dict, pairs: List[Tuple[Tensor, dict, List[str]]]
) -> float: ) -> float:
""" """
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
@ -289,7 +303,7 @@ class ScaledAdam(BatchedOptimizer):
""" """
assert len(pairs) >= 1 assert len(pairs) >= 1
clipping_scale = group["clipping_scale"] clipping_scale = group["clipping_scale"]
(first_p, first_state) = pairs[0] (first_p, first_state, _) = pairs[0]
step = first_state["step"] step = first_state["step"]
if clipping_scale is None or step == 0: if clipping_scale is None or step == 0:
# no clipping. return early on step == 0 because the other # no clipping. return early on step == 0 because the other
@ -298,7 +312,7 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period = group["clipping_update_period"] clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=first_p.device) tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state) in pairs: for (p, state, param_names) in pairs:
grad = p.grad grad = p.grad
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError( raise RuntimeError(
@ -361,8 +375,49 @@ class ScaledAdam(BatchedOptimizer):
logging.warn( logging.warn(
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
) )
if self.show_dominant_parameters:
assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(pairs, tot_sumsq)
return ans return ans
def _show_gradient_dominating_parameter(self, pairs, tot_sumsq):
# ori means calculated with state["param_rms"]
# cur means calculated with "param_rms" of current param.
# bt is short batch
# all_sumsq_ori_rms
all_sumsq_ori = {}
all_sumsq_cur = {}
for (p, state, batch_param_names) in pairs:
# p is a stacked batch parameters.
grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
batch_sumsq_ori = grad**2 # sum() to change shape [1] to []
batch_sumsq_cur = batch_sumsq_ori # sum() to change shape [1] to []
# Dummpy values used by following `zip` statement.
batch_rms_ori = torch.zeros(p.shape[0])
batch_rms_cur = batch_rms_ori
else:
batch_rms_ori = state["param_rms"]
batch_sumsq_ori = ((grad * batch_rms_ori) ** 2).sum(dim=list(range(1, grad.ndim)))
batch_rms_cur = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
batch_sumsq_cur = ((grad * batch_rms_cur) ** 2).sum(dim=list(range(1, grad.ndim)))
for name, sumsq_ori, sumsq_cur in zip(
batch_param_names, batch_sumsq_ori, batch_sumsq_cur):
proportion_ori = sumsq_ori / tot_sumsq
proportion_cur = sumsq_cur / tot_sumsq
all_sumsq_ori[name] = (proportion_ori, sumsq_ori)
all_sumsq_cur[name] = (proportion_cur, sumsq_cur)
for rms_type, all_sumsq in zip(("ori", "cur"), (all_sumsq_ori, all_sumsq_cur)):
sorted_by_proportion = {k: v for k, v in sorted(all_sumsq.items(), key=lambda item: item[1][0], reverse=True)}
dominant_param_name = next(iter(sorted_by_proportion))
dominant_proportion, dominant_sumsq = sorted_by_proportion[dominant_param_name]
logging.info(f"Dominant sumsq with {rms_type}_rms: {dominant_param_name} {dominant_proportion} {dominant_sumsq} {tot_sumsq}")
def _step_one_batch( def _step_one_batch(
self, group: dict, p: Tensor, state: dict, clipping_scale: float self, group: dict, p: Tensor, state: dict, clipping_scale: float
): ):

View File

@ -368,6 +368,13 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--show-dominant-parameters",
type=str2bool,
default=False,
help="Whether to show dominant parameters.",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -988,7 +995,11 @@ def run(rank, world_size, args):
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True) model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) parameters_names = []
parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()])
optimizer = ScaledAdam(model.parameters(), lr=params.base_lr,
clipping_scale=2.0, parameters_names=parameters_names,
show_dominant_parameters=params.show_dominant_parameters)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)