Share dropout masks across time in ff modules

This commit is contained in:
Daniel Povey 2023-01-10 16:49:36 +08:00
parent 3110ed045a
commit 4033000730
2 changed files with 21 additions and 1 deletions

View File

@ -1873,6 +1873,24 @@ class Dropout2(nn.Module):
p=float(self.p),
training=self.training)
# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
# and it lets you choose one dimension to share the dropout mask over
class Dropout3(nn.Module):
def __init__(self, p: FloatLike, shared_dim: int):
super().__init__()
self.p = p
self.shared_dim = shared_dim
def forward(self, x: Tensor) -> Tensor:
p = float(self.p)
if not self.training or p == 0:
return _no_op(x)
scale = 1.0 / (1 - self.p)
rand_shape = list(x.shape)
rand_shape[self.shared_dim] = 1
mask = torch.rand(*rand_shape, device=x.device) > p
return (x * mask) * scale
class SwooshLFunction(torch.autograd.Function):
"""
swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035

View File

@ -31,6 +31,7 @@ from scaling import (
ConvNorm1d,
ConvNorm2d,
Dropout2,
Dropout3,
MaxEig,
DoubleSwish,
SwooshL,
@ -1544,7 +1545,8 @@ class FeedforwardModule(nn.Module):
min_abs=0.75,
max_abs=5.0)
self.activation = SwooshL()
self.dropout = Dropout2(dropout)
# shared_dim=0 means we share the dropout mask along the time axis
self.dropout = Dropout3(dropout, shared_dim=0)
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
initial_scale=0.1)