From 89c3982a0760f135740556ae67c11d0af434303c Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 26 Nov 2022 00:50:21 +0800 Subject: [PATCH 1/3] show dominant parameters --- .../ASR/pruned_transducer_stateless7/optim.py | 79 ++++++++++++++++--- .../ASR/pruned_transducer_stateless7/train.py | 13 ++- 2 files changed, 79 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 8b90c9a0d..ab55381d7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -42,7 +42,7 @@ class BatchedOptimizer(Optimizer): super(BatchedOptimizer, self).__init__(params, defaults) @contextlib.contextmanager - def batched_params(self, param_group): + def batched_params(self, param_group, group_params_names=None): """ This function returns (technically, yields) a list of of tuples (p, state), where @@ -75,20 +75,28 @@ class BatchedOptimizer(Optimizer): batches = defaultdict( list ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches_names = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str - for p in param_group: + for p, named_p in zip(param_group, group_params_names): key = (str(p.dtype), *p.shape) batches[key].append(p) + batches_names[key].append(named_p) + + batches_names_keys = list(batches_names.keys()) + sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i]) + batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] + batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] stacked_params_dict = dict() # turn batches into a list, in deterministic order. - batches = [batches[key] for key in sorted(batches.keys())] # pairs will contain pairs of (stacked_param, state), one for each batch # in `batches`. pairs = [] - for batch in batches: + for batch, batch_names in zip(batches, batches_names): p = batch[0] # we arbitrarily store the state in the # state corresponding to the 1st parameter in the @@ -100,11 +108,11 @@ class BatchedOptimizer(Optimizer): ) p_stacked.grad = grad stacked_params_dict[key] = p_stacked - pairs.append((p_stacked, state)) + pairs.append((p_stacked, state, batch_names)) yield pairs # <-- calling code will do the actual optimization here! - for ((stacked_params, _state), batch) in zip(pairs, batches): + for ((stacked_params, _state, _names), batch) in zip(pairs, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -165,6 +173,8 @@ class ScaledAdam(BatchedOptimizer): scalar_max=10.0, size_update_period=4, clipping_update_period=100, + parameters_names=None, + show_dominant_parameters=False, ): defaults = dict( @@ -181,6 +191,8 @@ class ScaledAdam(BatchedOptimizer): ) super(ScaledAdam, self).__init__(params, defaults) + self.parameters_names = parameters_names + self.show_dominant_parameters = show_dominant_parameters def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) @@ -199,9 +211,11 @@ class ScaledAdam(BatchedOptimizer): loss = closure() batch = True - for group in self.param_groups: + assert len(self.param_groups) == len(self.parameters_names) - with self.batched_params(group["params"]) as batches: + for group, group_params_names in zip(self.param_groups, self.parameters_names): + + with self.batched_params(group["params"], group_params_names) as batches: # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to @@ -214,7 +228,7 @@ class ScaledAdam(BatchedOptimizer): else: clipping_scale = self._get_clipping_scale(group, batches) - for p, state in batches: + for p, state, _ in batches: # Perform optimization step. # grad is not going to be None, we handled that when creating the batches. grad = p.grad @@ -276,7 +290,7 @@ class ScaledAdam(BatchedOptimizer): state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) def _get_clipping_scale( - self, group: dict, pairs: List[Tuple[Tensor, dict]] + self, group: dict, pairs: List[Tuple[Tensor, dict, List[str]]] ) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients @@ -289,7 +303,7 @@ class ScaledAdam(BatchedOptimizer): """ assert len(pairs) >= 1 clipping_scale = group["clipping_scale"] - (first_p, first_state) = pairs[0] + (first_p, first_state, _) = pairs[0] step = first_state["step"] if clipping_scale is None or step == 0: # no clipping. return early on step == 0 because the other @@ -298,7 +312,7 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state) in pairs: + for (p, state, param_names) in pairs: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -361,8 +375,49 @@ class ScaledAdam(BatchedOptimizer): logging.warn( f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter(pairs, tot_sumsq) return ans + def _show_gradient_dominating_parameter(self, pairs, tot_sumsq): + # ori means calculated with state["param_rms"] + # cur means calculated with "param_rms" of current param. + # bt is short batch + # all_sumsq_ori_rms + all_sumsq_ori = {} + all_sumsq_cur = {} + for (p, state, batch_param_names) in pairs: + # p is a stacked batch parameters. + grad = p.grad + if p.numel() == p.shape[0]: # a batch of scalars + batch_sumsq_ori = grad**2 # sum() to change shape [1] to [] + batch_sumsq_cur = batch_sumsq_ori # sum() to change shape [1] to [] + # Dummpy values used by following `zip` statement. + batch_rms_ori = torch.zeros(p.shape[0]) + batch_rms_cur = batch_rms_ori + else: + batch_rms_ori = state["param_rms"] + batch_sumsq_ori = ((grad * batch_rms_ori) ** 2).sum(dim=list(range(1, grad.ndim))) + + batch_rms_cur = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + batch_sumsq_cur = ((grad * batch_rms_cur) ** 2).sum(dim=list(range(1, grad.ndim))) + + for name, sumsq_ori, sumsq_cur in zip( + batch_param_names, batch_sumsq_ori, batch_sumsq_cur): + + proportion_ori = sumsq_ori / tot_sumsq + proportion_cur = sumsq_cur / tot_sumsq + + all_sumsq_ori[name] = (proportion_ori, sumsq_ori) + all_sumsq_cur[name] = (proportion_cur, sumsq_cur) + + for rms_type, all_sumsq in zip(("ori", "cur"), (all_sumsq_ori, all_sumsq_cur)): + sorted_by_proportion = {k: v for k, v in sorted(all_sumsq.items(), key=lambda item: item[1][0], reverse=True)} + dominant_param_name = next(iter(sorted_by_proportion)) + dominant_proportion, dominant_sumsq = sorted_by_proportion[dominant_param_name] + logging.info(f"Dominant sumsq with {rms_type}_rms: {dominant_param_name} {dominant_proportion} {dominant_sumsq} {tot_sumsq}") + def _step_one_batch( self, group: dict, p: Tensor, state: dict, clipping_scale: float ): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index b27c573ab..8375b1a18 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -368,6 +368,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--show-dominant-parameters", + type=str2bool, + default=False, + help="Whether to show dominant parameters.", + ) + add_model_arguments(parser) return parser @@ -988,7 +995,11 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + parameters_names = [] + parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()]) + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, + clipping_scale=2.0, parameters_names=parameters_names, + show_dominant_parameters=params.show_dominant_parameters) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) From 9cf79cac3f23757e25b499947f045efb0f4d71a6 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 26 Nov 2022 21:48:17 +0800 Subject: [PATCH 2/3] message formatting --- .../ASR/pruned_transducer_stateless7/optim.py | 76 +++++++++++-------- .../ASR/pruned_transducer_stateless7/train.py | 10 +-- 2 files changed, 45 insertions(+), 41 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index ab55381d7..790752fe1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -42,7 +42,7 @@ class BatchedOptimizer(Optimizer): super(BatchedOptimizer, self).__init__(params, defaults) @contextlib.contextmanager - def batched_params(self, param_group, group_params_names=None): + def batched_params(self, param_group, group_params_names): """ This function returns (technically, yields) a list of of tuples (p, state), where @@ -85,7 +85,9 @@ class BatchedOptimizer(Optimizer): batches_names[key].append(named_p) batches_names_keys = list(batches_names.keys()) - sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i]) + sorted_idx = sorted( + range(len(batches_names)), key=lambda i: batches_names_keys[i] + ) batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] @@ -174,7 +176,7 @@ class ScaledAdam(BatchedOptimizer): size_update_period=4, clipping_update_period=100, parameters_names=None, - show_dominant_parameters=False, + show_dominant_parameters=True, ): defaults = dict( @@ -211,7 +213,7 @@ class ScaledAdam(BatchedOptimizer): loss = closure() batch = True - assert len(self.param_groups) == len(self.parameters_names) + assert len(self.param_groups) == len(self.parameters_names) for group, group_params_names in zip(self.param_groups, self.parameters_names): @@ -381,42 +383,52 @@ class ScaledAdam(BatchedOptimizer): return ans def _show_gradient_dominating_parameter(self, pairs, tot_sumsq): - # ori means calculated with state["param_rms"] - # cur means calculated with "param_rms" of current param. - # bt is short batch - # all_sumsq_ori_rms - all_sumsq_ori = {} - all_sumsq_cur = {} + all_sumsq_orig = {} for (p, state, batch_param_names) in pairs: # p is a stacked batch parameters. - grad = p.grad + batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars - batch_sumsq_ori = grad**2 # sum() to change shape [1] to [] - batch_sumsq_cur = batch_sumsq_ori # sum() to change shape [1] to [] + batch_sumsq_orig = batch_grad**2 # Dummpy values used by following `zip` statement. - batch_rms_ori = torch.zeros(p.shape[0]) - batch_rms_cur = batch_rms_ori + batch_rms_orig = torch.ones(p.shape[0]) else: - batch_rms_ori = state["param_rms"] - batch_sumsq_ori = ((grad * batch_rms_ori) ** 2).sum(dim=list(range(1, grad.ndim))) + batch_rms_orig = state["param_rms"] + batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( + dim=list(range(1, batch_grad.ndim)) + ) - batch_rms_cur = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - batch_sumsq_cur = ((grad * batch_rms_cur) ** 2).sum(dim=list(range(1, grad.ndim))) + for name, sumsq_orig, rms, grad in zip( + batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad + ): - for name, sumsq_ori, sumsq_cur in zip( - batch_param_names, batch_sumsq_ori, batch_sumsq_cur): + proportion_orig = sumsq_orig / tot_sumsq + all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) - proportion_ori = sumsq_ori / tot_sumsq - proportion_cur = sumsq_cur / tot_sumsq - - all_sumsq_ori[name] = (proportion_ori, sumsq_ori) - all_sumsq_cur[name] = (proportion_cur, sumsq_cur) - - for rms_type, all_sumsq in zip(("ori", "cur"), (all_sumsq_ori, all_sumsq_cur)): - sorted_by_proportion = {k: v for k, v in sorted(all_sumsq.items(), key=lambda item: item[1][0], reverse=True)} - dominant_param_name = next(iter(sorted_by_proportion)) - dominant_proportion, dominant_sumsq = sorted_by_proportion[dominant_param_name] - logging.info(f"Dominant sumsq with {rms_type}_rms: {dominant_param_name} {dominant_proportion} {dominant_sumsq} {tot_sumsq}") + assert torch.isclose( + sum([value[0] for value in all_sumsq_orig.values()]).cpu(), + torch.tensor(1.0), + ) + sorted_by_proportion = { + k: v + for k, v in sorted( + all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True + ) + } + dominant_param_name = next(iter(sorted_by_proportion)) + ( + dominant_proportion, + dominant_sumsq, + dominant_rms, + dominant_grad, + ) = sorted_by_proportion[dominant_param_name] + logging.info( + f"Parameter Dominanting tot_sumsq {dominant_param_name}" + f" with proportion {dominant_proportion:.2f}," + f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" + f"={dominant_sumsq:.3e}," + f" grad_sumsq = {(dominant_grad**2).sum():.3e}," + f" orig_rms_sq={(dominant_rms**2).item():.3e}" + ) def _step_one_batch( self, group: dict, p: Tensor, state: dict, clipping_scale: float diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 8375b1a18..e5a3e68df 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -368,13 +368,6 @@ def get_parser(): help="Whether to use half precision training.", ) - parser.add_argument( - "--show-dominant-parameters", - type=str2bool, - default=False, - help="Whether to show dominant parameters.", - ) - add_model_arguments(parser) return parser @@ -998,8 +991,7 @@ def run(rank, world_size, args): parameters_names = [] parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()]) optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, - clipping_scale=2.0, parameters_names=parameters_names, - show_dominant_parameters=params.show_dominant_parameters) + clipping_scale=2.0, parameters_names=parameters_names) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) From 4fee3e7f1ea6c2aefe7594e325ede1e530e54d3d Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 28 Nov 2022 16:55:18 +0800 Subject: [PATCH 3/3] impove comment --- .../ASR/pruned_transducer_stateless7/optim.py | 63 +++++++++++++------ .../ASR/pruned_transducer_stateless7/train.py | 12 +++- 2 files changed, 54 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 790752fe1..ff8fbb32c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -64,13 +64,15 @@ class BatchedOptimizer(Optimizer): you can do: with self.batched_params(group["params"]) as batches: - for p, state in batches: + for p, state, p_names in batches: ... Args: group: a parameter group, which is a list of parameters; should be - one of self.groups. + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. """ batches = defaultdict( list @@ -79,6 +81,7 @@ class BatchedOptimizer(Optimizer): list ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str + assert len(param_group) == len(group_params_names) for p, named_p in zip(param_group, group_params_names): key = (str(p.dtype), *p.shape) batches[key].append(p) @@ -94,9 +97,9 @@ class BatchedOptimizer(Optimizer): stacked_params_dict = dict() # turn batches into a list, in deterministic order. - # pairs will contain pairs of (stacked_param, state), one for each batch - # in `batches`. - pairs = [] + # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # one for each batch in `batches`. + tuples = [] for batch, batch_names in zip(batches, batches_names): p = batch[0] @@ -110,11 +113,11 @@ class BatchedOptimizer(Optimizer): ) p_stacked.grad = grad stacked_params_dict[key] = p_stacked - pairs.append((p_stacked, state, batch_names)) + tuples.append((p_stacked, state, batch_names)) - yield pairs # <-- calling code will do the actual optimization here! + yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(pairs, batches): + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -179,6 +182,11 @@ class ScaledAdam(BatchedOptimizer): show_dominant_parameters=True, ): + assert parameters_names is not None, ( + "Please prepare parameters_names," + "which is a List[List[str]]. Each List[str] is for a group" + "and each str is for a parameter" + ) defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -193,6 +201,7 @@ class ScaledAdam(BatchedOptimizer): ) super(ScaledAdam, self).__init__(params, defaults) + assert len(self.param_groups) == len(parameters_names) self.parameters_names = parameters_names self.show_dominant_parameters = show_dominant_parameters @@ -213,7 +222,6 @@ class ScaledAdam(BatchedOptimizer): loss = closure() batch = True - assert len(self.param_groups) == len(self.parameters_names) for group, group_params_names in zip(self.param_groups, self.parameters_names): @@ -292,7 +300,7 @@ class ScaledAdam(BatchedOptimizer): state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) def _get_clipping_scale( - self, group: dict, pairs: List[Tuple[Tensor, dict, List[str]]] + self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] ) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients @@ -300,12 +308,16 @@ class ScaledAdam(BatchedOptimizer): Args: group: the parameter group, an item in self.param_groups - pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad - (1st dim is batch dim) and state is the state-dict where optimization parameters are kept. + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". """ - assert len(pairs) >= 1 + assert len(tuples) >= 1 clipping_scale = group["clipping_scale"] - (first_p, first_state, _) = pairs[0] + (first_p, first_state, _) = tuples[0] step = first_state["step"] if clipping_scale is None or step == 0: # no clipping. return early on step == 0 because the other @@ -314,7 +326,7 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in pairs: + for (p, state, param_names) in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -379,12 +391,27 @@ class ScaledAdam(BatchedOptimizer): ) if self.show_dominant_parameters: assert p.shape[0] == len(param_names) - self._show_gradient_dominating_parameter(pairs, tot_sumsq) + self._show_gradient_dominating_parameter(tuples, tot_sumsq) return ans - def _show_gradient_dominating_parameter(self, pairs, tot_sumsq): + def _show_gradient_dominating_parameter( + self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor + ): + """ + Show information of parameter wihch dominanting tot_sumsq. + + Args: + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + tot_sumsq: sumsq of all parameters. Though it's could be calculated + from tuples, we still pass it to save some time. + """ all_sumsq_orig = {} - for (p, state, batch_param_names) in pairs: + for (p, state, batch_param_names) in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index e5a3e68df..31a3a0505 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -989,9 +989,15 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank], find_unused_parameters=True) parameters_names = [] - parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()]) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, - clipping_scale=2.0, parameters_names=parameters_names) + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)