From 6f5c4688efc58d6d92200daff77e48002e1010fe Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 30 Apr 2023 15:19:34 +0800 Subject: [PATCH] Add (back) straight_through_rate, with rate 0.025; try to handle memory allocation failures in backprop better. --- .../pruned_transducer_stateless7/scaling.py | 104 ++++++++++-------- .../pruned_transducer_stateless7/zipformer.py | 14 ++- 2 files changed, 68 insertions(+), 50 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 9b6c8880b..603110d95 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -683,40 +683,44 @@ class BalancerFunction(torch.autograd.Function): x, = ctx.saved_tensors (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x = x.to(torch.float32) - x = x.detach() - x.requires_grad = True - mean_dims = [ i for i in range(x.ndim) if i != channel_dim ] - uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True) - mean = x.mean(dim=mean_dims, keepdim=True) - stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() - rms = uncentered_var.clamp(min=1.0e-20).sqrt() - m = mean / stddev - # part of loss that relates to mean / stddev - m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() + try: + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x = x.to(torch.float32) + x = x.detach() + x.requires_grad = True + mean_dims = [ i for i in range(x.ndim) if i != channel_dim ] + uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True) + mean = x.mean(dim=mean_dims, keepdim=True) + stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() + rms = uncentered_var.clamp(min=1.0e-20).sqrt() - # put a much larger scale on the RMS-max-limit loss, so that if both it and the - # m_loss are violated we fix the RMS loss first. - rms_clamped = rms.clamp(min=min_rms, max=max_rms) - r_loss = (rms_clamped / rms).log().abs() + m = mean / stddev + # part of loss that relates to mean / stddev + m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() - loss = (m_loss + r_loss) + # put a much larger scale on the RMS-max-limit loss, so that if both it and the + # m_loss are violated we fix the RMS loss first. + rms_clamped = rms.clamp(min=min_rms, max=max_rms) + r_loss = (rms_clamped / rms).log().abs() - loss.backward(gradient=torch.ones_like(loss)) - loss_grad = x.grad - loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20) + loss = (m_loss + r_loss) - loss_grad = loss_grad * (grad_scale / loss_grad_rms) + loss.backward(gradient=torch.ones_like(loss)) + loss_grad = x.grad + loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20) - x_grad_float = x_grad.to(torch.float32) - # scale each element of loss_grad by the absolute value of the corresponding - # element of x_grad, which we view as a noisy estimate of its magnitude for that - # (frame and dimension). later we can consider factored versions. - x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) - x_grad = x_grad_mod.to(x_grad.dtype) + loss_grad = loss_grad * (grad_scale / loss_grad_rms) + + x_grad_float = x_grad.to(torch.float32) + # scale each element of loss_grad by the absolute value of the corresponding + # element of x_grad, which we view as a noisy estimate of its magnitude for that + # (frame and dimension). later we can consider factored versions. + x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) + x_grad = x_grad_mod.to(x_grad.dtype) + except Exception as e: + logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.") return x_grad, None, None, None, None, None, None @@ -924,28 +928,34 @@ class WhiteningPenaltyFunction(torch.autograd.Function): x_grad: Tensor): x_orig, = ctx.saved_tensors w = ctx.module - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x_detached = x_orig.to(torch.float32).detach() - x_detached.requires_grad = True - metric = _whitening_metric(x_detached, w.num_groups) + try: + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x_detached = x_orig.to(torch.float32).detach() + x_detached.requires_grad = True - if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}") + metric = _whitening_metric(x_detached, w.num_groups) + + if random.random() < 0.005 or __name__ == "__main__": + logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}") + + if metric < float(w.whitening_limit): + w.prob = w.min_prob + return x_grad, None + else: + w.prob = w.max_prob + metric.backward() + penalty_grad = x_detached.grad + scale = w.grad_scale * (x_grad.to(torch.float32).norm() / + (penalty_grad.norm() + 1.0e-20)) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None + except Exception as e: + logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.") + return x_grad, None - if metric < float(w.whitening_limit): - w.prob = w.min_prob - return x_grad, None - else: - w.prob = w.max_prob - metric.backward() - penalty_grad = x_detached.grad - scale = w.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) - penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None class Whiten(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c138da451..4ead015a4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -427,9 +427,10 @@ class Zipformer2EncoderLayer(nn.Module): self.embed_dim = embed_dim # self.bypass implements layer skipping as well as bypass; see its default values. - self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate) + self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate, + straight_through_rate=0.025) # bypass_mid is bypass used in the middle of the layer. - self.bypass_mid = BypassModule(embed_dim) + self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0.025) # skip probability for dynamic modules (meaning: anything but feedforward). @@ -768,11 +769,13 @@ class BypassModule(nn.Module): self, embed_dim: int, skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), scale_max: FloatLike = 1.0): super().__init__() self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) self.skip_rate = copy.deepcopy(skip_rate) + self.straight_through_rate = copy.deepcopy(straight_through_rate) self.scale_min = copy.deepcopy(scale_min) self.scale_max = copy.deepcopy(scale_max) @@ -794,6 +797,11 @@ class BypassModule(nn.Module): ans = ans * mask # now ans is of shape (batch_size, num_channels), and is zero for sequences # on which we have randomly chosen to do layer-skipping. + straight_through_rate = float(self.straight_through_rate) + if straight_through_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) < straight_through_rate + ans = torch.maximum(ans, mask.to(ans.dtype)) + return ans def forward(self, @@ -826,7 +834,7 @@ class DownsampledZipformer2Encoder(nn.Module): downsample, dropout) self.encoder = encoder self.upsample = SimpleUpsample(dim, downsample) - self.out_combiner = BypassModule(dim) + self.out_combiner = BypassModule(dim, straight_through_rate=0.025) def forward(self,