From 39e7c613c75f66f3edcc06c74417f05f3c9cb374 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 21 Dec 2022 18:41:05 +0800 Subject: [PATCH] Add balancer to ConvNeXt --- .../ASR/pruned_transducer_stateless7/zipformer.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 105a63943..b3029b1cb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1712,6 +1712,13 @@ class ConvNeXt(nn.Module): kernel_size=1, initial_scale=0.01) + self.out_balancer = ActivationBalancer( + channels, channel_dim=1, + min_positive=0.5, max_positive=0.5, + min_abs=0.25, max_abs=6.0, + ) + + def forward(self, x: Tensor) -> Tensor: """ @@ -1732,7 +1739,9 @@ class ConvNeXt(nn.Module): mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate x = x * mask - return bypass + x + x = bypass + x + x = self.out_balancer(x) + return x