diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 5af2402c0..c925fc32f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1873,6 +1873,24 @@ class Dropout2(nn.Module): p=float(self.p), training=self.training) +# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, +# and it lets you choose one dimension to share the dropout mask over +class Dropout3(nn.Module): + def __init__(self, p: FloatLike, shared_dim: int): + super().__init__() + self.p = p + self.shared_dim = shared_dim + def forward(self, x: Tensor) -> Tensor: + p = float(self.p) + if not self.training or p == 0: + return _no_op(x) + scale = 1.0 / (1 - self.p) + rand_shape = list(x.shape) + rand_shape[self.shared_dim] = 1 + mask = torch.rand(*rand_shape, device=x.device) > p + return (x * mask) * scale + + class SwooshLFunction(torch.autograd.Function): """ swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 2cc2b987e..9b0100c14 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -31,6 +31,7 @@ from scaling import ( ConvNorm1d, ConvNorm2d, Dropout2, + Dropout3, MaxEig, DoubleSwish, SwooshL, @@ -1544,7 +1545,8 @@ class FeedforwardModule(nn.Module): min_abs=0.75, max_abs=5.0) self.activation = SwooshL() - self.dropout = Dropout2(dropout) + # shared_dim=0 means we share the dropout mask along the time axis + self.dropout = Dropout3(dropout, shared_dim=0) self.out_proj = ScaledLinear(feedforward_dim, embed_dim, initial_scale=0.1)