Add LinearWithAuxLoss in nonlin_attention and AttentionSqueeze modules.

This commit is contained in:
Daniel Povey 2022-11-26 14:15:09 +08:00
parent 4058d56c0d
commit 5f80807027

View File

@ -342,6 +342,9 @@ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
(12000.0, ratio * x),
default=x)
def _aux_grad_scale() -> ScheduledFloat:
return ScheduledFloat((0.0, 0.2), (1000.0, 0.01))
class ZipformerEncoderLayer(nn.Module):
"""
Args:
@ -1286,8 +1289,9 @@ class AttentionSqueeze(nn.Module):
super().__init__()
self.bottleneck_dim = bottleneck_dim
self.in_proj = nn.Linear(embed_dim, embed_dim,
bias=False)
self.in_proj = LinearWithAuxLoss(embed_dim, embed_dim,
bias=False,
aux_grad_scale=_aux_grad_scale())
self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim,
bottleneck_dim)
@ -1337,8 +1341,9 @@ class AttentionSqueeze(nn.Module):
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim)
self.out_proj = ScaledLinear(embed_dim, embed_dim,
bias=False, initial_scale=0.05)
self.out_proj = LinearWithAuxLoss(embed_dim, embed_dim,
aux_grad_scale=_aux_grad_scale(),
bias=False, initial_scale=0.05)
def forward(self,
x: Tensor,
@ -1385,7 +1390,7 @@ class FeedforwardModule(nn.Module):
dropout: float):
super(FeedforwardModule, self).__init__()
self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim,
aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)))
aux_grad_scale=_aux_grad_scale())
self.hidden_balancer = ActivationBalancer(feedforward_dim,
channel_dim=-1, max_abs=10.0,
@ -1394,7 +1399,7 @@ class FeedforwardModule(nn.Module):
self.dropout = nn.Dropout(dropout)
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
initial_scale=0.01,
aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)))
aux_grad_scale=_aux_grad_scale())
self.out_whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(7.5),
prob=(0.025, 0.25),
@ -1425,7 +1430,8 @@ class NonlinAttentionModule(nn.Module):
) -> None:
super().__init__()
self.in_proj = nn.Linear(channels, 2 * channels, bias=True)
self.in_proj = LinearWithAuxLoss(channels, 2 * channels, bias=True,
aux_grad_scale=_aux_grad_scale())
# balancer that goes after the glu mechanism.
self.balancer = ActivationBalancer(
@ -1437,9 +1443,10 @@ class NonlinAttentionModule(nn.Module):
self.sigmoid = nn.Sigmoid()
self.activation = Identity() # for diagnostics.
self.out_proj = ScaledLinear(channels, channels,
bias=True,
initial_scale=0.05)
self.out_proj = LinearWithAuxLoss(channels, channels,
bias=True,
aux_grad_scale=_aux_grad_scale(),
initial_scale=0.05)
self.whiten1 = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(5.0),
@ -1515,7 +1522,7 @@ class ConvolutionModule(nn.Module):
self.in_proj = LinearWithAuxLoss(
channels, 2 * channels,
aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01))
aux_grad_scale=_aux_grad_scale()
)
@ -1562,7 +1569,7 @@ class ConvolutionModule(nn.Module):
self.out_proj = LinearWithAuxLoss(
channels, channels,
aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)),
aux_grad_scale=_aux_grad_scale(),
initial_scale=0.05,
)
@ -1678,7 +1685,7 @@ class Conv2dSubsampling(nn.Module):
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
aux_grad_scale=ScheduledFloat((0.0, 0.2), (1000.0, 0.01)))
aux_grad_scale=_aux_grad_scale())
self.dropout = nn.Dropout(dropout)