diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 161a5aa4a..0dfc1e5d8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1003,38 +1003,38 @@ class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, - num_groups: int, - whitening_limit: float, - grad_scale: float, - name: Optional[str]) -> Tensor: + module: nn.Module) -> Tensor: ctx.save_for_backward(x) - ctx.num_groups = num_groups - ctx.whitening_limit = whitening_limit - ctx.grad_scale = grad_scale - ctx.name = name + ctx.module = module return x @staticmethod def backward(ctx, x_grad: Tensor): x_orig, = ctx.saved_tensors + w = ctx.module with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True - metric = _whitening_metric(x_detached, ctx.num_groups) + metric = _whitening_metric(x_detached, w.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: name={ctx.name}, num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") + logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}") - (metric - ctx.whitening_limit).relu().backward() - penalty_grad = x_detached.grad - scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) - penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None, None, None, None + if metric < float(w.whitening_limit): + w.prob = w.min_prob + return x_grad, None + else: + w.prob = w.max_prob + metric.backward() + penalty_grad = x_detached.grad + scale = w.grad_scale * (x_grad.to(torch.float32).norm() / + (penalty_grad.norm() + 1.0e-20)) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None class Whiten(nn.Module): @@ -1101,21 +1101,7 @@ class Whiten(nn.Module): if not x.requires_grad or random.random() > self.prob or grad_scale == 0: return _no_op(x) else: - whitening_limit = float(self.whitening_limit) - if hasattr(self, 'min_prob') and random.random() < 0.25: - # occasionally switch between min_prob and max_prob, based on whether - # we are above or below the threshold. - if _whitening_metric(x.to(torch.float32), self.num_groups) > whitening_limit: - # there would be a change to the grad. - self.prob = self.max_prob - else: - self.prob = self.min_prob - - return WhiteningPenaltyFunction.apply(x, - self.num_groups, - whitening_limit, - grad_scale, - self.name) + return WhiteningPenaltyFunction.apply(x, self) class WithLoss(torch.autograd.Function):