Remove max-positive constraint in deriv-balancing; add second DerivBalancer in conv module.

This commit is contained in:
Daniel Povey 2022-03-15 13:10:35 +08:00
parent a23010fc10
commit 86e5dcba11
3 changed files with 8 additions and 6 deletions

View File

@ -527,7 +527,7 @@ class DerivBalancer(torch.nn.Module):
""" """
def __init__(self, channel_dim: int, def __init__(self, channel_dim: int,
min_positive: float = 0.05, min_positive: float = 0.05,
max_positive: float = 0.95, max_positive: float = 1.0,
max_factor: float = 0.01, max_factor: float = 0.01,
min_abs: float = 0.2, min_abs: float = 0.2,
max_abs: float = 100.0): max_abs: float = 100.0):

View File

@ -862,8 +862,7 @@ class ConvolutionModule(nn.Module):
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
# it will be in a better position to start learning something, i.e. to latch onto # it will be in a better position to start learning something, i.e. to latch onto
# the correct range. # the correct range.
self.deriv_balancer = DerivBalancer(channel_dim=1, max_abs=10.0, self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0)
min_positive=0.0, max_positive=1.0)
self.depthwise_conv = ScaledConv1d( self.depthwise_conv = ScaledConv1d(
channels, channels,
@ -875,7 +874,9 @@ class ConvolutionModule(nn.Module):
bias=bias, bias=bias,
) )
# shape: (channels, 1), broadcasts with (batch, channel, time). self.deriv_balancer2 = DerivBalancer(channel_dim=1)
# Shape: (channels, 1), broadcasts with (batch, channel, time).
self.activation = SwishOffset() self.activation = SwishOffset()
self.pointwise_conv2 = ScaledConv1d( self.pointwise_conv2 = ScaledConv1d(
@ -904,12 +905,13 @@ class ConvolutionModule(nn.Module):
# GLU mechanism # GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time) x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = self.deriv_balancer(x) x = self.deriv_balancer1(x)
x = nn.functional.glu(x, dim=1) # (batch, channels, time) x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.deriv_balancer2(x)
x = self.activation(x) x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup", default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved