diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 951fb863d..f876b15ce 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -843,7 +843,7 @@ class RelPositionMultiheadAttention(nn.Module): # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. - self.in_proj = ScaledLinear(embed_dim, 7 * attention_dim // 2, bias=True, + self.in_proj = ScaledLinear(embed_dim, 3 * attention_dim, bias=True, initial_scale=self.head_dim**-0.25) # self.whiten_values is applied on the values in forward() @@ -863,14 +863,14 @@ class RelPositionMultiheadAttention(nn.Module): self.linear_pos = ScaledLinear(embed_dim, attention_dim // 2, bias=False, initial_scale=0.05) - self.in_balancer = ActivationBalancer(7 * attention_dim // 2, + self.in_balancer = ActivationBalancer(3 * attention_dim, channel_dim=-1, max_abs=5.0) self.out_proj = ScaledLinear( - attention_dim, embed_dim, bias=True, initial_scale=0.05 + attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 ) - self.in_proj2 = nn.Linear(embed_dim, attention_dim, bias=False) - self.out_proj2 = ScaledLinear(attention_dim, embed_dim, bias=True, + self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) + self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True, initial_scale=0.05) # self.whiten_values2 is applied on the values in forward2() self.whiten_values2 = Whiten(num_groups=num_heads, @@ -1005,10 +1005,8 @@ class RelPositionMultiheadAttention(nn.Module): # self-attention - q = x[...,:attention_dim] - k = x[...,attention_dim:2*attention_dim] - v = x[...,2*attention_dim:3*attention_dim] - p = x[...,3*attention_dim:] + q, k, pv = x.chunk(3, dim=-1) + p, v = pv.chunk(2, dim=-1) k = self.whiten_keys(k) # does nothing in the forward pass. @@ -1066,7 +1064,7 @@ class RelPositionMultiheadAttention(nn.Module): q = q.reshape(seq_len, bsz, num_heads, head_dim) p = p.reshape(seq_len, bsz, num_heads, head_dim // 2) k = k.reshape(seq_len, bsz, num_heads, head_dim) - v = v.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1) + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) if key_padding_mask is not None: @@ -1140,11 +1138,11 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim] + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] attn_output = ( attn_output.transpose(0, 1) .contiguous() - .view(seq_len, bsz, attention_dim) + .view(seq_len, bsz, attention_dim // 2) ) attn_output = nn.functional.linear( attn_output, out_proj_weight, out_proj_bias @@ -1173,9 +1171,9 @@ class RelPositionMultiheadAttention(nn.Module): # v: (tgt_len, bsz, embed_dim // 2) v = self.in_proj2(x) v = self.whiten_values2(v) # does nothing in the forward pass. - v = v.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1) + v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - # now v: (bsz * num_heads, seq_len, head_dim) + # now v: (bsz * num_heads, seq_len, head_dim // 2) attn_output = torch.bmm(attn_weights, v) if random.random() < 0.001 or __name__ == "__main__": @@ -1185,7 +1183,7 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = ( attn_output.transpose(0, 1) .contiguous() - .view(seq_len, bsz, self.attention_dim) + .view(seq_len, bsz, self.attention_dim // 2) ) # returned value is of shape (seq_len, bsz, embed_dim), like x. return self.out_proj2(attn_output)