mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
Add norm+balancer to VggSubsampling
This commit is contained in:
parent
0ee2404ff0
commit
05b5e78d8f
@ -158,6 +158,12 @@ class VggSubsampling(nn.Module):
|
|||||||
self.out = nn.Linear(
|
self.out = nn.Linear(
|
||||||
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
|
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
|
||||||
)
|
)
|
||||||
|
self.out_norm = BasicNorm(odim, learn_eps=False)
|
||||||
|
# constrain median of output to be close to zero.
|
||||||
|
self.out_balancer = ActivationBalancer(channel_dim=-1,
|
||||||
|
min_positive=0.45,
|
||||||
|
max_positive=0.55)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Subsample x.
|
"""Subsample x.
|
||||||
@ -173,4 +179,6 @@ class VggSubsampling(nn.Module):
|
|||||||
x = self.layers(x)
|
x = self.layers(x)
|
||||||
b, c, t, f = x.size()
|
b, c, t, f = x.size()
|
||||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
|
x = self.out_norm(x)
|
||||||
|
x = self.out_balancer(x)
|
||||||
return x
|
return x
|
||||||
|
Loading…
x
Reference in New Issue
Block a user