From 8bbcd816041e0ce478e3a56eaac732095a89ec53 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 10 Jan 2023 17:46:32 +0800 Subject: [PATCH] Memory efficient backprop for dropout3 --- .../pruned_transducer_stateless7/scaling.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index c925fc32f..143690c3b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1873,6 +1873,25 @@ class Dropout2(nn.Module): p=float(self.p), training=self.training) +class MulForDropout3(torch.autograd.Function): + # returns (x * y * alpha) where alpha is a float and y doesn't require + # grad and is zero-or-one. + @staticmethod + @custom_fwd + def forward(self, ctx, x, y, alpha): + assert not y.requires_grad + ans = x * y * alpha + ctx.save_for_backward(ans) + ctx.alpha = alpha + return ans + + @staticmethod + @custom_bwd + def backward(self, ctx, ans_grad): + ans, = ctx.saved_tensors + x_grad = ctx.alpha * ans_grad * (ans != 0) + return x_grad, None, None + # Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, # and it lets you choose one dimension to share the dropout mask over class Dropout3(nn.Module): @@ -1884,11 +1903,12 @@ class Dropout3(nn.Module): p = float(self.p) if not self.training or p == 0: return _no_op(x) - scale = 1.0 / (1 - self.p) + scale = 1.0 / (1 - p) rand_shape = list(x.shape) rand_shape[self.shared_dim] = 1 mask = torch.rand(*rand_shape, device=x.device) > p - return (x * mask) * scale + ans = MulForDropout3.apply(x, mask, scale) + return mask class SwooshLFunction(torch.autograd.Function):