mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add balancer to ConvNeXt
This commit is contained in:
parent
11f68afa1f
commit
39e7c613c7
@ -1712,6 +1712,13 @@ class ConvNeXt(nn.Module):
|
|||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
initial_scale=0.01)
|
initial_scale=0.01)
|
||||||
|
|
||||||
|
self.out_balancer = ActivationBalancer(
|
||||||
|
channels, channel_dim=1,
|
||||||
|
min_positive=0.5, max_positive=0.5,
|
||||||
|
min_abs=0.25, max_abs=6.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
@ -1732,7 +1739,9 @@ class ConvNeXt(nn.Module):
|
|||||||
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
|
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
|
||||||
x = x * mask
|
x = x * mask
|
||||||
|
|
||||||
return bypass + x
|
x = bypass + x
|
||||||
|
x = self.out_balancer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user