Add activation balancer to stop activations in self_attn from getting too large

This commit is contained in:
Daniel Povey 2022-06-01 00:40:45 +08:00
parent da2ffd4d27
commit 61619c031e

View File

@ -464,6 +464,7 @@ class RelPositionMultiheadAttention(nn.Module):
), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0)
self.out_proj = ScaledLinear(
embed_dim, embed_dim, bias=True, initial_scale=0.5
)
@ -649,9 +650,12 @@ class RelPositionMultiheadAttention(nn.Module):
scaling = float(head_dim) ** -0.5
def linear(x, w, b):
return self.in_balancer(nn.functional.linear(x, w, b))
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(
q, k, v = linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
@ -664,7 +668,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
q = linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
@ -673,7 +677,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
k, v = linear(key, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
@ -683,7 +687,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
q = linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
@ -692,7 +696,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = nn.functional.linear(key, _w, _b)
k = linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
@ -701,7 +705,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
v = linear(value, _w, _b)
if attn_mask is not None:
assert (