diff --git a/egs/librispeech/ASR/conformer_ctc2/attention.py b/egs/librispeech/ASR/conformer_ctc2/attention.py index af481e4fd..d37d80225 100644 --- a/egs/librispeech/ASR/conformer_ctc2/attention.py +++ b/egs/librispeech/ASR/conformer_ctc2/attention.py @@ -182,9 +182,15 @@ class MultiheadAttention(nn.Module): query, key, value = [x.transpose(1, 0) for x in (query, key, value)] if not self._qkv_same_embed_dim: - q_proj_weight = self.q_proj_weight.get_weight() if self.q_proj_weight is not None else None - k_proj_weight = self.k_proj_weight.get_weight() if self.k_proj_weight is not None else None - v_proj_weight = self.v_proj_weight.get_weight() if self.v_proj_weight is not None else None + q_proj_weight = ( + self.q_proj_weight.get_weight() if self.q_proj_weight else None + ) + k_proj_weight = ( + self.k_proj_weight.get_weight() if self.k_proj_weight else None + ) + v_proj_weight = ( + self.v_proj_weight.get_weight() if self.v_proj_weight else None + ) ( attn_output, attn_output_weights,