mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Remove ReLU in attention
This commit is contained in:
parent
d187ad8b73
commit
2af1b3af98
@ -629,7 +629,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
if torch.equal(query, key) and torch.equal(key, value):
|
if torch.equal(query, key) and torch.equal(key, value):
|
||||||
# self-attention
|
# self-attention
|
||||||
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).relu().chunk(3, dim=-1)
|
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
||||||
|
|
||||||
elif torch.equal(key, value):
|
elif torch.equal(key, value):
|
||||||
# encoder-decoder attention
|
# encoder-decoder attention
|
||||||
@ -640,7 +640,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:_end, :]
|
_w = in_proj_weight[_start:_end, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
q = nn.functional.linear(query, _w, _b).relu()
|
q = nn.functional.linear(query, _w, _b)
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
@ -649,7 +649,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:, :]
|
_w = in_proj_weight[_start:, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:]
|
_b = _b[_start:]
|
||||||
k, v = nn.functional.linear(key, _w, _b).relu().chunk(2, dim=-1)
|
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
@ -659,7 +659,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:_end, :]
|
_w = in_proj_weight[_start:_end, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
q = nn.functional.linear(query, _w, _b).relu()
|
q = nn.functional.linear(query, _w, _b)
|
||||||
|
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
@ -669,7 +669,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:_end, :]
|
_w = in_proj_weight[_start:_end, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
k = nn.functional.linear(key, _w, _b).relu()
|
k = nn.functional.linear(key, _w, _b)
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
@ -678,7 +678,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:, :]
|
_w = in_proj_weight[_start:, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:]
|
_b = _b[_start:]
|
||||||
v = nn.functional.linear(value, _w, _b).relu()
|
v = nn.functional.linear(value, _w, _b)
|
||||||
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
|
@ -109,7 +109,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.15_fix",
|
default="transducer_stateless/exp-100h-specaugmod_p0.9_0.15_fix",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
Loading…
x
Reference in New Issue
Block a user