Add BasicNorm on output of Conv2dSubsampling module

This commit is contained in:
Daniel Povey 2022-12-20 15:00:01 +08:00
parent d2b272ab50
commit f59697555f

View File

@ -1821,6 +1821,7 @@ class Conv2dSubsampling(nn.Module):
self.out = nn.Linear(out_height * layer3_channels, out_channels)
self.out_norm = BasicNorm(out_channels, channel_dim=-1)
self.dropout = Dropout2(dropout)
@ -1856,6 +1857,7 @@ class Conv2dSubsampling(nn.Module):
max=float(self.scale_max))
x = self.out(x)
x = self.out_norm(x)
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.dropout(x)
return x