mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
check whether q k v weight is None
This commit is contained in:
parent
eb86873ee6
commit
1f32d34e9b
@ -182,6 +182,9 @@ class MultiheadAttention(nn.Module):
|
|||||||
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
|
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
|
||||||
|
|
||||||
if not self._qkv_same_embed_dim:
|
if not self._qkv_same_embed_dim:
|
||||||
|
q_proj_weight = self.q_proj_weight.get_weight() if self.q_proj_weight is not None else None
|
||||||
|
k_proj_weight = self.k_proj_weight.get_weight() if self.k_proj_weight is not None else None
|
||||||
|
v_proj_weight = self.v_proj_weight.get_weight() if self.v_proj_weight is not None else None
|
||||||
(
|
(
|
||||||
attn_output,
|
attn_output,
|
||||||
attn_output_weights,
|
attn_output_weights,
|
||||||
@ -204,9 +207,9 @@ class MultiheadAttention(nn.Module):
|
|||||||
need_weights=need_weights,
|
need_weights=need_weights,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
use_separate_proj_weight=True,
|
use_separate_proj_weight=True,
|
||||||
q_proj_weight=self.q_proj_weight.get_weight(),
|
q_proj_weight=q_proj_weight,
|
||||||
k_proj_weight=self.k_proj_weight.get_weight(),
|
k_proj_weight=k_proj_weight,
|
||||||
v_proj_weight=self.v_proj_weight.get_weight(),
|
v_proj_weight=v_proj_weight,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
(
|
(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user