mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make alpha for LinearWithAuxLossFunction be in log space; simplify/rework NonlinAttentionModule, setup more like ConvModule now.
This commit is contained in:
parent
e19118a966
commit
a96b92fb54
@ -421,7 +421,7 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
|
|||||||
# recompute y as we need the gradient; this is easier to implement than
|
# recompute y as we need the gradient; this is easier to implement than
|
||||||
# saving y in the context.
|
# saving y in the context.
|
||||||
y = torch.matmul(x, weight.t())
|
y = torch.matmul(x, weight.t())
|
||||||
z = alpha * torch.matmul(y, weight)
|
z = alpha.exp() * torch.matmul(y, weight)
|
||||||
diff = x - z
|
diff = x - z
|
||||||
dims_to_mean = tuple(range(x.ndim-1))
|
dims_to_mean = tuple(range(x.ndim-1))
|
||||||
mean = diff.mean(dim=dims_to_mean)
|
mean = diff.mean(dim=dims_to_mean)
|
||||||
@ -452,7 +452,7 @@ class LinearWithAuxLoss(nn.Module):
|
|||||||
Suppose the input is x, and this layer computes:
|
Suppose the input is x, and this layer computes:
|
||||||
y = M x
|
y = M x
|
||||||
(the bias is applied separately), then we define:
|
(the bias is applied separately), then we define:
|
||||||
z = alpha * M^T y
|
z = exp(alpha) * M^T y
|
||||||
where alpha is learnable; and the auxiliary loss will be:
|
where alpha is learnable; and the auxiliary loss will be:
|
||||||
aux_loss = normalize_mean(z - x)^2.
|
aux_loss = normalize_mean(z - x)^2.
|
||||||
(normalize_mean refers to subtracting the average value per channel,
|
(normalize_mean refers to subtracting the average value per channel,
|
||||||
@ -485,7 +485,7 @@ class LinearWithAuxLoss(nn.Module):
|
|||||||
0.01 * initial_scale)
|
0.01 * initial_scale)
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias', None)
|
self.register_parameter('bias', None)
|
||||||
self.alpha = nn.Parameter(torch.tensor(1.0))
|
self.alpha = nn.Parameter(torch.tensor(0.0))
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
|
|||||||
@ -343,7 +343,7 @@ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
|
|||||||
default=x)
|
default=x)
|
||||||
|
|
||||||
def _aux_grad_scale() -> float:
|
def _aux_grad_scale() -> float:
|
||||||
return 0.1
|
return 0.2
|
||||||
def _aux_grad_prob() -> ScheduledFloat:
|
def _aux_grad_prob() -> ScheduledFloat:
|
||||||
return ScheduledFloat((0.0, 0.25), (1000.0, 0.0125))
|
return ScheduledFloat((0.0, 0.25), (1000.0, 0.0125))
|
||||||
|
|
||||||
@ -1432,33 +1432,27 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.in_proj = LinearWithAuxLoss(channels, 2 * channels, bias=True,
|
self.in_proj = nn.Linear(channels, 2 * channels, bias=True)
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob())
|
|
||||||
|
|
||||||
# balancer that goes after the glu mechanism.
|
# balancer that goes before the sigmoid.
|
||||||
self.balancer = ActivationBalancer(
|
self.balancer = ActivationBalancer(
|
||||||
channels, channel_dim=-1,
|
channels, channel_dim=-1,
|
||||||
min_positive=0.2, max_positive=0.8,
|
min_positive=0.05, max_positive=1.0,
|
||||||
min_abs=0.2, max_abs=10.0,
|
min_abs=0.2, max_abs=ScheduledFloat((0.0, 2.0),
|
||||||
min_prob=0.05,
|
(4000.0, 10.0),
|
||||||
|
default=1.0),
|
||||||
)
|
)
|
||||||
self.pre_sigmoid = Identity() # for diagnostics.
|
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
self.activation = Identity() # for diagnostics.
|
self.activation = Identity() # for diagnostics.
|
||||||
self.out_proj = LinearWithAuxLoss(channels, channels,
|
self.out_proj = ScaledLinear(channels, channels,
|
||||||
bias=True,
|
bias=True,
|
||||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob(),
|
initial_scale=0.05)
|
||||||
initial_scale=0.05)
|
|
||||||
|
|
||||||
self.whiten1 = Whiten(num_groups=1,
|
self.whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=_whitening_schedule(5.0),
|
whitening_limit=_whitening_schedule(7.5),
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
self.whiten2 = Whiten(num_groups=1,
|
|
||||||
whitening_limit=_whitening_schedule(7.5),
|
|
||||||
prob=(0.025, 0.25),
|
|
||||||
grad_scale=0.01)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1475,19 +1469,12 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
"""
|
"""
|
||||||
x = self.in_proj(x)
|
x = self.in_proj(x)
|
||||||
|
|
||||||
v, s = x.chunk(2, dim=-1)
|
x, s = x.chunk(2, dim=-1)
|
||||||
|
|
||||||
if self.training and random.random() < 0.02:
|
s = self.balancer(s)
|
||||||
# prevent the inputs to the sigmoid from getting very large (this is
|
s = self.sigmoid(s)
|
||||||
# hopefully quite a rare phenomenon, so we are giving this path a
|
|
||||||
# very small probability to save time).
|
|
||||||
s = penalize_abs_values_gt(s, limit=20.0, penalty=1.0e-04)
|
|
||||||
|
|
||||||
v = self.whiten1(v)
|
x = x * s
|
||||||
# GLU mechanism
|
|
||||||
s = self.pre_sigmoid(s)
|
|
||||||
x = self.sigmoid(s) * v
|
|
||||||
x = self.balancer(x)
|
|
||||||
|
|
||||||
(seq_len, batch_size, embed_dim) = x.shape
|
(seq_len, batch_size, embed_dim) = x.shape
|
||||||
num_heads = attn_weights.shape[0]
|
num_heads = attn_weights.shape[0]
|
||||||
@ -1500,7 +1487,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
||||||
|
|
||||||
x = self.activation(x) # diagnostics only, it's the identity.
|
x = self.activation(x) # diagnostics only, it's the identity.
|
||||||
x = self.whiten2(x)
|
x = self.whiten(x)
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user