mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Share dropout masks across time in ff modules
This commit is contained in:
parent
3110ed045a
commit
4033000730
@ -1873,6 +1873,24 @@ class Dropout2(nn.Module):
|
|||||||
p=float(self.p),
|
p=float(self.p),
|
||||||
training=self.training)
|
training=self.training)
|
||||||
|
|
||||||
|
# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
|
||||||
|
# and it lets you choose one dimension to share the dropout mask over
|
||||||
|
class Dropout3(nn.Module):
|
||||||
|
def __init__(self, p: FloatLike, shared_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.p = p
|
||||||
|
self.shared_dim = shared_dim
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
p = float(self.p)
|
||||||
|
if not self.training or p == 0:
|
||||||
|
return _no_op(x)
|
||||||
|
scale = 1.0 / (1 - self.p)
|
||||||
|
rand_shape = list(x.shape)
|
||||||
|
rand_shape[self.shared_dim] = 1
|
||||||
|
mask = torch.rand(*rand_shape, device=x.device) > p
|
||||||
|
return (x * mask) * scale
|
||||||
|
|
||||||
|
|
||||||
class SwooshLFunction(torch.autograd.Function):
|
class SwooshLFunction(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
|
swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from scaling import (
|
|||||||
ConvNorm1d,
|
ConvNorm1d,
|
||||||
ConvNorm2d,
|
ConvNorm2d,
|
||||||
Dropout2,
|
Dropout2,
|
||||||
|
Dropout3,
|
||||||
MaxEig,
|
MaxEig,
|
||||||
DoubleSwish,
|
DoubleSwish,
|
||||||
SwooshL,
|
SwooshL,
|
||||||
@ -1544,7 +1545,8 @@ class FeedforwardModule(nn.Module):
|
|||||||
min_abs=0.75,
|
min_abs=0.75,
|
||||||
max_abs=5.0)
|
max_abs=5.0)
|
||||||
self.activation = SwooshL()
|
self.activation = SwooshL()
|
||||||
self.dropout = Dropout2(dropout)
|
# shared_dim=0 means we share the dropout mask along the time axis
|
||||||
|
self.dropout = Dropout3(dropout, shared_dim=0)
|
||||||
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
||||||
initial_scale=0.1)
|
initial_scale=0.1)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user