mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Share dropout masks across time in ff modules
This commit is contained in:
parent
3110ed045a
commit
4033000730
@ -1873,6 +1873,24 @@ class Dropout2(nn.Module):
|
||||
p=float(self.p),
|
||||
training=self.training)
|
||||
|
||||
# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
|
||||
# and it lets you choose one dimension to share the dropout mask over
|
||||
class Dropout3(nn.Module):
|
||||
def __init__(self, p: FloatLike, shared_dim: int):
|
||||
super().__init__()
|
||||
self.p = p
|
||||
self.shared_dim = shared_dim
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
p = float(self.p)
|
||||
if not self.training or p == 0:
|
||||
return _no_op(x)
|
||||
scale = 1.0 / (1 - self.p)
|
||||
rand_shape = list(x.shape)
|
||||
rand_shape[self.shared_dim] = 1
|
||||
mask = torch.rand(*rand_shape, device=x.device) > p
|
||||
return (x * mask) * scale
|
||||
|
||||
|
||||
class SwooshLFunction(torch.autograd.Function):
|
||||
"""
|
||||
swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
|
||||
|
||||
@ -31,6 +31,7 @@ from scaling import (
|
||||
ConvNorm1d,
|
||||
ConvNorm2d,
|
||||
Dropout2,
|
||||
Dropout3,
|
||||
MaxEig,
|
||||
DoubleSwish,
|
||||
SwooshL,
|
||||
@ -1544,7 +1545,8 @@ class FeedforwardModule(nn.Module):
|
||||
min_abs=0.75,
|
||||
max_abs=5.0)
|
||||
self.activation = SwooshL()
|
||||
self.dropout = Dropout2(dropout)
|
||||
# shared_dim=0 means we share the dropout mask along the time axis
|
||||
self.dropout = Dropout3(dropout, shared_dim=0)
|
||||
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
||||
initial_scale=0.1)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user