check whether q k v weight is None

This commit is contained in:
Quandwang 2022-07-21 20:07:08 +08:00
parent 1f32d34e9b
commit a1a84ed148

View File

@ -182,9 +182,15 @@ class MultiheadAttention(nn.Module):
query, key, value = [x.transpose(1, 0) for x in (query, key, value)] query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
if not self._qkv_same_embed_dim: if not self._qkv_same_embed_dim:
q_proj_weight = self.q_proj_weight.get_weight() if self.q_proj_weight is not None else None q_proj_weight = (
k_proj_weight = self.k_proj_weight.get_weight() if self.k_proj_weight is not None else None self.q_proj_weight.get_weight() if self.q_proj_weight else None
v_proj_weight = self.v_proj_weight.get_weight() if self.v_proj_weight is not None else None )
k_proj_weight = (
self.k_proj_weight.get_weight() if self.k_proj_weight else None
)
v_proj_weight = (
self.v_proj_weight.get_weight() if self.v_proj_weight else None
)
( (
attn_output, attn_output,
attn_output_weights, attn_output_weights,