Memory efficient backprop for dropout3

This commit is contained in:
Daniel Povey 2023-01-10 17:46:32 +08:00
parent 4033000730
commit 8bbcd81604

View File

@ -1873,6 +1873,25 @@ class Dropout2(nn.Module):
p=float(self.p), p=float(self.p),
training=self.training) training=self.training)
class MulForDropout3(torch.autograd.Function):
# returns (x * y * alpha) where alpha is a float and y doesn't require
# grad and is zero-or-one.
@staticmethod
@custom_fwd
def forward(self, ctx, x, y, alpha):
assert not y.requires_grad
ans = x * y * alpha
ctx.save_for_backward(ans)
ctx.alpha = alpha
return ans
@staticmethod
@custom_bwd
def backward(self, ctx, ans_grad):
ans, = ctx.saved_tensors
x_grad = ctx.alpha * ans_grad * (ans != 0)
return x_grad, None, None
# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, # Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
# and it lets you choose one dimension to share the dropout mask over # and it lets you choose one dimension to share the dropout mask over
class Dropout3(nn.Module): class Dropout3(nn.Module):
@ -1884,11 +1903,12 @@ class Dropout3(nn.Module):
p = float(self.p) p = float(self.p)
if not self.training or p == 0: if not self.training or p == 0:
return _no_op(x) return _no_op(x)
scale = 1.0 / (1 - self.p) scale = 1.0 / (1 - p)
rand_shape = list(x.shape) rand_shape = list(x.shape)
rand_shape[self.shared_dim] = 1 rand_shape[self.shared_dim] = 1
mask = torch.rand(*rand_shape, device=x.device) > p mask = torch.rand(*rand_shape, device=x.device) > p
return (x * mask) * scale ans = MulForDropout3.apply(x, mask, scale)
return mask
class SwooshLFunction(torch.autograd.Function): class SwooshLFunction(torch.autograd.Function):