diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index d0f29b4f4..484eb24ef 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1235,15 +1235,15 @@ class SwooshFunction(torch.autograd.Function): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + y = torch.logaddexp(zero, x - 4.0) - 0.1 * x - 0.035 if not requires_grad: return y y.backward(gradient = torch.ones_like(y)) grad = x.grad - floor = -0.08 - ceil = 0.925 # real ceil would be 0.092, give it extra room for roundoff. + floor = -0.1 + ceil = 0.905 # real ceil would be 0.09, give it extra room for roundoff. d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) if __name__ == "__main__": @@ -1261,8 +1261,8 @@ class SwooshFunction(torch.autograd.Function): def backward(ctx, y_grad: Tensor) -> Tensor: d, = ctx.saved_tensors # the same constants as used in forward pass. - floor = -0.08 - ceil = 0.925 + floor = -0.1 + ceil = 0.905 d = (d * ((ceil - floor) / 255.0) + floor) return (y_grad * d) @@ -1273,7 +1273,7 @@ class Swoosh(torch.nn.Module): """ if torch.jit.is_scripting(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + return torch.logaddexp(zero, x - 4.0) - 0.1 * x - 0.035 return SwooshFunction.apply(x) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 44dd311c0..87543e830 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1421,7 +1421,8 @@ class FeedforwardModule(nn.Module): self.hidden_balancer = ActivationBalancer(feedforward_dim, channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), + min_positive=0.3, + max_positive=1.0, min_abs=2.0, max_abs=10.0, min_prob=0.25)