mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
fix for black
This commit is contained in:
parent
c0a5601c3d
commit
911bfacffd
@ -154,9 +154,9 @@ class MultiheadAttention(nn.Module):
|
||||
self.self_attention = self_attention
|
||||
self.encoder_decoder_attention = encoder_decoder_attention
|
||||
|
||||
assert not self.self_attention or self.qkv_same_dim, (
|
||||
"Self-attention requires query, key and value to be of the same size"
|
||||
)
|
||||
assert (
|
||||
not self.self_attention or self.qkv_same_dim
|
||||
), "Self-attention requires query, key and value to be of the same size"
|
||||
|
||||
self.k_proj = quant_noise(
|
||||
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
|
||||
@ -224,13 +224,7 @@ class MultiheadAttention(nn.Module):
|
||||
]
|
||||
)
|
||||
).tolist()
|
||||
+ torch.sum(
|
||||
torch.abs(
|
||||
self.k_proj.bias[
|
||||
start_idx:end_idx
|
||||
]
|
||||
)
|
||||
).tolist()
|
||||
+ torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist()
|
||||
)
|
||||
q_proj_heads_norm.append(
|
||||
torch.sum(
|
||||
@ -240,13 +234,7 @@ class MultiheadAttention(nn.Module):
|
||||
]
|
||||
)
|
||||
).tolist()
|
||||
+ torch.sum(
|
||||
torch.abs(
|
||||
self.q_proj.bias[
|
||||
start_idx:end_idx
|
||||
]
|
||||
)
|
||||
).tolist()
|
||||
+ torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist()
|
||||
)
|
||||
v_proj_heads_norm.append(
|
||||
torch.sum(
|
||||
@ -256,13 +244,7 @@ class MultiheadAttention(nn.Module):
|
||||
]
|
||||
)
|
||||
).tolist()
|
||||
+ torch.sum(
|
||||
torch.abs(
|
||||
self.v_proj.bias[
|
||||
start_idx:end_idx
|
||||
]
|
||||
)
|
||||
).tolist()
|
||||
+ torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist()
|
||||
)
|
||||
|
||||
heads_norm = []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user