Remove ReLU in attention

This commit is contained in:
Daniel Povey 2022-02-14 19:39:19 +08:00
parent d187ad8b73
commit 2af1b3af98
2 changed files with 7 additions and 7 deletions

View File

@ -629,7 +629,7 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).relu().chunk(3, dim=-1)
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
@ -640,7 +640,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b).relu()
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
@ -649,7 +649,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = nn.functional.linear(key, _w, _b).relu().chunk(2, dim=-1)
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
@ -659,7 +659,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b).relu()
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
@ -669,7 +669,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = nn.functional.linear(key, _w, _b).relu()
k = nn.functional.linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
@ -678,7 +678,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b).relu()
v = nn.functional.linear(value, _w, _b)
if attn_mask is not None:

View File

@ -109,7 +109,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.15_fix",
default="transducer_stateless/exp-100h-specaugmod_p0.9_0.15_fix",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved