mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
fix for black
This commit is contained in:
parent
b070d04ae8
commit
809bdb07f0
@ -155,7 +155,8 @@ class MultiheadAttention(nn.Module):
|
|||||||
self.encoder_decoder_attention = encoder_decoder_attention
|
self.encoder_decoder_attention = encoder_decoder_attention
|
||||||
|
|
||||||
assert not self.self_attention or self.qkv_same_dim, (
|
assert not self.self_attention or self.qkv_same_dim, (
|
||||||
"Self-attention requires query, key and " "value to be of the same size"
|
"Self-attention requires query, key and "
|
||||||
|
"value to be of the same size"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.k_proj = quant_noise(
|
self.k_proj = quant_noise(
|
||||||
@ -217,22 +218,36 @@ class MultiheadAttention(nn.Module):
|
|||||||
start_idx = i * self.head_dim
|
start_idx = i * self.head_dim
|
||||||
end_idx = (i + 1) * self.head_dim
|
end_idx = (i + 1) * self.head_dim
|
||||||
k_proj_heads_norm.append(
|
k_proj_heads_norm.append(
|
||||||
torch.sum(torch.abs(self.k_proj.weight[start_idx:end_idx,])).tolist()
|
torch.sum(
|
||||||
+ torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist()
|
torch.abs(self.k_proj.weight[start_idx:end_idx,])
|
||||||
|
).tolist()
|
||||||
|
+ torch.sum(
|
||||||
|
torch.abs(self.k_proj.bias[start_idx:end_idx])
|
||||||
|
).tolist()
|
||||||
)
|
)
|
||||||
q_proj_heads_norm.append(
|
q_proj_heads_norm.append(
|
||||||
torch.sum(torch.abs(self.q_proj.weight[start_idx:end_idx,])).tolist()
|
torch.sum(
|
||||||
+ torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist()
|
torch.abs(self.q_proj.weight[start_idx:end_idx,])
|
||||||
|
).tolist()
|
||||||
|
+ torch.sum(
|
||||||
|
torch.abs(self.q_proj.bias[start_idx:end_idx])
|
||||||
|
).tolist()
|
||||||
)
|
)
|
||||||
v_proj_heads_norm.append(
|
v_proj_heads_norm.append(
|
||||||
torch.sum(torch.abs(self.v_proj.weight[start_idx:end_idx,])).tolist()
|
torch.sum(
|
||||||
+ torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist()
|
torch.abs(self.v_proj.weight[start_idx:end_idx,])
|
||||||
|
).tolist()
|
||||||
|
+ torch.sum(
|
||||||
|
torch.abs(self.v_proj.bias[start_idx:end_idx])
|
||||||
|
).tolist()
|
||||||
)
|
)
|
||||||
|
|
||||||
heads_norm = []
|
heads_norm = []
|
||||||
for i in range(self.num_heads):
|
for i in range(self.num_heads):
|
||||||
heads_norm.append(
|
heads_norm.append(
|
||||||
k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i]
|
k_proj_heads_norm[i]
|
||||||
|
+ q_proj_heads_norm[i]
|
||||||
|
+ v_proj_heads_norm[i]
|
||||||
)
|
)
|
||||||
|
|
||||||
sorted_head_index = sorted(
|
sorted_head_index = sorted(
|
||||||
@ -266,7 +281,9 @@ class MultiheadAttention(nn.Module):
|
|||||||
new_v_weight.append(self.v_proj.weight[start_idx:end_idx,])
|
new_v_weight.append(self.v_proj.weight[start_idx:end_idx,])
|
||||||
new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
|
new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
|
||||||
|
|
||||||
new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
|
new_out_proj_weight.append(
|
||||||
|
self.out_proj.weight[:, start_idx:end_idx]
|
||||||
|
)
|
||||||
|
|
||||||
new_q_weight = torch.cat(new_q_weight).detach()
|
new_q_weight = torch.cat(new_q_weight).detach()
|
||||||
new_k_weight = torch.cat(new_k_weight).detach()
|
new_k_weight = torch.cat(new_k_weight).detach()
|
||||||
@ -313,7 +330,9 @@ class MultiheadAttention(nn.Module):
|
|||||||
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
|
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
shape = attn_mask.size()[:-1] + torch.Size([1])
|
shape = attn_mask.size()[:-1] + torch.Size([1])
|
||||||
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1)
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, attn_mask.new_zeros(shape)], dim=-1
|
||||||
|
)
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
shape = key_padding_mask.size()[:-1] + torch.Size([1])
|
shape = key_padding_mask.size()[:-1] + torch.Size([1])
|
||||||
key_padding_mask = torch.cat(
|
key_padding_mask = torch.cat(
|
||||||
@ -351,10 +370,12 @@ class MultiheadAttention(nn.Module):
|
|||||||
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||||
zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:]
|
zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:]
|
||||||
k = torch.cat(
|
k = torch.cat(
|
||||||
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2
|
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)],
|
||||||
|
dim=-2,
|
||||||
)
|
)
|
||||||
v = torch.cat(
|
v = torch.cat(
|
||||||
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2
|
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)],
|
||||||
|
dim=-2,
|
||||||
)
|
)
|
||||||
key_padding_mask, attn_mask = self._pad_masks(
|
key_padding_mask, attn_mask = self._pad_masks(
|
||||||
key_padding_mask=key_padding_mask, attn_mask=attn_mask
|
key_padding_mask=key_padding_mask, attn_mask=attn_mask
|
||||||
@ -367,7 +388,9 @@ class MultiheadAttention(nn.Module):
|
|||||||
key: Optional[Tensor],
|
key: Optional[Tensor],
|
||||||
value: Optional[Tensor],
|
value: Optional[Tensor],
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
incremental_state: Optional[
|
||||||
|
Dict[str, Dict[str, Optional[Tensor]]]
|
||||||
|
] = None,
|
||||||
need_weights: bool = True,
|
need_weights: bool = True,
|
||||||
static_kv: bool = False,
|
static_kv: bool = False,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
@ -432,7 +455,9 @@ class MultiheadAttention(nn.Module):
|
|||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
torch.empty([0]),
|
torch.empty([0]),
|
||||||
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
torch.cat(
|
||||||
|
(self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)
|
||||||
|
),
|
||||||
self.bias_k,
|
self.bias_k,
|
||||||
self.bias_v,
|
self.bias_v,
|
||||||
self.add_zero_attn,
|
self.add_zero_attn,
|
||||||
@ -440,7 +465,9 @@ class MultiheadAttention(nn.Module):
|
|||||||
self.out_proj.weight,
|
self.out_proj.weight,
|
||||||
self.out_proj.bias,
|
self.out_proj.bias,
|
||||||
self.training or self.dropout_module.apply_during_inference,
|
self.training or self.dropout_module.apply_during_inference,
|
||||||
key_padding_mask.bool() if key_padding_mask is not None else None,
|
key_padding_mask.bool()
|
||||||
|
if key_padding_mask is not None
|
||||||
|
else None,
|
||||||
need_weights,
|
need_weights,
|
||||||
attn_mask,
|
attn_mask,
|
||||||
use_separate_proj_weight=True,
|
use_separate_proj_weight=True,
|
||||||
@ -455,7 +482,10 @@ class MultiheadAttention(nn.Module):
|
|||||||
# previous time steps are cached - no need to recompute
|
# previous time steps are cached - no need to recompute
|
||||||
# key and value if they are static
|
# key and value if they are static
|
||||||
if static_kv:
|
if static_kv:
|
||||||
assert self.encoder_decoder_attention and not self.self_attention
|
assert (
|
||||||
|
self.encoder_decoder_attention
|
||||||
|
and not self.self_attention
|
||||||
|
)
|
||||||
key = value = None
|
key = value = None
|
||||||
else:
|
else:
|
||||||
saved_state = None
|
saved_state = None
|
||||||
@ -473,9 +503,9 @@ class MultiheadAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
if self.beam_size > 1 and bsz == key.size(1):
|
if self.beam_size > 1 and bsz == key.size(1):
|
||||||
# key is [T, bsz*beam_size, C], reduce to [T, bsz, C]
|
# key is [T, bsz*beam_size, C], reduce to [T, bsz, C]
|
||||||
key = key.view(key.size(0), -1, self.beam_size, key.size(2))[
|
key = key.view(
|
||||||
:, :, 0, :
|
key.size(0), -1, self.beam_size, key.size(2)
|
||||||
]
|
)[:, :, 0, :]
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
key_padding_mask = key_padding_mask.view(
|
key_padding_mask = key_padding_mask.view(
|
||||||
-1, self.beam_size, key_padding_mask.size(1)
|
-1, self.beam_size, key_padding_mask.size(1)
|
||||||
@ -522,7 +552,9 @@ class MultiheadAttention(nn.Module):
|
|||||||
_prev_key = saved_state["prev_key"]
|
_prev_key = saved_state["prev_key"]
|
||||||
assert _prev_key is not None
|
assert _prev_key is not None
|
||||||
kv_bsz = _prev_key.size(0)
|
kv_bsz = _prev_key.size(0)
|
||||||
prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim)
|
prev_key = _prev_key.view(
|
||||||
|
kv_bsz * self.num_heads, -1, self.head_dim
|
||||||
|
)
|
||||||
if static_kv:
|
if static_kv:
|
||||||
k = prev_key
|
k = prev_key
|
||||||
else:
|
else:
|
||||||
@ -553,14 +585,18 @@ class MultiheadAttention(nn.Module):
|
|||||||
static_kv=static_kv,
|
static_kv=static_kv,
|
||||||
)
|
)
|
||||||
|
|
||||||
saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim)
|
saved_state["prev_key"] = k.view(
|
||||||
|
kv_bsz, self.num_heads, -1, self.head_dim
|
||||||
|
)
|
||||||
saved_state["prev_value"] = v.view(
|
saved_state["prev_value"] = v.view(
|
||||||
kv_bsz, self.num_heads, -1, self.head_dim
|
kv_bsz, self.num_heads, -1, self.head_dim
|
||||||
)
|
)
|
||||||
saved_state["prev_key_padding_mask"] = key_padding_mask
|
saved_state["prev_key_padding_mask"] = key_padding_mask
|
||||||
# In this branch incremental_state is never None
|
# In this branch incremental_state is never None
|
||||||
assert incremental_state is not None
|
assert incremental_state is not None
|
||||||
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
incremental_state = self._set_input_buffer(
|
||||||
|
incremental_state, saved_state
|
||||||
|
)
|
||||||
assert k is not None
|
assert k is not None
|
||||||
assert k.size(1) == src_len
|
assert k.size(1) == src_len
|
||||||
|
|
||||||
@ -586,12 +622,20 @@ class MultiheadAttention(nn.Module):
|
|||||||
q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]),
|
q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]),
|
||||||
k.view((kv_bsz, self.num_heads) + k.size()[1:]),
|
k.view((kv_bsz, self.num_heads) + k.size()[1:]),
|
||||||
)
|
)
|
||||||
attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:])
|
attn_weights = attn_weights.reshape(
|
||||||
|
(-1,) + attn_weights.size()[-2:]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||||
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
attn_weights = self.apply_sparse_mask(
|
||||||
|
attn_weights, tgt_len, src_len, bsz
|
||||||
|
)
|
||||||
|
|
||||||
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
assert list(attn_weights.size()) == [
|
||||||
|
bsz * self.num_heads,
|
||||||
|
tgt_len,
|
||||||
|
src_len,
|
||||||
|
]
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
attn_mask = attn_mask.unsqueeze(0)
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
@ -601,7 +645,9 @@ class MultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
# don't attend to padding symbols
|
# don't attend to padding symbols
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(
|
||||||
|
bsz, self.num_heads, tgt_len, src_len
|
||||||
|
)
|
||||||
if not is_tpu:
|
if not is_tpu:
|
||||||
attn_weights = attn_weights.view(
|
attn_weights = attn_weights.view(
|
||||||
kv_bsz, -1, self.num_heads, tgt_len, src_len
|
kv_bsz, -1, self.num_heads, tgt_len, src_len
|
||||||
@ -615,9 +661,13 @@ class MultiheadAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_weights = attn_weights.transpose(0, 2)
|
attn_weights = attn_weights.transpose(0, 2)
|
||||||
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
attn_weights = attn_weights.masked_fill(
|
||||||
|
key_padding_mask, float("-inf")
|
||||||
|
)
|
||||||
attn_weights = attn_weights.transpose(0, 2)
|
attn_weights = attn_weights.transpose(0, 2)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(
|
||||||
|
bsz * self.num_heads, tgt_len, src_len
|
||||||
|
)
|
||||||
|
|
||||||
if before_softmax:
|
if before_softmax:
|
||||||
return attn_weights, v
|
return attn_weights, v
|
||||||
@ -652,13 +702,21 @@ class MultiheadAttention(nn.Module):
|
|||||||
attn = attn.reshape((-1,) + attn.size()[-2:])
|
attn = attn.reshape((-1,) + attn.size()[-2:])
|
||||||
else:
|
else:
|
||||||
attn = torch.bmm(attn_probs, v)
|
attn = torch.bmm(attn_probs, v)
|
||||||
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
assert list(attn.size()) == [
|
||||||
|
bsz * self.num_heads,
|
||||||
|
tgt_len,
|
||||||
|
self.head_dim,
|
||||||
|
]
|
||||||
if self.onnx_trace and attn.size(1) == 1:
|
if self.onnx_trace and attn.size(1) == 1:
|
||||||
# when ONNX tracing a single decoder step (sequence length == 1)
|
# when ONNX tracing a single decoder step (sequence length == 1)
|
||||||
# the transpose is a no-op copy before view, thus unnecessary
|
# the transpose is a no-op copy before view, thus unnecessary
|
||||||
attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
|
attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim)
|
||||||
else:
|
else:
|
||||||
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
|
attn = (
|
||||||
|
attn.transpose(0, 1)
|
||||||
|
.contiguous()
|
||||||
|
.view(tgt_len, bsz, self.embed_dim)
|
||||||
|
)
|
||||||
attn = self.out_proj(attn)
|
attn = self.out_proj(attn)
|
||||||
attn_weights: Optional[Tensor] = None
|
attn_weights: Optional[Tensor] = None
|
||||||
if need_weights:
|
if need_weights:
|
||||||
@ -728,7 +786,9 @@ class MultiheadAttention(nn.Module):
|
|||||||
input_buffer_k = input_buffer[k]
|
input_buffer_k = input_buffer[k]
|
||||||
if input_buffer_k is not None:
|
if input_buffer_k is not None:
|
||||||
if self.encoder_decoder_attention:
|
if self.encoder_decoder_attention:
|
||||||
if input_buffer_k.size(0) * self.beam_size == new_order.size(0):
|
if input_buffer_k.size(
|
||||||
|
0
|
||||||
|
) * self.beam_size == new_order.size(0):
|
||||||
return incremental_state
|
return incremental_state
|
||||||
elif self.beam_size > 1:
|
elif self.beam_size > 1:
|
||||||
input_buffer[k] = input_buffer_k.index_select(
|
input_buffer[k] = input_buffer_k.index_select(
|
||||||
@ -737,10 +797,16 @@ class MultiheadAttention(nn.Module):
|
|||||||
// self.beam_size,
|
// self.beam_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
input_buffer[k] = input_buffer_k.index_select(
|
||||||
|
0, new_order
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
input_buffer[k] = input_buffer_k.index_select(
|
||||||
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
0, new_order
|
||||||
|
)
|
||||||
|
incremental_state = self._set_input_buffer(
|
||||||
|
incremental_state, input_buffer
|
||||||
|
)
|
||||||
return incremental_state
|
return incremental_state
|
||||||
|
|
||||||
def set_beam_size(self, beam_size):
|
def set_beam_size(self, beam_size):
|
||||||
@ -748,7 +814,8 @@ class MultiheadAttention(nn.Module):
|
|||||||
self.beam_size = beam_size
|
self.beam_size = beam_size
|
||||||
|
|
||||||
def _get_input_buffer(
|
def _get_input_buffer(
|
||||||
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
self,
|
||||||
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||||
) -> Dict[str, Optional[Tensor]]:
|
) -> Dict[str, Optional[Tensor]]:
|
||||||
result = self.get_incremental_state(incremental_state, "attn_state")
|
result = self.get_incremental_state(incremental_state, "attn_state")
|
||||||
if result is not None:
|
if result is not None:
|
||||||
@ -762,9 +829,13 @@ class MultiheadAttention(nn.Module):
|
|||||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||||
buffer: Dict[str, Optional[Tensor]],
|
buffer: Dict[str, Optional[Tensor]],
|
||||||
):
|
):
|
||||||
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
return self.set_incremental_state(
|
||||||
|
incremental_state, "attn_state", buffer
|
||||||
|
)
|
||||||
|
|
||||||
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
def apply_sparse_mask(
|
||||||
|
self, attn_weights, tgt_len: int, src_len: int, bsz: int
|
||||||
|
):
|
||||||
return attn_weights
|
return attn_weights
|
||||||
|
|
||||||
def upgrade_state_dict_named(self, state_dict, name):
|
def upgrade_state_dict_named(self, state_dict, name):
|
||||||
@ -776,19 +847,27 @@ class MultiheadAttention(nn.Module):
|
|||||||
# in_proj_weight used to be q + k + v with same dimensions
|
# in_proj_weight used to be q + k + v with same dimensions
|
||||||
dim = int(state_dict[k].shape[0] / 3)
|
dim = int(state_dict[k].shape[0] / 3)
|
||||||
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
||||||
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
items_to_add[prefix + "k_proj.weight"] = state_dict[k][
|
||||||
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
dim : 2 * dim
|
||||||
|
]
|
||||||
|
items_to_add[prefix + "v_proj.weight"] = state_dict[k][
|
||||||
|
2 * dim :
|
||||||
|
]
|
||||||
|
|
||||||
keys_to_remove.append(k)
|
keys_to_remove.append(k)
|
||||||
|
|
||||||
k_bias = prefix + "in_proj_bias"
|
k_bias = prefix + "in_proj_bias"
|
||||||
if k_bias in state_dict.keys():
|
if k_bias in state_dict.keys():
|
||||||
dim = int(state_dict[k].shape[0] / 3)
|
dim = int(state_dict[k].shape[0] / 3)
|
||||||
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][
|
||||||
|
:dim
|
||||||
|
]
|
||||||
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
|
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
|
||||||
dim : 2 * dim
|
dim : 2 * dim
|
||||||
]
|
]
|
||||||
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][
|
||||||
|
2 * dim :
|
||||||
|
]
|
||||||
|
|
||||||
keys_to_remove.append(prefix + "in_proj_bias")
|
keys_to_remove.append(prefix + "in_proj_bias")
|
||||||
|
|
||||||
|
@ -925,7 +925,9 @@ def train_one_epoch(
|
|||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale",
|
||||||
|
cur_grad_scale,
|
||||||
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
|
@ -925,7 +925,9 @@ def train_one_epoch(
|
|||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale",
|
||||||
|
cur_grad_scale,
|
||||||
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
|
@ -251,7 +251,10 @@ def add_hubert_arguments(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-embed-dim", type=int, default=768, help="encoder embedding dimension"
|
"--encoder-embed-dim",
|
||||||
|
type=int,
|
||||||
|
default=768,
|
||||||
|
help="encoder embedding dimension",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -271,7 +274,14 @@ def add_hubert_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--activation-fn",
|
"--activation-fn",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["relu", "gelu", "gelu_fast", "gelu_accurate", "tanh", "linear"],
|
choices=[
|
||||||
|
"relu",
|
||||||
|
"gelu",
|
||||||
|
"gelu_fast",
|
||||||
|
"gelu_accurate",
|
||||||
|
"tanh",
|
||||||
|
"linear",
|
||||||
|
],
|
||||||
default="gelu",
|
default="gelu",
|
||||||
help="activation function to use",
|
help="activation function to use",
|
||||||
)
|
)
|
||||||
@ -356,11 +366,17 @@ def add_hubert_arguments(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--conv-bias", type=bool, default=False, help="include bias in conv encoder"
|
"--conv-bias",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="include bias in conv encoder",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logit-temp", type=float, default=0.1, help="temperature to divide logits by"
|
"--logit-temp",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="temperature to divide logits by",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -251,7 +251,10 @@ def add_hubert_arguments(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-embed-dim", type=int, default=768, help="encoder embedding dimension"
|
"--encoder-embed-dim",
|
||||||
|
type=int,
|
||||||
|
default=768,
|
||||||
|
help="encoder embedding dimension",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -271,7 +274,14 @@ def add_hubert_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--activation-fn",
|
"--activation-fn",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["relu", "gelu", "gelu_fast", "gelu_accurate", "tanh", "linear"],
|
choices=[
|
||||||
|
"relu",
|
||||||
|
"gelu",
|
||||||
|
"gelu_fast",
|
||||||
|
"gelu_accurate",
|
||||||
|
"tanh",
|
||||||
|
"linear",
|
||||||
|
],
|
||||||
default="gelu",
|
default="gelu",
|
||||||
help="activation function to use",
|
help="activation function to use",
|
||||||
)
|
)
|
||||||
@ -356,11 +366,17 @@ def add_hubert_arguments(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--conv-bias", type=bool, default=False, help="include bias in conv encoder"
|
"--conv-bias",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="include bias in conv encoder",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--logit-temp", type=float, default=0.1, help="temperature to divide logits by"
|
"--logit-temp",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="temperature to divide logits by",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -744,7 +744,9 @@ def train_one_epoch(
|
|||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale",
|
||||||
|
cur_grad_scale,
|
||||||
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
|
@ -744,7 +744,9 @@ def train_one_epoch(
|
|||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale",
|
||||||
|
cur_grad_scale,
|
||||||
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
|
@ -254,7 +254,8 @@ def quant_noise(module, p, block_size):
|
|||||||
|
|
||||||
# split weight matrix into blocks and randomly drop selected blocks
|
# split weight matrix into blocks and randomly drop selected blocks
|
||||||
mask = torch.zeros(
|
mask = torch.zeros(
|
||||||
in_features // block_size * out_features, device=weight.device
|
in_features // block_size * out_features,
|
||||||
|
device=weight.device,
|
||||||
)
|
)
|
||||||
mask.bernoulli_(p)
|
mask.bernoulli_(p)
|
||||||
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
||||||
|
@ -309,11 +309,14 @@ class TransformerEncoder(nn.Module):
|
|||||||
# layer_check = layer.unwrapped_module
|
# layer_check = layer.unwrapped_module
|
||||||
if (corpus_key is None) or (
|
if (corpus_key is None) or (
|
||||||
not isinstance(
|
not isinstance(
|
||||||
layer_check, (TransformerSentenceEncoderWithAdapterLayer,)
|
layer_check,
|
||||||
|
(TransformerSentenceEncoderWithAdapterLayer,),
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
x, (z, lr) = layer(
|
x, (z, lr) = layer(
|
||||||
x, self_attn_padding_mask=padding_mask, need_weights=False
|
x,
|
||||||
|
self_attn_padding_mask=padding_mask,
|
||||||
|
need_weights=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
x, (z, lr) = layer(
|
x, (z, lr) = layer(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user