Make alpha for LinearWithAuxLossFunction be in log space; simplify/rework NonlinAttentionModule, setup more like ConvModule now.

This commit is contained in:
Daniel Povey 2022-11-26 19:38:29 +08:00
parent e19118a966
commit a96b92fb54
2 changed files with 22 additions and 35 deletions

View File

@ -421,7 +421,7 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
# recompute y as we need the gradient; this is easier to implement than
# saving y in the context.
y = torch.matmul(x, weight.t())
z = alpha * torch.matmul(y, weight)
z = alpha.exp() * torch.matmul(y, weight)
diff = x - z
dims_to_mean = tuple(range(x.ndim-1))
mean = diff.mean(dim=dims_to_mean)
@ -452,7 +452,7 @@ class LinearWithAuxLoss(nn.Module):
Suppose the input is x, and this layer computes:
y = M x
(the bias is applied separately), then we define:
z = alpha * M^T y
z = exp(alpha) * M^T y
where alpha is learnable; and the auxiliary loss will be:
aux_loss = normalize_mean(z - x)^2.
(normalize_mean refers to subtracting the average value per channel,
@ -485,7 +485,7 @@ class LinearWithAuxLoss(nn.Module):
0.01 * initial_scale)
else:
self.register_parameter('bias', None)
self.alpha = nn.Parameter(torch.tensor(1.0))
self.alpha = nn.Parameter(torch.tensor(0.0))
def forward(self,

View File

@ -343,7 +343,7 @@ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
default=x)
def _aux_grad_scale() -> float:
return 0.1
return 0.2
def _aux_grad_prob() -> ScheduledFloat:
return ScheduledFloat((0.0, 0.25), (1000.0, 0.0125))
@ -1432,33 +1432,27 @@ class NonlinAttentionModule(nn.Module):
) -> None:
super().__init__()
self.in_proj = LinearWithAuxLoss(channels, 2 * channels, bias=True,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob())
self.in_proj = nn.Linear(channels, 2 * channels, bias=True)
# balancer that goes after the glu mechanism.
# balancer that goes before the sigmoid.
self.balancer = ActivationBalancer(
channels, channel_dim=-1,
min_positive=0.2, max_positive=0.8,
min_abs=0.2, max_abs=10.0,
min_prob=0.05,
min_positive=0.05, max_positive=1.0,
min_abs=0.2, max_abs=ScheduledFloat((0.0, 2.0),
(4000.0, 10.0),
default=1.0),
)
self.pre_sigmoid = Identity() # for diagnostics.
self.sigmoid = nn.Sigmoid()
self.activation = Identity() # for diagnostics.
self.out_proj = LinearWithAuxLoss(channels, channels,
bias=True,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob(),
initial_scale=0.05)
self.out_proj = ScaledLinear(channels, channels,
bias=True,
initial_scale=0.05)
self.whiten1 = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(5.0),
prob=(0.025, 0.25),
grad_scale=0.01)
self.whiten2 = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(7.5),
prob=(0.025, 0.25),
grad_scale=0.01)
self.whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(7.5),
prob=(0.025, 0.25),
grad_scale=0.01)
@ -1475,19 +1469,12 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
"""
x = self.in_proj(x)
v, s = x.chunk(2, dim=-1)
x, s = x.chunk(2, dim=-1)
if self.training and random.random() < 0.02:
# prevent the inputs to the sigmoid from getting very large (this is
# hopefully quite a rare phenomenon, so we are giving this path a
# very small probability to save time).
s = penalize_abs_values_gt(s, limit=20.0, penalty=1.0e-04)
s = self.balancer(s)
s = self.sigmoid(s)
v = self.whiten1(v)
# GLU mechanism
s = self.pre_sigmoid(s)
x = self.sigmoid(s) * v
x = self.balancer(x)
x = x * s
(seq_len, batch_size, embed_dim) = x.shape
num_heads = attn_weights.shape[0]
@ -1500,7 +1487,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
x = self.activation(x) # diagnostics only, it's the identity.
x = self.whiten2(x)
x = self.whiten(x)
x = self.out_proj(x)
return x