check whether q k v weight is None

This commit is contained in:
Quandwang 2022-07-21 20:01:37 +08:00
parent eb86873ee6
commit 1f32d34e9b

View File

@ -182,6 +182,9 @@ class MultiheadAttention(nn.Module):
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
if not self._qkv_same_embed_dim:
q_proj_weight = self.q_proj_weight.get_weight() if self.q_proj_weight is not None else None
k_proj_weight = self.k_proj_weight.get_weight() if self.k_proj_weight is not None else None
v_proj_weight = self.v_proj_weight.get_weight() if self.v_proj_weight is not None else None
(
attn_output,
attn_output_weights,
@ -204,9 +207,9 @@ class MultiheadAttention(nn.Module):
need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight.get_weight(),
k_proj_weight=self.k_proj_weight.get_weight(),
v_proj_weight=self.v_proj_weight.get_weight(),
q_proj_weight=q_proj_weight,
k_proj_weight=k_proj_weight,
v_proj_weight=v_proj_weight,
)
else:
(