Remove learnable offset, use relu instead.

See https://github.com/k2-fsa/icefall/pull/199
This commit is contained in:
Fangjun Kuang 2022-02-07 19:02:47 +08:00
parent b3ea50126a
commit f2a45eb38d

View File

@ -629,9 +629,11 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
q, k, v = (
nn.functional.linear(query, in_proj_weight, in_proj_bias)
.relu()
.chunk(3, dim=-1)
)
elif torch.equal(key, value):
# encoder-decoder attention
@ -642,7 +644,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
q = nn.functional.linear(query, _w, _b).relu()
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
@ -650,7 +652,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
k, v = nn.functional.linear(key, _w, _b).relu().chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
@ -660,7 +662,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
q = nn.functional.linear(query, _w, _b).relu()
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
@ -669,7 +671,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = nn.functional.linear(key, _w, _b)
k = nn.functional.linear(key, _w, _b).relu()
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
@ -678,7 +680,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
v = nn.functional.linear(value, _w, _b).relu()
if attn_mask is not None:
assert (