From 71b3756ada62fb53673b3ad2feb9e8d4e0609213 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 24 Sep 2022 15:40:44 +0800 Subject: [PATCH] Use half the dim per head, in self_attn layers. --- .../pruned_transducer_stateless7/conformer.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index cf3129df2..8f8bebf4f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -469,27 +469,27 @@ class RelPositionMultiheadAttention(nn.Module): self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout - self.head_dim = embed_dim // num_heads + self.head_dim = embed_dim // (num_heads * 2) assert ( - self.head_dim * num_heads == self.embed_dim + self.head_dim * num_heads == self.embed_dim // 2 ), "embed_dim must be divisible by num_heads" - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - self.in_balancer = ActivationBalancer(3 * embed_dim, + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True) + self.in_balancer = ActivationBalancer(3 * embed_dim // 2, channel_dim=-1, max_abs=5.0, max_var_per_eig=0.2) - self.proj_balancer = ActivationBalancer(embed_dim, + self.proj_balancer = ActivationBalancer(embed_dim // 2, channel_dim=-1, max_abs=10.0, min_positive=0.0, max_positive=1.0) self.out_proj = ScaledLinear( - embed_dim, embed_dim, bias=True, initial_scale=0.5 + embed_dim // 2, embed_dim, bias=True, initial_scale=0.5 ) self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads)) self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads)) # linear transformation for positional encoding. - self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + self.linear_pos = nn.Linear(embed_dim, embed_dim // 2, bias=False) # these two learnable bias are used in matrix c and matrix d # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) @@ -663,9 +663,9 @@ class RelPositionMultiheadAttention(nn.Module): tgt_len, bsz, _ = x.size() - head_dim = embed_dim // num_heads + head_dim = embed_dim // (num_heads * 2) assert ( - head_dim * num_heads == embed_dim + head_dim * num_heads == embed_dim // 2 ), "embed_dim must be divisible by num_heads" scaling = float(head_dim) ** -0.5 @@ -815,7 +815,7 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = ( attn_output.transpose(0, 1) .contiguous() - .view(tgt_len, bsz, embed_dim) + .view(tgt_len, bsz, embed_dim // 2) ) attn_output = nn.functional.linear( attn_output, out_proj_weight, out_proj_bias