Remove max-positive constraint in deriv-balancing; add second DerivBalancer in conv module.

This commit is contained in:
Daniel Povey 2022-03-15 13:10:35 +08:00
parent a23010fc10
commit 86e5dcba11
3 changed files with 8 additions and 6 deletions

View File

@ -527,7 +527,7 @@ class DerivBalancer(torch.nn.Module):
"""
def __init__(self, channel_dim: int,
min_positive: float = 0.05,
max_positive: float = 0.95,
max_positive: float = 1.0,
max_factor: float = 0.01,
min_abs: float = 0.2,
max_abs: float = 100.0):

View File

@ -862,8 +862,7 @@ class ConvolutionModule(nn.Module):
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
# it will be in a better position to start learning something, i.e. to latch onto
# the correct range.
self.deriv_balancer = DerivBalancer(channel_dim=1, max_abs=10.0,
min_positive=0.0, max_positive=1.0)
self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0)
self.depthwise_conv = ScaledConv1d(
channels,
@ -875,7 +874,9 @@ class ConvolutionModule(nn.Module):
bias=bias,
)
# shape: (channels, 1), broadcasts with (batch, channel, time).
self.deriv_balancer2 = DerivBalancer(channel_dim=1)
# Shape: (channels, 1), broadcasts with (batch, channel, time).
self.activation = SwishOffset()
self.pointwise_conv2 = ScaledConv1d(
@ -904,12 +905,13 @@ class ConvolutionModule(nn.Module):
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = self.deriv_balancer(x)
x = self.deriv_balancer1(x)
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.deriv_balancer2(x)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup",
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved