fix for black

This commit is contained in:
yifanyeung 2024-02-18 13:15:56 +08:00
parent 809bdb07f0
commit c0a5601c3d

View File

@ -155,8 +155,7 @@ class MultiheadAttention(nn.Module):
self.encoder_decoder_attention = encoder_decoder_attention self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, ( assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "Self-attention requires query, key and value to be of the same size"
"value to be of the same size"
) )
self.k_proj = quant_noise( self.k_proj = quant_noise(
@ -219,35 +218,57 @@ class MultiheadAttention(nn.Module):
end_idx = (i + 1) * self.head_dim end_idx = (i + 1) * self.head_dim
k_proj_heads_norm.append( k_proj_heads_norm.append(
torch.sum( torch.sum(
torch.abs(self.k_proj.weight[start_idx:end_idx,]) torch.abs(
self.k_proj.weight[
start_idx:end_idx,
]
)
).tolist() ).tolist()
+ torch.sum( + torch.sum(
torch.abs(self.k_proj.bias[start_idx:end_idx]) torch.abs(
self.k_proj.bias[
start_idx:end_idx
]
)
).tolist() ).tolist()
) )
q_proj_heads_norm.append( q_proj_heads_norm.append(
torch.sum( torch.sum(
torch.abs(self.q_proj.weight[start_idx:end_idx,]) torch.abs(
self.q_proj.weight[
start_idx:end_idx,
]
)
).tolist() ).tolist()
+ torch.sum( + torch.sum(
torch.abs(self.q_proj.bias[start_idx:end_idx]) torch.abs(
self.q_proj.bias[
start_idx:end_idx
]
)
).tolist() ).tolist()
) )
v_proj_heads_norm.append( v_proj_heads_norm.append(
torch.sum( torch.sum(
torch.abs(self.v_proj.weight[start_idx:end_idx,]) torch.abs(
self.v_proj.weight[
start_idx:end_idx,
]
)
).tolist() ).tolist()
+ torch.sum( + torch.sum(
torch.abs(self.v_proj.bias[start_idx:end_idx]) torch.abs(
self.v_proj.bias[
start_idx:end_idx
]
)
).tolist() ).tolist()
) )
heads_norm = [] heads_norm = []
for i in range(self.num_heads): for i in range(self.num_heads):
heads_norm.append( heads_norm.append(
k_proj_heads_norm[i] k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i]
+ q_proj_heads_norm[i]
+ v_proj_heads_norm[i]
) )
sorted_head_index = sorted( sorted_head_index = sorted(
@ -271,19 +292,29 @@ class MultiheadAttention(nn.Module):
for ele in reserve_head_index: for ele in reserve_head_index:
start_idx, end_idx = ele start_idx, end_idx = ele
new_q_weight.append(self.q_proj.weight[start_idx:end_idx,]) new_q_weight.append(
self.q_proj.weight[
start_idx:end_idx,
]
)
new_q_bias.append(self.q_proj.bias[start_idx:end_idx]) new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
new_k_weight.append(self.k_proj.weight[start_idx:end_idx,]) new_k_weight.append(
self.k_proj.weight[
start_idx:end_idx,
]
)
new_k_bias.append(self.k_proj.bias[start_idx:end_idx]) new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
new_v_weight.append(self.v_proj.weight[start_idx:end_idx,]) new_v_weight.append(
self.v_proj.weight[
start_idx:end_idx,
]
)
new_v_bias.append(self.v_proj.bias[start_idx:end_idx]) new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
new_out_proj_weight.append( new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
self.out_proj.weight[:, start_idx:end_idx]
)
new_q_weight = torch.cat(new_q_weight).detach() new_q_weight = torch.cat(new_q_weight).detach()
new_k_weight = torch.cat(new_k_weight).detach() new_k_weight = torch.cat(new_k_weight).detach()
@ -330,9 +361,7 @@ class MultiheadAttention(nn.Module):
) -> Tuple[Optional[Tensor], Optional[Tensor]]: ) -> Tuple[Optional[Tensor], Optional[Tensor]]:
if attn_mask is not None: if attn_mask is not None:
shape = attn_mask.size()[:-1] + torch.Size([1]) shape = attn_mask.size()[:-1] + torch.Size([1])
attn_mask = torch.cat( attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1)
[attn_mask, attn_mask.new_zeros(shape)], dim=-1
)
if key_padding_mask is not None: if key_padding_mask is not None:
shape = key_padding_mask.size()[:-1] + torch.Size([1]) shape = key_padding_mask.size()[:-1] + torch.Size([1])
key_padding_mask = torch.cat( key_padding_mask = torch.cat(
@ -388,9 +417,7 @@ class MultiheadAttention(nn.Module):
key: Optional[Tensor], key: Optional[Tensor],
value: Optional[Tensor], value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
Dict[str, Dict[str, Optional[Tensor]]]
] = None,
need_weights: bool = True, need_weights: bool = True,
static_kv: bool = False, static_kv: bool = False,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
@ -455,9 +482,7 @@ class MultiheadAttention(nn.Module):
self.embed_dim, self.embed_dim,
self.num_heads, self.num_heads,
torch.empty([0]), torch.empty([0]),
torch.cat( torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
(self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)
),
self.bias_k, self.bias_k,
self.bias_v, self.bias_v,
self.add_zero_attn, self.add_zero_attn,
@ -465,9 +490,7 @@ class MultiheadAttention(nn.Module):
self.out_proj.weight, self.out_proj.weight,
self.out_proj.bias, self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference, self.training or self.dropout_module.apply_during_inference,
key_padding_mask.bool() key_padding_mask.bool() if key_padding_mask is not None else None,
if key_padding_mask is not None
else None,
need_weights, need_weights,
attn_mask, attn_mask,
use_separate_proj_weight=True, use_separate_proj_weight=True,
@ -482,10 +505,7 @@ class MultiheadAttention(nn.Module):
# previous time steps are cached - no need to recompute # previous time steps are cached - no need to recompute
# key and value if they are static # key and value if they are static
if static_kv: if static_kv:
assert ( assert self.encoder_decoder_attention and not self.self_attention
self.encoder_decoder_attention
and not self.self_attention
)
key = value = None key = value = None
else: else:
saved_state = None saved_state = None
@ -503,9 +523,9 @@ class MultiheadAttention(nn.Module):
else: else:
if self.beam_size > 1 and bsz == key.size(1): if self.beam_size > 1 and bsz == key.size(1):
# key is [T, bsz*beam_size, C], reduce to [T, bsz, C] # key is [T, bsz*beam_size, C], reduce to [T, bsz, C]
key = key.view( key = key.view(key.size(0), -1, self.beam_size, key.size(2))[
key.size(0), -1, self.beam_size, key.size(2) :, :, 0, :
)[:, :, 0, :] ]
if key_padding_mask is not None: if key_padding_mask is not None:
key_padding_mask = key_padding_mask.view( key_padding_mask = key_padding_mask.view(
-1, self.beam_size, key_padding_mask.size(1) -1, self.beam_size, key_padding_mask.size(1)
@ -552,9 +572,7 @@ class MultiheadAttention(nn.Module):
_prev_key = saved_state["prev_key"] _prev_key = saved_state["prev_key"]
assert _prev_key is not None assert _prev_key is not None
kv_bsz = _prev_key.size(0) kv_bsz = _prev_key.size(0)
prev_key = _prev_key.view( prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim)
kv_bsz * self.num_heads, -1, self.head_dim
)
if static_kv: if static_kv:
k = prev_key k = prev_key
else: else:
@ -585,18 +603,14 @@ class MultiheadAttention(nn.Module):
static_kv=static_kv, static_kv=static_kv,
) )
saved_state["prev_key"] = k.view( saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim)
kv_bsz, self.num_heads, -1, self.head_dim
)
saved_state["prev_value"] = v.view( saved_state["prev_value"] = v.view(
kv_bsz, self.num_heads, -1, self.head_dim kv_bsz, self.num_heads, -1, self.head_dim
) )
saved_state["prev_key_padding_mask"] = key_padding_mask saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None # In this branch incremental_state is never None
assert incremental_state is not None assert incremental_state is not None
incremental_state = self._set_input_buffer( incremental_state = self._set_input_buffer(incremental_state, saved_state)
incremental_state, saved_state
)
assert k is not None assert k is not None
assert k.size(1) == src_len assert k.size(1) == src_len
@ -622,14 +636,10 @@ class MultiheadAttention(nn.Module):
q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]),
k.view((kv_bsz, self.num_heads) + k.size()[1:]), k.view((kv_bsz, self.num_heads) + k.size()[1:]),
) )
attn_weights = attn_weights.reshape( attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:])
(-1,) + attn_weights.size()[-2:]
)
else: else:
attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask( attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
attn_weights, tgt_len, src_len, bsz
)
assert list(attn_weights.size()) == [ assert list(attn_weights.size()) == [
bsz * self.num_heads, bsz * self.num_heads,
@ -645,9 +655,7 @@ class MultiheadAttention(nn.Module):
if key_padding_mask is not None: if key_padding_mask is not None:
# don't attend to padding symbols # don't attend to padding symbols
attn_weights = attn_weights.view( attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
bsz, self.num_heads, tgt_len, src_len
)
if not is_tpu: if not is_tpu:
attn_weights = attn_weights.view( attn_weights = attn_weights.view(
kv_bsz, -1, self.num_heads, tgt_len, src_len kv_bsz, -1, self.num_heads, tgt_len, src_len
@ -661,13 +669,9 @@ class MultiheadAttention(nn.Module):
) )
else: else:
attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill( attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
key_padding_mask, float("-inf")
)
attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view( attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
bsz * self.num_heads, tgt_len, src_len
)
if before_softmax: if before_softmax:
return attn_weights, v return attn_weights, v
@ -712,11 +716,7 @@ class MultiheadAttention(nn.Module):
# the transpose is a no-op copy before view, thus unnecessary # the transpose is a no-op copy before view, thus unnecessary
attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim) attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
else: else:
attn = ( attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
attn.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, self.embed_dim)
)
attn = self.out_proj(attn) attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None attn_weights: Optional[Tensor] = None
if need_weights: if need_weights:
@ -786,9 +786,7 @@ class MultiheadAttention(nn.Module):
input_buffer_k = input_buffer[k] input_buffer_k = input_buffer[k]
if input_buffer_k is not None: if input_buffer_k is not None:
if self.encoder_decoder_attention: if self.encoder_decoder_attention:
if input_buffer_k.size( if input_buffer_k.size(0) * self.beam_size == new_order.size(0):
0
) * self.beam_size == new_order.size(0):
return incremental_state return incremental_state
elif self.beam_size > 1: elif self.beam_size > 1:
input_buffer[k] = input_buffer_k.index_select( input_buffer[k] = input_buffer_k.index_select(
@ -797,16 +795,10 @@ class MultiheadAttention(nn.Module):
// self.beam_size, // self.beam_size,
) )
else: else:
input_buffer[k] = input_buffer_k.index_select( input_buffer[k] = input_buffer_k.index_select(0, new_order)
0, new_order
)
else: else:
input_buffer[k] = input_buffer_k.index_select( input_buffer[k] = input_buffer_k.index_select(0, new_order)
0, new_order incremental_state = self._set_input_buffer(incremental_state, input_buffer)
)
incremental_state = self._set_input_buffer(
incremental_state, input_buffer
)
return incremental_state return incremental_state
def set_beam_size(self, beam_size): def set_beam_size(self, beam_size):
@ -829,13 +821,9 @@ class MultiheadAttention(nn.Module):
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
buffer: Dict[str, Optional[Tensor]], buffer: Dict[str, Optional[Tensor]],
): ):
return self.set_incremental_state( return self.set_incremental_state(incremental_state, "attn_state", buffer)
incremental_state, "attn_state", buffer
)
def apply_sparse_mask( def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
self, attn_weights, tgt_len: int, src_len: int, bsz: int
):
return attn_weights return attn_weights
def upgrade_state_dict_named(self, state_dict, name): def upgrade_state_dict_named(self, state_dict, name):
@ -847,27 +835,19 @@ class MultiheadAttention(nn.Module):
# in_proj_weight used to be q + k + v with same dimensions # in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3) dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
items_to_add[prefix + "k_proj.weight"] = state_dict[k][ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
dim : 2 * dim items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
]
items_to_add[prefix + "v_proj.weight"] = state_dict[k][
2 * dim :
]
keys_to_remove.append(k) keys_to_remove.append(k)
k_bias = prefix + "in_proj_bias" k_bias = prefix + "in_proj_bias"
if k_bias in state_dict.keys(): if k_bias in state_dict.keys():
dim = int(state_dict[k].shape[0] / 3) dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
:dim
]
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
dim : 2 * dim dim : 2 * dim
] ]
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
2 * dim :
]
keys_to_remove.append(prefix + "in_proj_bias") keys_to_remove.append(prefix + "in_proj_bias")