Add balancer to ConvNeXt

This commit is contained in:
Daniel Povey 2022-12-21 18:41:05 +08:00
parent 11f68afa1f
commit 39e7c613c7

View File

@ -1712,6 +1712,13 @@ class ConvNeXt(nn.Module):
kernel_size=1, kernel_size=1,
initial_scale=0.01) initial_scale=0.01)
self.out_balancer = ActivationBalancer(
channels, channel_dim=1,
min_positive=0.5, max_positive=0.5,
min_abs=0.25, max_abs=6.0,
)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
""" """
@ -1732,7 +1739,9 @@ class ConvNeXt(nn.Module):
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
x = x * mask x = x * mask
return bypass + x x = bypass + x
x = self.out_balancer(x)
return x