mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Memory efficient backprop for dropout3
This commit is contained in:
parent
4033000730
commit
8bbcd81604
@ -1873,6 +1873,25 @@ class Dropout2(nn.Module):
|
||||
p=float(self.p),
|
||||
training=self.training)
|
||||
|
||||
class MulForDropout3(torch.autograd.Function):
|
||||
# returns (x * y * alpha) where alpha is a float and y doesn't require
|
||||
# grad and is zero-or-one.
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(self, ctx, x, y, alpha):
|
||||
assert not y.requires_grad
|
||||
ans = x * y * alpha
|
||||
ctx.save_for_backward(ans)
|
||||
ctx.alpha = alpha
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(self, ctx, ans_grad):
|
||||
ans, = ctx.saved_tensors
|
||||
x_grad = ctx.alpha * ans_grad * (ans != 0)
|
||||
return x_grad, None, None
|
||||
|
||||
# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
|
||||
# and it lets you choose one dimension to share the dropout mask over
|
||||
class Dropout3(nn.Module):
|
||||
@ -1884,11 +1903,12 @@ class Dropout3(nn.Module):
|
||||
p = float(self.p)
|
||||
if not self.training or p == 0:
|
||||
return _no_op(x)
|
||||
scale = 1.0 / (1 - self.p)
|
||||
scale = 1.0 / (1 - p)
|
||||
rand_shape = list(x.shape)
|
||||
rand_shape[self.shared_dim] = 1
|
||||
mask = torch.rand(*rand_shape, device=x.device) > p
|
||||
return (x * mask) * scale
|
||||
ans = MulForDropout3.apply(x, mask, scale)
|
||||
return mask
|
||||
|
||||
|
||||
class SwooshLFunction(torch.autograd.Function):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user