diff --git a/egs/librispeech/ASR/conformer_ctc2/attention.py b/egs/librispeech/ASR/conformer_ctc2/attention.py index 0f4313c17..af481e4fd 100644 --- a/egs/librispeech/ASR/conformer_ctc2/attention.py +++ b/egs/librispeech/ASR/conformer_ctc2/attention.py @@ -182,6 +182,9 @@ class MultiheadAttention(nn.Module): query, key, value = [x.transpose(1, 0) for x in (query, key, value)] if not self._qkv_same_embed_dim: + q_proj_weight = self.q_proj_weight.get_weight() if self.q_proj_weight is not None else None + k_proj_weight = self.k_proj_weight.get_weight() if self.k_proj_weight is not None else None + v_proj_weight = self.v_proj_weight.get_weight() if self.v_proj_weight is not None else None ( attn_output, attn_output_weights, @@ -204,9 +207,9 @@ class MultiheadAttention(nn.Module): need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=True, - q_proj_weight=self.q_proj_weight.get_weight(), - k_proj_weight=self.k_proj_weight.get_weight(), - v_proj_weight=self.v_proj_weight.get_weight(), + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, ) else: (