Some small fixes, to bias_correction2 formula and remove bias-u,v-scale

This commit is contained in:
Daniel Povey 2022-05-22 16:28:33 +08:00
parent b916789ca3
commit 9ef11e64ba
2 changed files with 5 additions and 12 deletions

View File

@ -449,16 +449,8 @@ class RelPositionMultiheadAttention(nn.Module):
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
self._reset_parameters()
def _pos_bias_u(self):
return self.pos_bias_u * self.pos_bias_u_scale.exp()
def _pos_bias_v(self):
return self.pos_bias_v * self.pos_bias_v_scale.exp()
def _reset_parameters(self) -> None:
nn.init.uniform_(self.pos_bias_u, -0.05, 0.05)
nn.init.uniform_(self.pos_bias_v, -0.05, 0.05)
@ -756,11 +748,11 @@ class RelPositionMultiheadAttention(nn.Module):
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
q_with_bias_u = (q + self._pos_bias_u()).transpose(
q_with_bias_u = (q + self.pos_bias_u).transpose(
1, 2
) # (batch, head, time1, d_k)
q_with_bias_v = (q + self._pos_bias_v()).transpose(
q_with_bias_v = (q + self.pos_bias_v).transpose(
1, 2
) # (batch, head, time1, d_k)

View File

@ -599,7 +599,9 @@ class Cain(Optimizer):
scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv,
value=1 - beta2)
scale_bias_correction2 = 1 - beta2 ** step
# should actually be step + 1, so on 1st minibatch we are not learning
# anything here. May fix this at some point.
scale_bias_correction2 = 1 - beta2 ** (step + 1)
scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"])
@ -621,7 +623,6 @@ class Cain(Optimizer):
device=scale_delta.device,
dtype=scale_delta.dtype),
scale_delta)
exp_avg.add_(p, alpha=scale_delta)