mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Replace BatchNorm in the Conformer model with LayerNorm.
This commit is contained in:
parent
4635af633a
commit
66cc9b4592
@ -866,7 +866,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
groups=channels,
|
groups=channels,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
self.norm = nn.BatchNorm1d(channels)
|
self.norm = nn.LayerNorm(channels)
|
||||||
self.pointwise_conv2 = nn.Conv1d(
|
self.pointwise_conv2 = nn.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
@ -896,7 +896,12 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
# 1D Depthwise Conv
|
# 1D Depthwise Conv
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
x = self.activation(self.norm(x))
|
# x is (batch, channels, time)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
x = self.activation(x)
|
||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user