mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make alpha for LinearWithAuxLossFunction be in log space; simplify/rework NonlinAttentionModule, setup more like ConvModule now.
This commit is contained in:
parent
e19118a966
commit
a96b92fb54
@ -421,7 +421,7 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
|
||||
# recompute y as we need the gradient; this is easier to implement than
|
||||
# saving y in the context.
|
||||
y = torch.matmul(x, weight.t())
|
||||
z = alpha * torch.matmul(y, weight)
|
||||
z = alpha.exp() * torch.matmul(y, weight)
|
||||
diff = x - z
|
||||
dims_to_mean = tuple(range(x.ndim-1))
|
||||
mean = diff.mean(dim=dims_to_mean)
|
||||
@ -452,7 +452,7 @@ class LinearWithAuxLoss(nn.Module):
|
||||
Suppose the input is x, and this layer computes:
|
||||
y = M x
|
||||
(the bias is applied separately), then we define:
|
||||
z = alpha * M^T y
|
||||
z = exp(alpha) * M^T y
|
||||
where alpha is learnable; and the auxiliary loss will be:
|
||||
aux_loss = normalize_mean(z - x)^2.
|
||||
(normalize_mean refers to subtracting the average value per channel,
|
||||
@ -485,7 +485,7 @@ class LinearWithAuxLoss(nn.Module):
|
||||
0.01 * initial_scale)
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.alpha = nn.Parameter(torch.tensor(1.0))
|
||||
self.alpha = nn.Parameter(torch.tensor(0.0))
|
||||
|
||||
|
||||
def forward(self,
|
||||
|
||||
@ -343,7 +343,7 @@ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
|
||||
default=x)
|
||||
|
||||
def _aux_grad_scale() -> float:
|
||||
return 0.1
|
||||
return 0.2
|
||||
def _aux_grad_prob() -> ScheduledFloat:
|
||||
return ScheduledFloat((0.0, 0.25), (1000.0, 0.0125))
|
||||
|
||||
@ -1432,33 +1432,27 @@ class NonlinAttentionModule(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.in_proj = LinearWithAuxLoss(channels, 2 * channels, bias=True,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob())
|
||||
self.in_proj = nn.Linear(channels, 2 * channels, bias=True)
|
||||
|
||||
# balancer that goes after the glu mechanism.
|
||||
# balancer that goes before the sigmoid.
|
||||
self.balancer = ActivationBalancer(
|
||||
channels, channel_dim=-1,
|
||||
min_positive=0.2, max_positive=0.8,
|
||||
min_abs=0.2, max_abs=10.0,
|
||||
min_prob=0.05,
|
||||
min_positive=0.05, max_positive=1.0,
|
||||
min_abs=0.2, max_abs=ScheduledFloat((0.0, 2.0),
|
||||
(4000.0, 10.0),
|
||||
default=1.0),
|
||||
)
|
||||
self.pre_sigmoid = Identity() # for diagnostics.
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
self.activation = Identity() # for diagnostics.
|
||||
self.out_proj = LinearWithAuxLoss(channels, channels,
|
||||
bias=True,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob(),
|
||||
initial_scale=0.05)
|
||||
self.out_proj = ScaledLinear(channels, channels,
|
||||
bias=True,
|
||||
initial_scale=0.05)
|
||||
|
||||
self.whiten1 = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(5.0),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
self.whiten2 = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
self.whiten = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
|
||||
|
||||
@ -1475,19 +1469,12 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
"""
|
||||
x = self.in_proj(x)
|
||||
|
||||
v, s = x.chunk(2, dim=-1)
|
||||
x, s = x.chunk(2, dim=-1)
|
||||
|
||||
if self.training and random.random() < 0.02:
|
||||
# prevent the inputs to the sigmoid from getting very large (this is
|
||||
# hopefully quite a rare phenomenon, so we are giving this path a
|
||||
# very small probability to save time).
|
||||
s = penalize_abs_values_gt(s, limit=20.0, penalty=1.0e-04)
|
||||
s = self.balancer(s)
|
||||
s = self.sigmoid(s)
|
||||
|
||||
v = self.whiten1(v)
|
||||
# GLU mechanism
|
||||
s = self.pre_sigmoid(s)
|
||||
x = self.sigmoid(s) * v
|
||||
x = self.balancer(x)
|
||||
x = x * s
|
||||
|
||||
(seq_len, batch_size, embed_dim) = x.shape
|
||||
num_heads = attn_weights.shape[0]
|
||||
@ -1500,7 +1487,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
||||
|
||||
x = self.activation(x) # diagnostics only, it's the identity.
|
||||
x = self.whiten2(x)
|
||||
x = self.whiten(x)
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user