mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
Add activation balancer to stop activations in self_attn from getting too large
This commit is contained in:
parent
da2ffd4d27
commit
61619c031e
@ -464,6 +464,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
||||||
|
self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0)
|
||||||
self.out_proj = ScaledLinear(
|
self.out_proj = ScaledLinear(
|
||||||
embed_dim, embed_dim, bias=True, initial_scale=0.5
|
embed_dim, embed_dim, bias=True, initial_scale=0.5
|
||||||
)
|
)
|
||||||
@ -649,9 +650,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
scaling = float(head_dim) ** -0.5
|
scaling = float(head_dim) ** -0.5
|
||||||
|
|
||||||
|
def linear(x, w, b):
|
||||||
|
return self.in_balancer(nn.functional.linear(x, w, b))
|
||||||
|
|
||||||
if torch.equal(query, key) and torch.equal(key, value):
|
if torch.equal(query, key) and torch.equal(key, value):
|
||||||
# self-attention
|
# self-attention
|
||||||
q, k, v = nn.functional.linear(
|
q, k, v = linear(
|
||||||
query, in_proj_weight, in_proj_bias
|
query, in_proj_weight, in_proj_bias
|
||||||
).chunk(3, dim=-1)
|
).chunk(3, dim=-1)
|
||||||
|
|
||||||
@ -664,7 +668,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:_end, :]
|
_w = in_proj_weight[_start:_end, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
q = nn.functional.linear(query, _w, _b)
|
q = linear(query, _w, _b)
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
@ -673,7 +677,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:, :]
|
_w = in_proj_weight[_start:, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:]
|
_b = _b[_start:]
|
||||||
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
k, v = linear(key, _w, _b).chunk(2, dim=-1)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
@ -683,7 +687,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:_end, :]
|
_w = in_proj_weight[_start:_end, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
q = nn.functional.linear(query, _w, _b)
|
q = linear(query, _w, _b)
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
@ -692,7 +696,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:_end, :]
|
_w = in_proj_weight[_start:_end, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
k = nn.functional.linear(key, _w, _b)
|
k = linear(key, _w, _b)
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
@ -701,7 +705,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:, :]
|
_w = in_proj_weight[_start:, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:]
|
_b = _b[_start:]
|
||||||
v = nn.functional.linear(value, _w, _b)
|
v = linear(value, _w, _b)
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
assert (
|
assert (
|
||||||
|
Loading…
x
Reference in New Issue
Block a user