mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Replace norm in ConvolutionModule with a scaling factor.
This commit is contained in:
parent
87b843f023
commit
425e274c82
@ -857,7 +857,8 @@ class ConvolutionModule(nn.Module):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm = nn.LayerNorm(channels)
|
self.scale = ExpScale(1, speed=10.0, initial_scale=1.0)
|
||||||
|
|
||||||
# shape: (channels, 1), broadcasts with (batch, channel, time).
|
# shape: (channels, 1), broadcasts with (batch, channel, time).
|
||||||
self.activation = SwishOffset()
|
self.activation = SwishOffset()
|
||||||
|
|
||||||
@ -891,7 +892,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
# x is (batch, channels, time)
|
# x is (batch, channels, time)
|
||||||
x = x.permute(0, 2, 1)
|
x = x.permute(0, 2, 1)
|
||||||
x = self.norm(x)
|
x = self.scale(x)
|
||||||
x = x.permute(0, 2, 1)
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm",
|
default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
Loading…
x
Reference in New Issue
Block a user