diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py index ca393f2b9..99a6f7ce5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py @@ -449,16 +449,8 @@ class RelPositionMultiheadAttention(nn.Module): # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) - self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) self._reset_parameters() - def _pos_bias_u(self): - return self.pos_bias_u * self.pos_bias_u_scale.exp() - - def _pos_bias_v(self): - return self.pos_bias_v * self.pos_bias_v_scale.exp() - def _reset_parameters(self) -> None: nn.init.uniform_(self.pos_bias_u, -0.05, 0.05) nn.init.uniform_(self.pos_bias_v, -0.05, 0.05) @@ -756,11 +748,11 @@ class RelPositionMultiheadAttention(nn.Module): p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - q_with_bias_u = (q + self._pos_bias_u()).transpose( + q_with_bias_u = (q + self.pos_bias_u).transpose( 1, 2 ) # (batch, head, time1, d_k) - q_with_bias_v = (q + self._pos_bias_v()).transpose( + q_with_bias_v = (q + self.pos_bias_v).transpose( 1, 2 ) # (batch, head, time1, d_k) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py index 9db522b31..5089c84c9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/optim.py @@ -599,7 +599,9 @@ class Cain(Optimizer): scale_exp_avg_sq.mul_(beta2).addcmul_(scale_deriv, scale_deriv, value=1 - beta2) - scale_bias_correction2 = 1 - beta2 ** step + # should actually be step + 1, so on 1st minibatch we are not learning + # anything here. May fix this at some point. + scale_bias_correction2 = 1 - beta2 ** (step + 1) scale_denom = (scale_exp_avg_sq.sqrt()).add_(group["eps"]) @@ -621,7 +623,6 @@ class Cain(Optimizer): device=scale_delta.device, dtype=scale_delta.dtype), scale_delta) - exp_avg.add_(p, alpha=scale_delta)