Remove ReLU in attention

This commit is contained in:
Daniel Povey 2022-02-14 19:39:19 +08:00
parent d187ad8b73
commit 2af1b3af98
2 changed files with 7 additions and 7 deletions

View File

@ -629,7 +629,7 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value): if torch.equal(query, key) and torch.equal(key, value):
# self-attention # self-attention
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).relu().chunk(3, dim=-1) q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
elif torch.equal(key, value): elif torch.equal(key, value):
# encoder-decoder attention # encoder-decoder attention
@ -640,7 +640,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:_end, :] _w = in_proj_weight[_start:_end, :]
if _b is not None: if _b is not None:
_b = _b[_start:_end] _b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b).relu() q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias # This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias _b = in_proj_bias
@ -649,7 +649,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:, :] _w = in_proj_weight[_start:, :]
if _b is not None: if _b is not None:
_b = _b[_start:] _b = _b[_start:]
k, v = nn.functional.linear(key, _w, _b).relu().chunk(2, dim=-1) k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
else: else:
# This is inline in_proj function with in_proj_weight and in_proj_bias # This is inline in_proj function with in_proj_weight and in_proj_bias
@ -659,7 +659,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:_end, :] _w = in_proj_weight[_start:_end, :]
if _b is not None: if _b is not None:
_b = _b[_start:_end] _b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b).relu() q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias # This is inline in_proj function with in_proj_weight and in_proj_bias
@ -669,7 +669,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:_end, :] _w = in_proj_weight[_start:_end, :]
if _b is not None: if _b is not None:
_b = _b[_start:_end] _b = _b[_start:_end]
k = nn.functional.linear(key, _w, _b).relu() k = nn.functional.linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias # This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias _b = in_proj_bias
@ -678,7 +678,7 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:, :] _w = in_proj_weight[_start:, :]
if _b is not None: if _b is not None:
_b = _b[_start:] _b = _b[_start:]
v = nn.functional.linear(value, _w, _b).relu() v = nn.functional.linear(value, _w, _b)
if attn_mask is not None: if attn_mask is not None:

View File

@ -109,7 +109,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.15_fix", default="transducer_stateless/exp-100h-specaugmod_p0.9_0.15_fix",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved