From 809bdb07f0b945363a34f4d0a670bf24a8905133 Mon Sep 17 00:00:00 2001 From: yifanyeung Date: Sun, 18 Feb 2024 12:44:44 +0800 Subject: [PATCH] fix for black --- .../SSL/hubert/attention_module.py | 161 +++++++++++++----- egs/librispeech/SSL/hubert/finetune.py | 4 +- egs/librispeech/SSL/hubert/finetune_ce.py | 4 +- egs/librispeech/SSL/hubert/hubert.py | 24 ++- egs/librispeech/SSL/hubert/hubert_ce.py | 24 ++- egs/librispeech/SSL/hubert/pretrain.py | 4 +- egs/librispeech/SSL/hubert/pretrain_ce.py | 4 +- egs/librispeech/SSL/hubert/utils.py | 3 +- egs/librispeech/SSL/hubert/wav2vec2_module.py | 7 +- 9 files changed, 179 insertions(+), 56 deletions(-) diff --git a/egs/librispeech/SSL/hubert/attention_module.py b/egs/librispeech/SSL/hubert/attention_module.py index 54aaf6075..8e47ed7ab 100644 --- a/egs/librispeech/SSL/hubert/attention_module.py +++ b/egs/librispeech/SSL/hubert/attention_module.py @@ -155,7 +155,8 @@ class MultiheadAttention(nn.Module): self.encoder_decoder_attention = encoder_decoder_attention assert not self.self_attention or self.qkv_same_dim, ( - "Self-attention requires query, key and " "value to be of the same size" + "Self-attention requires query, key and " + "value to be of the same size" ) self.k_proj = quant_noise( @@ -217,22 +218,36 @@ class MultiheadAttention(nn.Module): start_idx = i * self.head_dim end_idx = (i + 1) * self.head_dim k_proj_heads_norm.append( - torch.sum(torch.abs(self.k_proj.weight[start_idx:end_idx,])).tolist() - + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist() + torch.sum( + torch.abs(self.k_proj.weight[start_idx:end_idx,]) + ).tolist() + + torch.sum( + torch.abs(self.k_proj.bias[start_idx:end_idx]) + ).tolist() ) q_proj_heads_norm.append( - torch.sum(torch.abs(self.q_proj.weight[start_idx:end_idx,])).tolist() - + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist() + torch.sum( + torch.abs(self.q_proj.weight[start_idx:end_idx,]) + ).tolist() + + torch.sum( + torch.abs(self.q_proj.bias[start_idx:end_idx]) + ).tolist() ) v_proj_heads_norm.append( - torch.sum(torch.abs(self.v_proj.weight[start_idx:end_idx,])).tolist() - + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist() + torch.sum( + torch.abs(self.v_proj.weight[start_idx:end_idx,]) + ).tolist() + + torch.sum( + torch.abs(self.v_proj.bias[start_idx:end_idx]) + ).tolist() ) heads_norm = [] for i in range(self.num_heads): heads_norm.append( - k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i] + k_proj_heads_norm[i] + + q_proj_heads_norm[i] + + v_proj_heads_norm[i] ) sorted_head_index = sorted( @@ -266,7 +281,9 @@ class MultiheadAttention(nn.Module): new_v_weight.append(self.v_proj.weight[start_idx:end_idx,]) new_v_bias.append(self.v_proj.bias[start_idx:end_idx]) - new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx]) + new_out_proj_weight.append( + self.out_proj.weight[:, start_idx:end_idx] + ) new_q_weight = torch.cat(new_q_weight).detach() new_k_weight = torch.cat(new_k_weight).detach() @@ -313,7 +330,9 @@ class MultiheadAttention(nn.Module): ) -> Tuple[Optional[Tensor], Optional[Tensor]]: if attn_mask is not None: shape = attn_mask.size()[:-1] + torch.Size([1]) - attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1) + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(shape)], dim=-1 + ) if key_padding_mask is not None: shape = key_padding_mask.size()[:-1] + torch.Size([1]) key_padding_mask = torch.cat( @@ -351,10 +370,12 @@ class MultiheadAttention(nn.Module): ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:] k = torch.cat( - [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2 + [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], + dim=-2, ) v = torch.cat( - [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2 + [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], + dim=-2, ) key_padding_mask, attn_mask = self._pad_masks( key_padding_mask=key_padding_mask, attn_mask=attn_mask @@ -367,7 +388,9 @@ class MultiheadAttention(nn.Module): key: Optional[Tensor], value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, - incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + incremental_state: Optional[ + Dict[str, Dict[str, Optional[Tensor]]] + ] = None, need_weights: bool = True, static_kv: bool = False, attn_mask: Optional[Tensor] = None, @@ -432,7 +455,9 @@ class MultiheadAttention(nn.Module): self.embed_dim, self.num_heads, torch.empty([0]), - torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + torch.cat( + (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias) + ), self.bias_k, self.bias_v, self.add_zero_attn, @@ -440,7 +465,9 @@ class MultiheadAttention(nn.Module): self.out_proj.weight, self.out_proj.bias, self.training or self.dropout_module.apply_during_inference, - key_padding_mask.bool() if key_padding_mask is not None else None, + key_padding_mask.bool() + if key_padding_mask is not None + else None, need_weights, attn_mask, use_separate_proj_weight=True, @@ -455,7 +482,10 @@ class MultiheadAttention(nn.Module): # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: - assert self.encoder_decoder_attention and not self.self_attention + assert ( + self.encoder_decoder_attention + and not self.self_attention + ) key = value = None else: saved_state = None @@ -473,9 +503,9 @@ class MultiheadAttention(nn.Module): else: if self.beam_size > 1 and bsz == key.size(1): # key is [T, bsz*beam_size, C], reduce to [T, bsz, C] - key = key.view(key.size(0), -1, self.beam_size, key.size(2))[ - :, :, 0, : - ] + key = key.view( + key.size(0), -1, self.beam_size, key.size(2) + )[:, :, 0, :] if key_padding_mask is not None: key_padding_mask = key_padding_mask.view( -1, self.beam_size, key_padding_mask.size(1) @@ -522,7 +552,9 @@ class MultiheadAttention(nn.Module): _prev_key = saved_state["prev_key"] assert _prev_key is not None kv_bsz = _prev_key.size(0) - prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim) + prev_key = _prev_key.view( + kv_bsz * self.num_heads, -1, self.head_dim + ) if static_kv: k = prev_key else: @@ -553,14 +585,18 @@ class MultiheadAttention(nn.Module): static_kv=static_kv, ) - saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key"] = k.view( + kv_bsz, self.num_heads, -1, self.head_dim + ) saved_state["prev_value"] = v.view( kv_bsz, self.num_heads, -1, self.head_dim ) saved_state["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None - incremental_state = self._set_input_buffer(incremental_state, saved_state) + incremental_state = self._set_input_buffer( + incremental_state, saved_state + ) assert k is not None assert k.size(1) == src_len @@ -586,12 +622,20 @@ class MultiheadAttention(nn.Module): q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), k.view((kv_bsz, self.num_heads) + k.size()[1:]), ) - attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:]) + attn_weights = attn_weights.reshape( + (-1,) + attn_weights.size()[-2:] + ) else: attn_weights = torch.bmm(q, k.transpose(1, 2)) - attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + attn_weights = self.apply_sparse_mask( + attn_weights, tgt_len, src_len, bsz + ) - assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + assert list(attn_weights.size()) == [ + bsz * self.num_heads, + tgt_len, + src_len, + ] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) @@ -601,7 +645,9 @@ class MultiheadAttention(nn.Module): if key_padding_mask is not None: # don't attend to padding symbols - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) if not is_tpu: attn_weights = attn_weights.view( kv_bsz, -1, self.num_heads, tgt_len, src_len @@ -615,9 +661,13 @@ class MultiheadAttention(nn.Module): ) else: attn_weights = attn_weights.transpose(0, 2) - attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.masked_fill( + key_padding_mask, float("-inf") + ) attn_weights = attn_weights.transpose(0, 2) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view( + bsz * self.num_heads, tgt_len, src_len + ) if before_softmax: return attn_weights, v @@ -652,13 +702,21 @@ class MultiheadAttention(nn.Module): attn = attn.reshape((-1,) + attn.size()[-2:]) else: attn = torch.bmm(attn_probs, v) - assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + assert list(attn.size()) == [ + bsz * self.num_heads, + tgt_len, + self.head_dim, + ] if self.onnx_trace and attn.size(1) == 1: # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim) else: - attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) + attn = ( + attn.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, self.embed_dim) + ) attn = self.out_proj(attn) attn_weights: Optional[Tensor] = None if need_weights: @@ -728,7 +786,9 @@ class MultiheadAttention(nn.Module): input_buffer_k = input_buffer[k] if input_buffer_k is not None: if self.encoder_decoder_attention: - if input_buffer_k.size(0) * self.beam_size == new_order.size(0): + if input_buffer_k.size( + 0 + ) * self.beam_size == new_order.size(0): return incremental_state elif self.beam_size > 1: input_buffer[k] = input_buffer_k.index_select( @@ -737,10 +797,16 @@ class MultiheadAttention(nn.Module): // self.beam_size, ) else: - input_buffer[k] = input_buffer_k.index_select(0, new_order) + input_buffer[k] = input_buffer_k.index_select( + 0, new_order + ) else: - input_buffer[k] = input_buffer_k.index_select(0, new_order) - incremental_state = self._set_input_buffer(incremental_state, input_buffer) + input_buffer[k] = input_buffer_k.index_select( + 0, new_order + ) + incremental_state = self._set_input_buffer( + incremental_state, input_buffer + ) return incremental_state def set_beam_size(self, beam_size): @@ -748,7 +814,8 @@ class MultiheadAttention(nn.Module): self.beam_size = beam_size def _get_input_buffer( - self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], ) -> Dict[str, Optional[Tensor]]: result = self.get_incremental_state(incremental_state, "attn_state") if result is not None: @@ -762,9 +829,13 @@ class MultiheadAttention(nn.Module): incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], buffer: Dict[str, Optional[Tensor]], ): - return self.set_incremental_state(incremental_state, "attn_state", buffer) + return self.set_incremental_state( + incremental_state, "attn_state", buffer + ) - def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + def apply_sparse_mask( + self, attn_weights, tgt_len: int, src_len: int, bsz: int + ): return attn_weights def upgrade_state_dict_named(self, state_dict, name): @@ -776,19 +847,27 @@ class MultiheadAttention(nn.Module): # in_proj_weight used to be q + k + v with same dimensions dim = int(state_dict[k].shape[0] / 3) items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] - items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] - items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] + items_to_add[prefix + "k_proj.weight"] = state_dict[k][ + dim : 2 * dim + ] + items_to_add[prefix + "v_proj.weight"] = state_dict[k][ + 2 * dim : + ] keys_to_remove.append(k) k_bias = prefix + "in_proj_bias" if k_bias in state_dict.keys(): dim = int(state_dict[k].shape[0] / 3) - items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] + items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][ + :dim + ] items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ dim : 2 * dim ] - items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] + items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][ + 2 * dim : + ] keys_to_remove.append(prefix + "in_proj_bias") diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py index c7ce5995f..cc6b25590 100644 --- a/egs/librispeech/SSL/hubert/finetune.py +++ b/egs/librispeech/SSL/hubert/finetune.py @@ -925,7 +925,9 @@ def train_one_epoch( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, ) if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py index 0e6221778..d56c36665 100644 --- a/egs/librispeech/SSL/hubert/finetune_ce.py +++ b/egs/librispeech/SSL/hubert/finetune_ce.py @@ -925,7 +925,9 @@ def train_one_epoch( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, ) if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: diff --git a/egs/librispeech/SSL/hubert/hubert.py b/egs/librispeech/SSL/hubert/hubert.py index 29018ba23..f800044f4 100644 --- a/egs/librispeech/SSL/hubert/hubert.py +++ b/egs/librispeech/SSL/hubert/hubert.py @@ -251,7 +251,10 @@ def add_hubert_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( - "--encoder-embed-dim", type=int, default=768, help="encoder embedding dimension" + "--encoder-embed-dim", + type=int, + default=768, + help="encoder embedding dimension", ) parser.add_argument( @@ -271,7 +274,14 @@ def add_hubert_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--activation-fn", type=str, - choices=["relu", "gelu", "gelu_fast", "gelu_accurate", "tanh", "linear"], + choices=[ + "relu", + "gelu", + "gelu_fast", + "gelu_accurate", + "tanh", + "linear", + ], default="gelu", help="activation function to use", ) @@ -356,11 +366,17 @@ def add_hubert_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( - "--conv-bias", type=bool, default=False, help="include bias in conv encoder" + "--conv-bias", + type=bool, + default=False, + help="include bias in conv encoder", ) parser.add_argument( - "--logit-temp", type=float, default=0.1, help="temperature to divide logits by" + "--logit-temp", + type=float, + default=0.1, + help="temperature to divide logits by", ) parser.add_argument( diff --git a/egs/librispeech/SSL/hubert/hubert_ce.py b/egs/librispeech/SSL/hubert/hubert_ce.py index c2c50f8c9..ccdd63efd 100644 --- a/egs/librispeech/SSL/hubert/hubert_ce.py +++ b/egs/librispeech/SSL/hubert/hubert_ce.py @@ -251,7 +251,10 @@ def add_hubert_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( - "--encoder-embed-dim", type=int, default=768, help="encoder embedding dimension" + "--encoder-embed-dim", + type=int, + default=768, + help="encoder embedding dimension", ) parser.add_argument( @@ -271,7 +274,14 @@ def add_hubert_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--activation-fn", type=str, - choices=["relu", "gelu", "gelu_fast", "gelu_accurate", "tanh", "linear"], + choices=[ + "relu", + "gelu", + "gelu_fast", + "gelu_accurate", + "tanh", + "linear", + ], default="gelu", help="activation function to use", ) @@ -356,11 +366,17 @@ def add_hubert_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( - "--conv-bias", type=bool, default=False, help="include bias in conv encoder" + "--conv-bias", + type=bool, + default=False, + help="include bias in conv encoder", ) parser.add_argument( - "--logit-temp", type=float, default=0.1, help="temperature to divide logits by" + "--logit-temp", + type=float, + default=0.1, + help="temperature to divide logits by", ) parser.add_argument( diff --git a/egs/librispeech/SSL/hubert/pretrain.py b/egs/librispeech/SSL/hubert/pretrain.py index 23279e6e6..89bc53338 100644 --- a/egs/librispeech/SSL/hubert/pretrain.py +++ b/egs/librispeech/SSL/hubert/pretrain.py @@ -744,7 +744,9 @@ def train_one_epoch( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, ) if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: diff --git a/egs/librispeech/SSL/hubert/pretrain_ce.py b/egs/librispeech/SSL/hubert/pretrain_ce.py index 5b83e9d00..755be202e 100644 --- a/egs/librispeech/SSL/hubert/pretrain_ce.py +++ b/egs/librispeech/SSL/hubert/pretrain_ce.py @@ -744,7 +744,9 @@ def train_one_epoch( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, ) if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: diff --git a/egs/librispeech/SSL/hubert/utils.py b/egs/librispeech/SSL/hubert/utils.py index 748d3c96e..de980ba62 100644 --- a/egs/librispeech/SSL/hubert/utils.py +++ b/egs/librispeech/SSL/hubert/utils.py @@ -254,7 +254,8 @@ def quant_noise(module, p, block_size): # split weight matrix into blocks and randomly drop selected blocks mask = torch.zeros( - in_features // block_size * out_features, device=weight.device + in_features // block_size * out_features, + device=weight.device, ) mask.bernoulli_(p) mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) diff --git a/egs/librispeech/SSL/hubert/wav2vec2_module.py b/egs/librispeech/SSL/hubert/wav2vec2_module.py index e9bc9e3b5..4c2e1ce98 100644 --- a/egs/librispeech/SSL/hubert/wav2vec2_module.py +++ b/egs/librispeech/SSL/hubert/wav2vec2_module.py @@ -309,11 +309,14 @@ class TransformerEncoder(nn.Module): # layer_check = layer.unwrapped_module if (corpus_key is None) or ( not isinstance( - layer_check, (TransformerSentenceEncoderWithAdapterLayer,) + layer_check, + (TransformerSentenceEncoderWithAdapterLayer,), ) ): x, (z, lr) = layer( - x, self_attn_padding_mask=padding_mask, need_weights=False + x, + self_attn_padding_mask=padding_mask, + need_weights=False, ) else: x, (z, lr) = layer(