mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Add norm+balancer to VggSubsampling
This commit is contained in:
parent
0ee2404ff0
commit
05b5e78d8f
@ -158,6 +158,12 @@ class VggSubsampling(nn.Module):
|
||||
self.out = nn.Linear(
|
||||
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
|
||||
)
|
||||
self.out_norm = BasicNorm(odim, learn_eps=False)
|
||||
# constrain median of output to be close to zero.
|
||||
self.out_balancer = ActivationBalancer(channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
max_positive=0.55)
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
@ -173,4 +179,6 @@ class VggSubsampling(nn.Module):
|
||||
x = self.layers(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
x = self.out_norm(x)
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user