diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 385bfda36..a6ecef1e7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -563,9 +563,14 @@ class RelPositionMultiheadAttention(nn.Module): need_weights=need_weights, attn_mask=attn_mask, ) - attn_scores_out = torch.matmul(scores, self.attn_scores_proj_out) if attn_scores_in is not None: + attn_scores_out = torch.matmul(scores, self.attn_scores_proj_out) attn_scores_out = attn_scores_out + attn_scores_in + else: + # Here, add self.attn_scores_proj_in in order to make sure it has + # a grad. + attn_scores_out = torch.matmul(scores, self.attn_scores_proj_out + + self.attn_scores_proj_in) return x, weights, attn_scores_out def rel_shift(self, x: Tensor) -> Tensor: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 085e8df65..6294592a8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -127,7 +127,7 @@ class ScaledAdam(Optimizer): # Perform optimization step grad = p.grad - if grad.is_sparse: + if grad is not None and grad.is_sparse: raise RuntimeError( "ScaledAdam optimizer does not support sparse gradients" ) @@ -138,6 +138,8 @@ class ScaledAdam(Optimizer): if i == 0: clipping_scale = self._get_clipping_scale(group, p, state) + if grad is None: + continue self._step_one_batch(group, p, state, clipping_scale) @@ -211,6 +213,8 @@ class ScaledAdam(Optimizer): for p in group["params"]: state = self.state[p] grad = p.grad + if grad is None: + continue if grad.is_sparse: raise RuntimeError( "ScaledAdam optimizer does not support sparse gradients"