From 9e30f2bf1270bf1d686dbf4543e69cd3e9d80856 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 13 Oct 2022 12:05:45 +0800 Subject: [PATCH 1/6] Make the ActivationBalancer regress to the data mean, not zero, when enforcing abs constraint. --- .../pruned_transducer_stateless7/scaling.py | 43 ++++++++----------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index e74acb7fe..f80e42edb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -34,6 +34,7 @@ class ActivationBalancerFunction(torch.autograd.Function): def forward( ctx, x: Tensor, + mean: Tensor, sign_factor: Tensor, scale_factor: Tensor, channel_dim: int, @@ -41,8 +42,13 @@ class ActivationBalancerFunction(torch.autograd.Function): if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = (x > 0) - ctx.save_for_backward(xgt0, sign_factor, scale_factor) + for _ in range(ctx.channel_dim, x.ndim - 1): + mean = mean.unsqueeze(-1) + sign_factor = sign_factor.unsqueeze(-1) + scale_factor = scale_factor.unsqueeze(-1) + + xgtmean = (x > mean) + ctx.save_for_backward(xgtmean, sign_factor, scale_factor) return x @@ -50,14 +56,11 @@ class ActivationBalancerFunction(torch.autograd.Function): def backward( ctx, x_grad: Tensor ) -> Tuple[Tensor, None, None, None]: - xgt0, sign_factor, scale_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - sign_factor = sign_factor.unsqueeze(-1) - scale_factor = scale_factor.unsqueeze(-1) + xgtmean, sign_factor, scale_factor = ctx.saved_tensors - factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + factor = sign_factor + scale_factor * (xgtmean.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, + return x_grad - neg_delta_grad, None, None, None, None, @@ -275,6 +278,9 @@ class ActivationBalancer(torch.nn.Module): # count measures how many times the forward() function has been called. self.count = 0 + # the mean of the data per channel + self.register_buffer('mean', torch.zeros(num_channels)) + # the mean of the absolute value of the data per channel self.register_buffer('abs_mean', torch.zeros(num_channels)) @@ -307,7 +313,7 @@ class ActivationBalancer(torch.nn.Module): sign_factor = factors[0] scale_factor = factors[1] return ActivationBalancerFunction.apply( - x, sign_factor, scale_factor, self.channel_dim, + x, self.mean, sign_factor, scale_factor, self.channel_dim, ) else: return x @@ -322,6 +328,7 @@ class ActivationBalancer(torch.nn.Module): with torch.no_grad(): sum_dims = [d for d in range(x.ndim) if d != self.channel_dim] + x_mean = torch.mean(x, dim=sum_dims).to(torch.float32) x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) # the random.random() thing is to split the difference if x is zero, # between treating it positive or negative @@ -333,9 +340,11 @@ class ActivationBalancer(torch.nn.Module): mask = (y - y != 0) y.masked_fill_(mask, 0.0) + filter_inf_nan(x_mean) filter_inf_nan(x_abs_mean) beta = self.beta if count > 0 else 0.0 + self.mean.mul_(beta).add_(x_mean, alpha=(1-beta)) self.abs_mean.mul_(beta).add_(x_abs_mean, alpha=(1-beta)) self.proportion_positive.mul_(beta).add_(proportion_positive, alpha=(1-beta)) @@ -363,25 +372,11 @@ class ActivationBalancer(torch.nn.Module): # the factor of 2.0 below is just to cancel out a factor of 0.5 that gets introduced when, in # the backprop, we do (xgt0.to(dtype) - 0.5). - # - # scale_factor_scale, on the other hand, is a heuristically chosen value between 0 and 1, - # that we use to make the gradient changes from the 'scale' constraints (min_abs/max_abs) - # less strong than those from the sign constraints. - # - # This is to get rid of a pathology that can happen if, for instance, a - # channel is always positive but is too small (max_positive and min_abs constraints both - # violated). If scale_factor_scale were equal to 1.0, then the gradient changes from the - # min_positive constraint (trying to make the activation more negative) and from the - # min_abs constraint (trying to make the activation more positive) would exactly cancel. - # Instead we make the min_positive constraint stronger, so it first makes the value - # sometimes negative, and only when that is satisfied, can deal with the absolute-value - # constraint. - scale_factor_scale = 0.5 below_threshold = (self.abs_mean < self.min_abs) above_threshold = (self.abs_mean > self.max_abs) scale_factor[:] = ((below_threshold.to(torch.float32) - above_threshold.to(torch.float32)) - * (max_factor * (2.0 * scale_factor_scale))) + * (max_factor * 2.0)) class MaxEig(torch.nn.Module): From 9270e32a5187be49bfc1ab392cc411144cfe1413 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 13 Oct 2022 13:34:35 +0800 Subject: [PATCH 2/6] Remove unused config value --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index f80e42edb..7034987d9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -258,7 +258,6 @@ class ActivationBalancer(torch.nn.Module): max_factor: float = 0.01, min_abs: float = 0.2, max_abs: float = 100.0, - max_var_per_eig: float = 0.0, beta: float = 0.75, prob: float = 0.25, stats_period: int = 10, From b09a1b2ae6afccf5694f089ac7a6aaf7404ce615 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 13 Oct 2022 13:40:43 +0800 Subject: [PATCH 3/6] Fix bug when channel_dim < 0 --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 7034987d9..2f8a88681 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -325,7 +325,10 @@ class ActivationBalancer(torch.nn.Module): channel. """ with torch.no_grad(): - sum_dims = [d for d in range(x.ndim) if d != self.channel_dim] + channel_dim = self.channel_dim + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] x_mean = torch.mean(x, dim=sum_dims).to(torch.float32) x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) From 2a50def7c67736f7c366cce76e18c0c5ce70501f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 13 Oct 2022 15:07:53 +0800 Subject: [PATCH 4/6] Simplify how the positional-embedding scores work in attention (thanks to Zengwei for this concept) --- .../pruned_transducer_stateless7/conformer.py | 145 +++++++----------- 1 file changed, 56 insertions(+), 89 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 625651d3c..c00f04e31 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -839,9 +839,6 @@ class RelPositionMultiheadAttention(nn.Module): channel_dim=-1, max_abs=5.0) self.in_max_eig = MaxEig(3 * embed_dim // 2, channel_dim=-1) - self.proj_balancer = ActivationBalancer(embed_dim // 2, - channel_dim=-1, max_abs=10.0, - min_positive=0.0, max_positive=1.0) self.out_proj = ScaledLinear( embed_dim // 2, embed_dim, bias=True, initial_scale=0.05 ) @@ -850,18 +847,9 @@ class RelPositionMultiheadAttention(nn.Module): self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True, initial_scale=0.05) - - # linear transformation for positional encoding. - self.linear_pos = nn.Linear(embed_dim, embed_dim // 2, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self._reset_parameters() - - def _reset_parameters(self) -> None: - nn.init.uniform_(self.pos_bias_u, -0.05, 0.05) - nn.init.uniform_(self.pos_bias_v, -0.05, 0.05) + # linear transformation for positional encoding (projects to a scalar per head, + # which will be added to the score). + self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05) def forward( self, @@ -909,7 +897,7 @@ class RelPositionMultiheadAttention(nn.Module): """ x, weights = self.multi_head_attention_forward( self.in_max_eig(self.in_balancer(self.in_proj(x))), - pos_emb, + self.linear_pos(pos_emb), self.embed_dim, self.num_heads, self.in_proj.weight, @@ -923,35 +911,44 @@ class RelPositionMultiheadAttention(nn.Module): ) return x, weights - def rel_shift(self, x: Tensor) -> Tensor: - """Compute relative positional encoding. + def rel_shift(self, pos_bias: Tensor) -> Tensor: + """Convert relative positional bias from linear to matrix format. Args: - x: Input tensor (batch, head, time1, 2*time1-1). - time1 means the length of query vector. + pos_bias: Input tensor (1, 2*T-1, num_heads), where T is the number of frames. Returns: - Tensor: tensor of shape (batch, head, time1, time2) - (note: time2 has the same value as time1, but it is for - the key, while time1 is for the query). + Tensor of shape (1, num_heads, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). """ - (batch_size, num_heads, time1, n) = x.shape - assert n == 2 * time1 - 1 + (batch_size, n, num_heads) = pos_bias.shape + assert batch_size == 1 + T = (n + 1) // 2 + assert n == 2 * T - 1 + # The leading T dimension behaves like a batch dimension. + # It is only needed because PyTorch does not currently support + # negative strides. + pos_bias = pos_bias.expand(T, n, num_heads).contiguous() + # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time1), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), + batch_stride = pos_bias.stride(0) + time_stride = pos_bias.stride(1) + head_stride = pos_bias.stride(2) + + # We could have left the batch dim as 1, and used '-time_stride' below + # where we use 'batch_stride - time_stride', but PyTorch does not support negative + # strides. + return pos_bias.as_strided( + (1, num_heads, T, T), + (0, head_stride, batch_stride - time_stride, time_stride), + storage_offset=time_stride * (T - 1), ) def multi_head_attention_forward( self, x: Tensor, - pos_emb: Tensor, + pos: Tensor, embed_dim: int, num_heads: int, in_proj_weight: Tensor, @@ -965,8 +962,8 @@ class RelPositionMultiheadAttention(nn.Module): ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor + x_proj: the projected input, to be split into query, key, value. + pos: head-specific biases arising from the positional embeddings. embed_dim: total dimension of the model. num_heads: parallel attention heads. in_proj_weight, in_proj_bias: input projection weight and bias. @@ -981,14 +978,10 @@ class RelPositionMultiheadAttention(nn.Module): Shape: Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence - length, N is the batch size, E is the embedding dimension. + - x: :math:`(L, N, 3 * E//2)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. Will be split into (query, key, value). + - pos: :math:`(N, 2*L-1, H)` or :math:`(1, 2*L-1, H)` where L is the sequence + length, N is the batch size, and H is the number of heads. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions will be unchanged. If a BoolTensor is provided, the positions with the @@ -1008,7 +1001,7 @@ class RelPositionMultiheadAttention(nn.Module): H is the num-heads, S is the sequence length. """ - tgt_len, bsz, _ = x.size() + seq_len, bsz, _ = x.size() head_dim = embed_dim // (num_heads * 2) assert ( @@ -1040,15 +1033,15 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, tgt_len, tgt_len]: + if list(attn_mask.size()) != [1, seq_len, seq_len]: raise RuntimeError( "The size of the 2D attn_mask is not correct." ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, - tgt_len, - tgt_len, + seq_len, + seq_len, ]: raise RuntimeError( "The size of the 3D attn_mask is not correct." @@ -1071,63 +1064,37 @@ class RelPositionMultiheadAttention(nn.Module): ) key_padding_mask = key_padding_mask.to(torch.bool) - q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) + q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim) v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - src_len = k.size(0) if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz ) - assert key_padding_mask.size(1) == src_len, "{} == {}".format( - key_padding_mask.size(1), src_len + assert key_padding_mask.size(1) == seq_len, "{} == {}".format( + key_padding_mask.size(1), seq_len ) - q = q.transpose(0, 1) # (batch, time1, head, d_k) - - pos_emb_bsz = pos_emb.size(0) - assert pos_emb_bsz in (1, bsz) # actually it is 1 - p = self.proj_balancer(self.linear_pos(pos_emb)).view(pos_emb_bsz, -1, num_heads, head_dim) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - - q_with_bias_u = (q + self.pos_bias_u).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - q_with_bias_v = (q + self.pos_bias_v).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - + q = q.permute(1, 2, 0, 3) # (batch head, time1, head_dim) # compute attention score - # first compute matrix a and matrix c - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) - - # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p.transpose(-2, -1) - ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd) - - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + # pos_bias: (batch, head, time1, time2) + pos_bias = self.rel_shift(pos) + attn_output_weights = torch.matmul(q, k) + pos_bias + # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 + bsz * num_heads, seq_len, seq_len ) assert list(attn_output_weights.size()) == [ bsz * num_heads, - tgt_len, - src_len, + seq_len, + seq_len, ] if attn_mask is not None: @@ -1138,14 +1105,14 @@ class RelPositionMultiheadAttention(nn.Module): if key_padding_mask is not None: attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len + bsz, num_heads, seq_len, seq_len ) attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), ) attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len + bsz * num_heads, seq_len, seq_len ) attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) @@ -1154,11 +1121,11 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim] attn_output = ( attn_output.transpose(0, 1) .contiguous() - .view(tgt_len, bsz, embed_dim // 2) + .view(seq_len, bsz, embed_dim // 2) ) attn_output = nn.functional.linear( attn_output, out_proj_weight, out_proj_bias From 7d8e460a53fb118a10b487a877bc25d5bbdf146a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 13 Oct 2022 15:09:50 +0800 Subject: [PATCH 5/6] Revert dropout on attention scores to 0.0. --- egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index c00f04e31..177aa3c3b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -263,7 +263,7 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=dropout, + d_model, nhead, dropout=0.0, ) self.feed_forward1 = FeedforwardModule(d_model, From ae6478c6873b618f0179f9ea90ca65608481f259 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 13 Oct 2022 19:41:28 +0800 Subject: [PATCH 6/6] This should just be a cosmetic change, regularizing how we get the warmup times from the layers. --- .../ASR/pruned_transducer_stateless7/conformer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 177aa3c3b..10527a7a5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -436,16 +436,14 @@ class ConformerEncoder(nn.Module): def get_layers_to_drop(self, rnd_seed: int, warmup_count: float): num_layers = len(self.layers) - warmup_begin = self.warmup_begin - warmup_end = self.warmup_end def get_layerdrop_prob(layer: int) -> float: - layer_warmup_delta = (warmup_end - warmup_begin) / num_layers - layer_warmup_begin = warmup_begin + layer * layer_warmup_delta + layer_warmup_begin = self.layers[layer].warmup_begin + layer_warmup_end = self.layers[layer].warmup_end + initial_layerdrop_prob = 0.5 final_layerdrop_prob = 0.05 - layer_warmup_end = layer_warmup_begin + layer_warmup_delta if warmup_count < layer_warmup_begin: return initial_layerdrop_prob elif warmup_count > layer_warmup_end: @@ -483,7 +481,7 @@ class ConformerEncoder(nn.Module): if len(ans) == num_to_drop: break if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info(f"warmup_begin={warmup_begin:.1f}, warmup_end={warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") + logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") return ans