mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
check whether q k v weight is None
This commit is contained in:
parent
1f32d34e9b
commit
a1a84ed148
@ -182,9 +182,15 @@ class MultiheadAttention(nn.Module):
|
||||
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
|
||||
|
||||
if not self._qkv_same_embed_dim:
|
||||
q_proj_weight = self.q_proj_weight.get_weight() if self.q_proj_weight is not None else None
|
||||
k_proj_weight = self.k_proj_weight.get_weight() if self.k_proj_weight is not None else None
|
||||
v_proj_weight = self.v_proj_weight.get_weight() if self.v_proj_weight is not None else None
|
||||
q_proj_weight = (
|
||||
self.q_proj_weight.get_weight() if self.q_proj_weight else None
|
||||
)
|
||||
k_proj_weight = (
|
||||
self.k_proj_weight.get_weight() if self.k_proj_weight else None
|
||||
)
|
||||
v_proj_weight = (
|
||||
self.v_proj_weight.get_weight() if self.v_proj_weight else None
|
||||
)
|
||||
(
|
||||
attn_output,
|
||||
attn_output_weights,
|
||||
|
Loading…
x
Reference in New Issue
Block a user