mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use half the dim per head, in self_attn layers.
This commit is contained in:
parent
ce3f59d9c7
commit
71b3756ada
@ -469,27 +469,27 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.head_dim = embed_dim // (num_heads * 2)
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
self.head_dim * num_heads == self.embed_dim // 2
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
||||
self.in_balancer = ActivationBalancer(3 * embed_dim,
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
|
||||
self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
|
||||
channel_dim=-1, max_abs=5.0,
|
||||
max_var_per_eig=0.2)
|
||||
self.proj_balancer = ActivationBalancer(embed_dim,
|
||||
self.proj_balancer = ActivationBalancer(embed_dim // 2,
|
||||
channel_dim=-1, max_abs=10.0,
|
||||
min_positive=0.0, max_positive=1.0)
|
||||
self.out_proj = ScaledLinear(
|
||||
embed_dim, embed_dim, bias=True, initial_scale=0.5
|
||||
embed_dim // 2, embed_dim, bias=True, initial_scale=0.5
|
||||
)
|
||||
|
||||
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
|
||||
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
self.linear_pos = nn.Linear(embed_dim, embed_dim // 2, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
@ -663,9 +663,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
tgt_len, bsz, _ = x.size()
|
||||
|
||||
head_dim = embed_dim // num_heads
|
||||
head_dim = embed_dim // (num_heads * 2)
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
head_dim * num_heads == embed_dim // 2
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
scaling = float(head_dim) ** -0.5
|
||||
@ -815,7 +815,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(tgt_len, bsz, embed_dim)
|
||||
.view(tgt_len, bsz, embed_dim // 2)
|
||||
)
|
||||
attn_output = nn.functional.linear(
|
||||
attn_output, out_proj_weight, out_proj_bias
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user