diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index e594e4e4d..81f025b79 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1216,9 +1216,9 @@ class TanSwish(torch.nn.Module): class SwooshFunction(torch.autograd.Function): """ - swoosh(x) = log(1 + exp(x-4)) - 0.055*x - 0.15 + swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.15 - derivatives are between -0.055 and 1-0.055. + derivatives are between -0.08 and 0.92. """ @staticmethod @@ -1235,15 +1235,15 @@ class SwooshFunction(torch.autograd.Function): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = torch.logaddexp(one, x - 4) - 0.055 * x - 0.15 + y = torch.logaddexp(one, x - 4) - 0.08 * x - 0.15 if not requires_grad: return y y.backward(gradient = torch.ones_like(y)) grad = x.grad - floor = -0.055 - ceil = 0.946 # real ceil would be 0.0945, give it extra room for roundoff. + floor = -0.08 + ceil = 0.925 # real ceil would be 0.092, give it extra room for roundoff. d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) if __name__ == "__main__": @@ -1261,8 +1261,8 @@ class SwooshFunction(torch.autograd.Function): def backward(ctx, y_grad: Tensor) -> Tensor: d, = ctx.saved_tensors # the same constants as used in forward pass. - floor = -0.055 - ceil = 0.946 + floor = -0.08 + ceil = 0.925 d = (d * ((ceil - floor) / 255.0) + floor) return (y_grad * d) @@ -1273,7 +1273,7 @@ class Swoosh(torch.nn.Module): """ if torch.jit.is_scripting(): one = torch.tensor(1.0, dtype=x.dtype, device=x.device) - return torch.logaddexp(one, x - 4) - 0.055 * x - 0.15 + return torch.logaddexp(one, x - 4) - 0.08 * x - 0.15 return SwooshFunction.apply(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index b97bd51ee..889b9eba0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -29,6 +29,7 @@ from scaling import ( BasicNorm, MaxEig, DoubleSwish, + Swoosh, TanSwish, ScaledConv1d, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. @@ -1421,10 +1422,10 @@ class FeedforwardModule(nn.Module): self.hidden_balancer = ActivationBalancer(feedforward_dim, channel_dim=-1, min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), - min_abs=1.5, - max_abs=15.0, + min_abs=2.0, + max_abs=10.0, min_prob=0.25) - self.activation = DoubleSwish() + self.activation = Swoosh() self.dropout = nn.Dropout(dropout) self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim, initial_scale=0.01, @@ -1599,10 +1600,11 @@ class ConvolutionModule(nn.Module): channels, channel_dim=1, min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), max_positive=1.0, - max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=10), + min_abs=1.0, + max_abs=ScheduledFloat((0.0, 10.0), (8000.0, 20.0), default=10), ) - self.activation = DoubleSwish() + self.activation = Swoosh() self.whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(7.5), diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index e596c0028..91580f986 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -547,7 +547,7 @@ def attach_diagnostics( module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook) - if type(module).__name__ in ["Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish"]: + if type(module).__name__ in ["Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish", "Swoosh"]: # For these specific module types, accumulate some additional diagnostics # that can help us improve the activation function. These require a lot of memory, # to save the forward activations, so limit this to some select classes.