mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Some small fixes, to bias_correction2 formula and remove bias-u,v-scale
This commit is contained in:
parent
b916789ca3
commit
9ef11e64ba
@ -449,16 +449,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
|
||||
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
|
||||
self._reset_parameters()
|
||||
|
||||
def _pos_bias_u(self):
|
||||
return self.pos_bias_u * self.pos_bias_u_scale.exp()
|
||||
|
||||
def _pos_bias_v(self):
|
||||
return self.pos_bias_v * self.pos_bias_v_scale.exp()
|
||||
|
||||
def _reset_parameters(self) -> None:
|
||||
nn.init.uniform_(self.pos_bias_u, -0.05, 0.05)
|
||||
nn.init.uniform_(self.pos_bias_v, -0.05, 0.05)
|
||||
@ -756,11 +748,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||
|
||||
q_with_bias_u = (q + self._pos_bias_u()).transpose(
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time1, d_k)
|
||||
|
||||
q_with_bias_v = (q + self._pos_bias_v()).transpose(
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time1, d_k)
|
||||
|
||||
|
@ -599,7 +599,9 @@ class Cain(Optimizer):
|
||||
scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv,
|
||||
value=1 - beta2)
|
||||
|
||||
scale_bias_correction2 = 1 - beta2 ** step
|
||||
# should actually be step + 1, so on 1st minibatch we are not learning
|
||||
# anything here. May fix this at some point.
|
||||
scale_bias_correction2 = 1 - beta2 ** (step + 1)
|
||||
|
||||
scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"])
|
||||
|
||||
@ -621,7 +623,6 @@ class Cain(Optimizer):
|
||||
device=scale_delta.device,
|
||||
dtype=scale_delta.dtype),
|
||||
scale_delta)
|
||||
|
||||
exp_avg.add_(p, alpha=scale_delta)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user