From 9e4ce825acf7227f6fa3fef854916dcf9363e0bc Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Wed, 21 Dec 2022 15:25:51 +0800 Subject: [PATCH 1/2] simplify interface --- .../ASR/pruned_transducer_stateless7/optim.py | 147 ++++++++++++++++-- .../ASR/pruned_transducer_stateless7/train.py | 7 +- 2 files changed, 135 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index ff8fbb32c..9588a3ef8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -18,7 +18,7 @@ import contextlib import logging import random from collections import defaultdict -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch from lhotse.utils import fix_random_seed @@ -132,6 +132,9 @@ class ScaledAdam(BatchedOptimizer): Args: params: The parameters or param_groups to optimize (like other Optimizer subclasses) + Unlike common optimizers, which accepts model.parameters() or groups of parameters(), + this optimizer could accept model.named_parameters() or groups of named_parameters(). + See comments of function _get_names_of_parameters for its 4 possible cases. lr: The learning rate. We will typically use a learning rate schedule that starts at 0.03 and decreases over time, i.e. much higher than other common optimizers. @@ -178,15 +181,8 @@ class ScaledAdam(BatchedOptimizer): scalar_max=10.0, size_update_period=4, clipping_update_period=100, - parameters_names=None, - show_dominant_parameters=True, ): - assert parameters_names is not None, ( - "Please prepare parameters_names," - "which is a List[List[str]]. Each List[str] is for a group" - "and each str is for a parameter" - ) defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -200,10 +196,135 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period=clipping_update_period, ) + # If params only contains parameters or group of parameters, + # i.e when parameter names are not given, + # this flag will be set to False in funciton _get_names_of_parameters. + self.show_dominant_parameters = True + params, parameters_names = self._get_names_of_parameters(params) super(ScaledAdam, self).__init__(params, defaults) assert len(self.param_groups) == len(parameters_names) self.parameters_names = parameters_names - self.show_dominant_parameters = show_dominant_parameters + + def _get_names_of_parameters( + self, params_or_named_params + ) -> Tuple[List[Dict], List[List[str]]]: + """ + Args: + params_or_named_params: according to the way ScaledAdam is initialized in train.py, + this argument could be one of following 4 cases, + case 1, a generator of parameter, e.g.: + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 2, a list of parameter groups with different config, e.g.: + model_param_groups = [ + {'params': model.encoder.parameters(), 'lr': 0.05}, + {'params': model.decoder.parameters(), 'lr': 0.01}, + {'params': model.joiner.parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) + + case 3, a generator of named_parameter, e.g.: + optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 4, a list of named_parameter groups with different config, e.g.: + model_named_param_groups = [ + {'params': model.encoder.named_parameters(), 'lr': 0.05}, + {'params': model.decoder.named_parameters(), 'lr': 0.01}, + {'params': model.joiner.named_parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) + + For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. + For case 3 and case 4, firstly, names and params are extracted from input named_params, + then, these extracted params are used to initialize the underlying torch.optimizer, + and these extracted names are mainly used by function + `_show_gradient_dominating_parameter` + + Returns: + Returns a tuple containing 2 elements: + - `param_groups` with type List[Dict], each Dict element is a parameter group. + An example of `param_groups` could be: + [ + {'params': `one iterable of Parameter`, 'lr': 0.05}, + {'params': `another iterable of Parameter`, 'lr': 0.08}, + {'params': `a third iterable of Parameter`, 'lr': 0.1}, + ] + - `param_gruops_names` with type List[List[str]], + each `List[str]` is for a group['params'] in param_groups, + and each `str` is the name of a parameter. + A dummy name "foo" is related to each parameter, + if input are params without names, i.e. case 1 or case 2. + """ + # variable naming convention in this function: + # p is short for param. + # np is short for named_param. + # p_or_np is short for param_or_named_param. + # curt is short for current. + # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. + # groups is a List[group] + + iterable_or_groups = list(params_or_named_params) + if len(iterable_or_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + + # The first value of returned tuple. + param_groups = [] + + # The second value of returned tuple, + # a List[List[str]], each sub-List is for a group. + param_groups_names = [] + + if not isinstance(iterable_or_groups[0], dict): + # case 1 or case 3, + # the input is an iterable of parameter or named parameter. + param_iterable_curt_group = [] + param_names_curt_group = [] + for p_or_np in iterable_or_groups: + if isinstance(p_or_np, tuple): + # case 3 + name, param = p_or_np + else: + # case 1 + assert isinstance(p_or_np, torch.Tensor) + param = p_or_np + # Assign a dummy name as a placeholder + name = "foo" + self.show_dominant_parameters = False + param_iterable_curt_group.append(param) + param_names_curt_group.append(name) + param_groups.append({"params": param_iterable_curt_group}) + param_groups_names.append(param_names_curt_group) + else: + # case 2 or case 4 + # the input is groups of parameter or named parameter. + for p_or_np_curt_group in iterable_or_groups: + param_iterable_curt_group = [] + param_names_curt_group = [] + p_or_np_iterable = p_or_np_curt_group["params"] + for p_or_np in p_or_np_iterable: + if isinstance(p_or_np, tuple): + # case 4 + name, param = p_or_np + else: + # case 2 + assert isinstance(p_or_np, torch.Tensor) + param = p_or_np + # Assign a dummy name as a placeholder + name = "foo" + self.show_dominant_parameters = False + param_iterable_curt_group.append(param) + param_names_curt_group.append(name) + + # The original `params` filed contains named_parameters. + # After following assignment, + # it will be changed to an iterable of parameter, + # and other fileds, if exist, are still original values. + # So param_groups could be used to initialize + # an underlying torch.Optimizer later. + p_or_np_curt_group["params"] = param_iterable_curt_group + param_groups.append(p_or_np_curt_group) + param_groups_names.append(param_names_curt_group) + return param_groups, param_groups_names def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) @@ -398,7 +519,7 @@ class ScaledAdam(BatchedOptimizer): self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor ): """ - Show information of parameter wihch dominanting tot_sumsq. + Show information of parameter wihch dominating tot_sumsq. Args: tuples: a list of tuples of (param, state, param_names) @@ -416,7 +537,7 @@ class ScaledAdam(BatchedOptimizer): batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars batch_sumsq_orig = batch_grad**2 - # Dummpy values used by following `zip` statement. + # Dummy values used by following `zip` statement. batch_rms_orig = torch.ones(p.shape[0]) else: batch_rms_orig = state["param_rms"] @@ -449,11 +570,11 @@ class ScaledAdam(BatchedOptimizer): dominant_grad, ) = sorted_by_proportion[dominant_param_name] logging.info( - f"Parameter Dominanting tot_sumsq {dominant_param_name}" + f"Parameter dominating tot_sumsq {dominant_param_name}" f" with proportion {dominant_proportion:.2f}," f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" f"={dominant_sumsq:.3e}," - f" grad_sumsq = {(dominant_grad**2).sum():.3e}," + f" grad_sumsq={(dominant_grad**2).sum():.3e}," f" orig_rms_sq={(dominant_rms**2).item():.3e}" ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 31a3a0505..6c8ff1813 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -988,15 +988,10 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) optimizer = ScaledAdam( - model.parameters(), + model.named_parameters(), lr=params.base_lr, clipping_scale=2.0, - parameters_names=parameters_names, ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) From b011448a287784fa14d44d9c6a5aeb8dc9705c27 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Wed, 21 Dec 2022 16:35:38 +0800 Subject: [PATCH 2/2] change current abbreviation --- .../ASR/pruned_transducer_stateless7/optim.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 9588a3ef8..8b6f67d3f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -132,7 +132,7 @@ class ScaledAdam(BatchedOptimizer): Args: params: The parameters or param_groups to optimize (like other Optimizer subclasses) - Unlike common optimizers, which accepts model.parameters() or groups of parameters(), + Unlike common optimizers, which accept model.parameters() or groups of parameters(), this optimizer could accept model.named_parameters() or groups of named_parameters(). See comments of function _get_names_of_parameters for its 4 possible cases. lr: The learning rate. We will typically use a learning rate schedule that starts @@ -259,7 +259,7 @@ class ScaledAdam(BatchedOptimizer): # p is short for param. # np is short for named_param. # p_or_np is short for param_or_named_param. - # curt is short for current. + # cur is short for current. # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. # groups is a List[group] @@ -277,8 +277,8 @@ class ScaledAdam(BatchedOptimizer): if not isinstance(iterable_or_groups[0], dict): # case 1 or case 3, # the input is an iterable of parameter or named parameter. - param_iterable_curt_group = [] - param_names_curt_group = [] + param_iterable_cur_group = [] + param_names_cur_group = [] for p_or_np in iterable_or_groups: if isinstance(p_or_np, tuple): # case 3 @@ -290,17 +290,17 @@ class ScaledAdam(BatchedOptimizer): # Assign a dummy name as a placeholder name = "foo" self.show_dominant_parameters = False - param_iterable_curt_group.append(param) - param_names_curt_group.append(name) - param_groups.append({"params": param_iterable_curt_group}) - param_groups_names.append(param_names_curt_group) + param_iterable_cur_group.append(param) + param_names_cur_group.append(name) + param_groups.append({"params": param_iterable_cur_group}) + param_groups_names.append(param_names_cur_group) else: # case 2 or case 4 # the input is groups of parameter or named parameter. - for p_or_np_curt_group in iterable_or_groups: - param_iterable_curt_group = [] - param_names_curt_group = [] - p_or_np_iterable = p_or_np_curt_group["params"] + for p_or_np_cur_group in iterable_or_groups: + param_iterable_cur_group = [] + param_names_cur_group = [] + p_or_np_iterable = p_or_np_cur_group["params"] for p_or_np in p_or_np_iterable: if isinstance(p_or_np, tuple): # case 4 @@ -312,8 +312,8 @@ class ScaledAdam(BatchedOptimizer): # Assign a dummy name as a placeholder name = "foo" self.show_dominant_parameters = False - param_iterable_curt_group.append(param) - param_names_curt_group.append(name) + param_iterable_cur_group.append(param) + param_names_cur_group.append(name) # The original `params` filed contains named_parameters. # After following assignment, @@ -321,9 +321,9 @@ class ScaledAdam(BatchedOptimizer): # and other fileds, if exist, are still original values. # So param_groups could be used to initialize # an underlying torch.Optimizer later. - p_or_np_curt_group["params"] = param_iterable_curt_group - param_groups.append(p_or_np_curt_group) - param_groups_names.append(param_names_curt_group) + p_or_np_cur_group["params"] = param_iterable_cur_group + param_groups.append(p_or_np_cur_group) + param_groups_names.append(param_names_cur_group) return param_groups, param_groups_names def __setstate__(self, state):