Halve expected value of aux_grad scale, and implement it more efficiently, via a scale on the prob of using it.

This commit is contained in:
Daniel Povey 2022-11-26 14:52:59 +08:00
parent 110c2601ab
commit 8858fb38f1

View File

@ -342,8 +342,10 @@ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
(12000.0, ratio * x), (12000.0, ratio * x),
default=x) default=x)
def _aux_grad_scale() -> ScheduledFloat: def _aux_grad_scale() -> float:
return ScheduledFloat((0.0, 0.2), (1000.0, 0.01)) return 0.1
def _aux_grad_prob() -> ScheduledFloat:
return ScheduledFloat((0.0, 0.25), (1000.0, 0.0125))
class ZipformerEncoderLayer(nn.Module): class ZipformerEncoderLayer(nn.Module):
""" """
@ -1291,7 +1293,7 @@ class AttentionSqueeze(nn.Module):
self.in_proj = LinearWithAuxLoss(embed_dim, embed_dim, self.in_proj = LinearWithAuxLoss(embed_dim, embed_dim,
bias=False, bias=False,
aux_grad_scale=_aux_grad_scale()) aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob())
self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim, self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim,
bottleneck_dim) bottleneck_dim)
@ -1342,7 +1344,7 @@ class AttentionSqueeze(nn.Module):
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim) self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim)
self.out_proj = LinearWithAuxLoss(embed_dim, embed_dim, self.out_proj = LinearWithAuxLoss(embed_dim, embed_dim,
aux_grad_scale=_aux_grad_scale(), aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob(),
bias=False, initial_scale=0.05) bias=False, initial_scale=0.05)
def forward(self, def forward(self,
@ -1390,7 +1392,7 @@ class FeedforwardModule(nn.Module):
dropout: float): dropout: float):
super(FeedforwardModule, self).__init__() super(FeedforwardModule, self).__init__()
self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim, self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim,
aux_grad_scale=_aux_grad_scale()) aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob())
self.hidden_balancer = ActivationBalancer(feedforward_dim, self.hidden_balancer = ActivationBalancer(feedforward_dim,
channel_dim=-1, max_abs=10.0, channel_dim=-1, max_abs=10.0,
@ -1399,7 +1401,7 @@ class FeedforwardModule(nn.Module):
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim, self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
initial_scale=0.01, initial_scale=0.01,
aux_grad_scale=_aux_grad_scale()) aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob())
self.out_whiten = Whiten(num_groups=1, self.out_whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(7.5), whitening_limit=_whitening_schedule(7.5),
prob=(0.025, 0.25), prob=(0.025, 0.25),
@ -1431,7 +1433,7 @@ class NonlinAttentionModule(nn.Module):
super().__init__() super().__init__()
self.in_proj = LinearWithAuxLoss(channels, 2 * channels, bias=True, self.in_proj = LinearWithAuxLoss(channels, 2 * channels, bias=True,
aux_grad_scale=_aux_grad_scale()) aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob())
# balancer that goes after the glu mechanism. # balancer that goes after the glu mechanism.
self.balancer = ActivationBalancer( self.balancer = ActivationBalancer(
@ -1445,7 +1447,7 @@ class NonlinAttentionModule(nn.Module):
self.activation = Identity() # for diagnostics. self.activation = Identity() # for diagnostics.
self.out_proj = LinearWithAuxLoss(channels, channels, self.out_proj = LinearWithAuxLoss(channels, channels,
bias=True, bias=True,
aux_grad_scale=_aux_grad_scale(), aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob(),
initial_scale=0.05) initial_scale=0.05)
self.whiten1 = Whiten(num_groups=1, self.whiten1 = Whiten(num_groups=1,
@ -1522,7 +1524,7 @@ class ConvolutionModule(nn.Module):
self.in_proj = LinearWithAuxLoss( self.in_proj = LinearWithAuxLoss(
channels, 2 * channels, channels, 2 * channels,
aux_grad_scale=_aux_grad_scale() aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob()
) )
@ -1569,7 +1571,7 @@ class ConvolutionModule(nn.Module):
self.out_proj = LinearWithAuxLoss( self.out_proj = LinearWithAuxLoss(
channels, channels, channels, channels,
aux_grad_scale=_aux_grad_scale(), aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob(),
initial_scale=0.05, initial_scale=0.05,
) )
@ -1685,7 +1687,7 @@ class Conv2dSubsampling(nn.Module):
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels, self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
aux_grad_scale=_aux_grad_scale()) aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob())
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)