From e996f7f3717b10951b678fc9f836eb2bbb78ccbe Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 18 Jun 2022 15:17:51 +0800 Subject: [PATCH] First version I am running, of the speedup code --- .../ASR/pruned_transducer_stateless7/optim.py | 118 +++++++++++++----- .../ASR/pruned_transducer_stateless7/train.py | 14 ++- 2 files changed, 97 insertions(+), 35 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index f50f1e763..e8c9f6c0e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -22,6 +22,7 @@ import random from torch import Tensor from torch.optim import Optimizer from icefall import diagnostics # only for testing code +import logging class NeutralGradient(Optimizer): """ @@ -90,9 +91,9 @@ class NeutralGradient(Optimizer): estimate_period=2000, stats_steps=200, param_pow=1.0, - lr_for_speedup=0.1, - speedup_first_step=500, # TEMP. Make this 3500? must be > warmup. + lr_for_speedup=0.03, speedup_recalibrate_period=100, + speedup_first_step=500, ): if not 0.0 <= lr: @@ -117,6 +118,10 @@ class NeutralGradient(Optimizer): raise ValueError("Invalid estimate_period value: {}".format(estimate_period)) if not 0 < stats_steps < estimate_period: raise ValueError("Invalid stats_steps value: {}".format(stats_steps)) + if not 0 < speedup_first_step: + raise ValueError("Invalid speedup_first_step value: {}".format(speedup_first_step)) + if not 0 < speedup_recalibrate_period: + raise ValueError("Invalid speedup_recalibrate_period value: {}".format(speedup_recalibrate_period)) defaults = dict( @@ -132,9 +137,9 @@ class NeutralGradient(Optimizer): stats_steps=stats_steps, param_pow=param_pow, lr_for_speedup=lr_for_speedup, - speedup_first_step=speedup_first_step, speedup_recalibrate_period=speedup_recalibrate_period, - group_step=0, + speedup_first_step=speedup_first_step, + speedup_step=-speedup_first_step, ) super(NeutralGradient, self).__init__(params, defaults) @@ -171,11 +176,11 @@ class NeutralGradient(Optimizer): speedup_first_step = group["speedup_first_step"] speedup_recalibrate_period = group["speedup_recalibrate_period"] - group_step = group["group_step"] - group["group_step"] = group_step + 1 + speedup_step = group["speedup_step"] + group["speedup_step"] = speedup_step + 1 - recalibrate = (group_step >= speedup_first_step and - (group_step - speedup_first_step) % speedup_recalibrate_period == 0) + recalibrate = (speedup_step >= 0 and + speedup_step % speedup_recalibrate_period == 0) if recalibrate: # param_prods will be an array of products, one per parameter tensor, equal to: # E[grad . step], @@ -270,31 +275,26 @@ class NeutralGradient(Optimizer): step_period = state["step_period"] n_cached_grads = state["n_cached_grads"] - if step_period > 1: + if step_period > 1 or n_cached_grads > 0: cached_grad = state["cached_grad"] - cached_grad += grad n_cached_grads += 1 state["n_cached_grads"] = n_cached_grads # the point of random_step_prob is to inject a little randomness # into the timing of the updates, so they don't stay synchronized. random_step_prob = 0.05 / step_period - if n_cached_grads == step_period or random.random() < random_step_prob: + if n_cached_grads >= step_period or random.random() < random_step_prob: # We will continue to the code below, and do a step - grad = cached_grad + grad += cached_grad + cached_grad.zero_() + state["n_cached_grads"] = 0 else: + cached_grad += grad # Don't do a step this time. continue else: # We are doing a step on every iteration. n_cached_grads = 1 # will be used to scale gradients below. - # sqrt_n_grads will be used to scale down estimates of grad variances - # (in situations where the scale matters), to reflect the sizes of - # grads of individual - inv_sqrt_n_grads = n_cached_grads ** 0.5 - - - delta = state["delta"] step = state["step"] @@ -410,12 +410,14 @@ class NeutralGradient(Optimizer): if random.random() < 0.0001: print(f"Delta rms = {(delta**2).mean().item()}, shape = {delta.shape}") - max_abs = delta.abs().max() - if max_abs > 1.0: - print("Max_abs_change = ", max_abs) + #max_abs = delta.abs().max() + #if max_abs > 1.0: + # print("Max_abs_change = ", max_abs) p.add_(delta) if step % 10 == 0: p.clamp_(min=-param_max, max=param_max) + + state["step"] += 1 return loss @@ -551,6 +553,10 @@ class NeutralGradient(Optimizer): param_rel_eps*param_rel_eps * params_sq.mean()) + params_sq_norm = params_sq + if param_pow != 1.0: + params_sq_partnorm = params_sq + for _ in range(3): # Iterate 3 times, this should be enough to converge. for dim in range(p.ndim): @@ -561,8 +567,10 @@ class NeutralGradient(Optimizer): other_dims = [ i for i in range(p.ndim) if i != dim ] # Compute diagonal variance along this dimension # (this is after normalizing previous dimensions' variance) - this_var = params_sq.mean(dim=other_dims, keepdim=True) - params_sq = params_sq / this_var + this_var = params_sq_norm.mean(dim=other_dims, keepdim=True) + params_sq_norm = params_sq_norm / this_var + if param_pow != 1.0: + params_sq_partnorm = params_sq_partnorm / (this_var ** param_pow) this_var = this_var.reshape(size) this_scale = (this_var ** (param_pow * 0.5)).reshape(size) @@ -576,6 +584,17 @@ class NeutralGradient(Optimizer): # diagonalized space. proj *= this_scale.unsqueeze(-1) + if param_pow != 1.0: + # need to get the overall scale corret, as if we had param_pow == 1.0 + scale = (params_sq_partnorm.mean() ** 0.5) + for dim in range(p.ndim): + size = p.shape[dim] + if size == 1: + continue + else: + state[f"proj_{dim}"] *= scale + break + # Reset the squared-gradient stats because we have changed the basis. state["exp_avg_sq"].fill_(0.0) @@ -789,13 +808,29 @@ class NeutralGradient(Optimizer): return P + def reset_speedup(self): + """ + Reset the step_period so that we'll do a step each time, until the next + time + """ + for group in self.param_groups: + group["speedup_step"] = -group["speedup_first_step"] + for p in group["params"]: + state = self.state[p] + state["step_period"] = 1 + def _recalibrate_speedup(self, group: dict): param_prods = [] speedup = group["lr_for_speedup"] / group["lr"] if speedup > 10.0: speedup = 10.0 - if speedup < 2.0: - speedup = 2.0 + if speedup <= 1.0: + for p in group["params"]: + state = self.state[p] + state["step_period"] = 1 + return + + _, beta2 = group["betas"] for p in group["params"]: @@ -819,13 +854,16 @@ class NeutralGradient(Optimizer): param_prods.append(prod) param_prods = torch.stack(param_prods).to('cpu') + # TEMP + param_prods = param_prods.sqrt() avg_update_freq = 1.0 / speedup param_freqs = param_prods * (avg_update_freq / param_prods.mean()) - # Don't delay more than 40 steps for any parameter. Note: with beta=0.1, - # this would mean an average delay of about 500 steps before the gradient - # makes it to the parameter. - min_freq = 1.0 / 40 + # Don't delay more than 1 / (3*speedup) steps for any parameter, to + # ensure that no parameter waits super-long for an update. Note: with + # beta=0.1, this would mean an average delay of about 30*speedup steps + # before the gradient makes it to the parameter. + min_freq = 1.0 / (3 * speedup) # get the mean after applying limits of [min_freq..1] for the frequencies of # update @@ -836,15 +874,23 @@ class NeutralGradient(Optimizer): param_freqs = param_freqs.clamp(min=min_freq, max=1.0) param_periods = (1.0 / param_freqs).to(torch.int32) # .to(torch.int32) rounds down + actual_speedup = 1.0 / (1.0 / param_periods).mean() + param_prods = param_prods.tolist() param_freqs = param_freqs.tolist() param_periods = param_periods.tolist() - print("speedup = ", speedup) - for i,p in enumerate(group["params"]): + logging.info(f"NeutralGradient._recalibrate_speedup: speedup = {speedup:.2g}, actual_speedup = {actual_speedup:.2g}") + print_info = True # random.random() < 0.1 + i = 0 + for p in group["params"]: if p.grad is None: continue - print(f"Shape = {p.shape}, prod = {param_prods[i]}, freq = {param_freqs[i]}, period = {param_periods[i]}.") + state = self.state[p] + state["step_period"] = param_periods[i] + if print_info: + logging.info(f"Shape = {p.shape}, prod = {param_prods[i]}, freq = {param_freqs[i]}, period = {param_periods[i]}.") + i += 1 # Do nothing more, as yet. @@ -1507,6 +1553,8 @@ def _test_eve_cain(): avg_loss = 0.0 for epoch in range(150): scheduler.step_epoch() + if epoch == 100 and iter in [2,3]: + optim.reset_speedup() # check it doesn't crash. if epoch == 130: opts = diagnostics.TensorDiagnosticOptions( @@ -1531,7 +1579,8 @@ def _test_eve_cain(): #scale1b = '%.2e' % (m[0].bias_scale.exp().item()) #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) - print(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + lr = scheduler.get_last_lr()[0] + print(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} loss.log().backward() optim.step() optim.zero_grad() @@ -1562,6 +1611,7 @@ def _test_eve_cain(): if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) _test_eve_cain() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index da7cde00c..300f0f0f8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -762,6 +762,16 @@ def train_one_epoch( # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() scheduler.step_batch(params.batch_idx_train) + + if params.batch_idx_train == params.model_warm_step: + # we're about to start using the pruned loss, which brings new + # modules into play, so reset the frequencies of update, to + # avoid possible instability. + try: + optimizer.reset_speedup() + logging.info("Reset speedup on optimizer") + except: + pass scaler.step(optimizer) scaler.update() optimizer.zero_grad() @@ -916,7 +926,9 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank]) - optimizer = NeutralGradient(model.parameters(), lr=params.initial_lr) + optimizer = NeutralGradient(model.parameters(), + lr=params.initial_lr, + lr_for_speedup=params.initial_lr) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)