mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
fix for black
This commit is contained in:
parent
c0a5601c3d
commit
911bfacffd
@ -154,9 +154,9 @@ class MultiheadAttention(nn.Module):
|
|||||||
self.self_attention = self_attention
|
self.self_attention = self_attention
|
||||||
self.encoder_decoder_attention = encoder_decoder_attention
|
self.encoder_decoder_attention = encoder_decoder_attention
|
||||||
|
|
||||||
assert not self.self_attention or self.qkv_same_dim, (
|
assert (
|
||||||
"Self-attention requires query, key and value to be of the same size"
|
not self.self_attention or self.qkv_same_dim
|
||||||
)
|
), "Self-attention requires query, key and value to be of the same size"
|
||||||
|
|
||||||
self.k_proj = quant_noise(
|
self.k_proj = quant_noise(
|
||||||
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
|
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||||
@ -224,13 +224,7 @@ class MultiheadAttention(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
).tolist()
|
).tolist()
|
||||||
+ torch.sum(
|
+ torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist()
|
||||||
torch.abs(
|
|
||||||
self.k_proj.bias[
|
|
||||||
start_idx:end_idx
|
|
||||||
]
|
|
||||||
)
|
|
||||||
).tolist()
|
|
||||||
)
|
)
|
||||||
q_proj_heads_norm.append(
|
q_proj_heads_norm.append(
|
||||||
torch.sum(
|
torch.sum(
|
||||||
@ -240,13 +234,7 @@ class MultiheadAttention(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
).tolist()
|
).tolist()
|
||||||
+ torch.sum(
|
+ torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist()
|
||||||
torch.abs(
|
|
||||||
self.q_proj.bias[
|
|
||||||
start_idx:end_idx
|
|
||||||
]
|
|
||||||
)
|
|
||||||
).tolist()
|
|
||||||
)
|
)
|
||||||
v_proj_heads_norm.append(
|
v_proj_heads_norm.append(
|
||||||
torch.sum(
|
torch.sum(
|
||||||
@ -256,13 +244,7 @@ class MultiheadAttention(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
).tolist()
|
).tolist()
|
||||||
+ torch.sum(
|
+ torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist()
|
||||||
torch.abs(
|
|
||||||
self.v_proj.bias[
|
|
||||||
start_idx:end_idx
|
|
||||||
]
|
|
||||||
)
|
|
||||||
).tolist()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
heads_norm = []
|
heads_norm = []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user