mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Reduce constraints from deriv-balancer in ConvModule.
This commit is contained in:
parent
788963d40a
commit
8d17a05dd2
@ -861,7 +861,8 @@ class ConvolutionModule(nn.Module):
|
|||||||
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
||||||
# it will be in a better position to start learning something, i.e. to latch onto
|
# it will be in a better position to start learning something, i.e. to latch onto
|
||||||
# the correct range.
|
# the correct range.
|
||||||
self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0)
|
self.deriv_balancer = DerivBalancer(channel_dim=1, max_abs=10.0,
|
||||||
|
min_positive=0.0, max_positive=1.0)
|
||||||
|
|
||||||
self.depthwise_conv = ScaledConv1d(
|
self.depthwise_conv = ScaledConv1d(
|
||||||
channels,
|
channels,
|
||||||
@ -873,8 +874,6 @@ class ConvolutionModule(nn.Module):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
self.deriv_balancer2 = DerivBalancer(channel_dim=1)
|
|
||||||
# shape: (channels, 1), broadcasts with (batch, channel, time).
|
# shape: (channels, 1), broadcasts with (batch, channel, time).
|
||||||
self.activation = SwishOffset()
|
self.activation = SwishOffset()
|
||||||
|
|
||||||
@ -904,13 +903,12 @@ class ConvolutionModule(nn.Module):
|
|||||||
# GLU mechanism
|
# GLU mechanism
|
||||||
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
||||||
|
|
||||||
x = self.deriv_balancer1(x)
|
x = self.deriv_balancer(x)
|
||||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||||
|
|
||||||
# 1D Depthwise Conv
|
# 1D Depthwise Conv
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
|
|
||||||
x = self.deriv_balancer2(x)
|
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv",
|
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
Loading…
x
Reference in New Issue
Block a user