mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Memory efficient backprop for dropout3
This commit is contained in:
parent
4033000730
commit
8bbcd81604
@ -1873,6 +1873,25 @@ class Dropout2(nn.Module):
|
|||||||
p=float(self.p),
|
p=float(self.p),
|
||||||
training=self.training)
|
training=self.training)
|
||||||
|
|
||||||
|
class MulForDropout3(torch.autograd.Function):
|
||||||
|
# returns (x * y * alpha) where alpha is a float and y doesn't require
|
||||||
|
# grad and is zero-or-one.
|
||||||
|
@staticmethod
|
||||||
|
@custom_fwd
|
||||||
|
def forward(self, ctx, x, y, alpha):
|
||||||
|
assert not y.requires_grad
|
||||||
|
ans = x * y * alpha
|
||||||
|
ctx.save_for_backward(ans)
|
||||||
|
ctx.alpha = alpha
|
||||||
|
return ans
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@custom_bwd
|
||||||
|
def backward(self, ctx, ans_grad):
|
||||||
|
ans, = ctx.saved_tensors
|
||||||
|
x_grad = ctx.alpha * ans_grad * (ans != 0)
|
||||||
|
return x_grad, None, None
|
||||||
|
|
||||||
# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
|
# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
|
||||||
# and it lets you choose one dimension to share the dropout mask over
|
# and it lets you choose one dimension to share the dropout mask over
|
||||||
class Dropout3(nn.Module):
|
class Dropout3(nn.Module):
|
||||||
@ -1884,11 +1903,12 @@ class Dropout3(nn.Module):
|
|||||||
p = float(self.p)
|
p = float(self.p)
|
||||||
if not self.training or p == 0:
|
if not self.training or p == 0:
|
||||||
return _no_op(x)
|
return _no_op(x)
|
||||||
scale = 1.0 / (1 - self.p)
|
scale = 1.0 / (1 - p)
|
||||||
rand_shape = list(x.shape)
|
rand_shape = list(x.shape)
|
||||||
rand_shape[self.shared_dim] = 1
|
rand_shape[self.shared_dim] = 1
|
||||||
mask = torch.rand(*rand_shape, device=x.device) > p
|
mask = torch.rand(*rand_shape, device=x.device) > p
|
||||||
return (x * mask) * scale
|
ans = MulForDropout3.apply(x, mask, scale)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
class SwooshLFunction(torch.autograd.Function):
|
class SwooshLFunction(torch.autograd.Function):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user