Add norm+balancer to VggSubsampling

This commit is contained in:
Daniel Povey 2022-03-21 15:55:11 +08:00
parent 0ee2404ff0
commit 05b5e78d8f

View File

@ -158,6 +158,12 @@ class VggSubsampling(nn.Module):
self.out = nn.Linear(
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
self.out_norm = BasicNorm(odim, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(channel_dim=-1,
min_positive=0.45,
max_positive=0.55)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
@ -173,4 +179,6 @@ class VggSubsampling(nn.Module):
x = self.layers(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
x = self.out_norm(x)
x = self.out_balancer(x)
return x