mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Remove learnable offset, use relu instead.
This commit is contained in:
parent
48a764eccf
commit
a859dcb205
@ -440,19 +440,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
||||
|
||||
self.in_proj_floor_scale = 10.0 # so it learns fast enough..
|
||||
with torch.no_grad():
|
||||
in_proj_floor = torch.Tensor(3 * embed_dim)
|
||||
# key and query get a floor value quite close to zero.
|
||||
in_proj_floor[:2*embed_dim] = -0.2 / self.in_proj_floor_scale
|
||||
# value gets very low floor, may be close to having no effectc.
|
||||
in_proj_floor[2*embed_dim:] = -1.5 / self.in_proj_floor_scale
|
||||
self.in_proj_floor = nn.Parameter(in_proj_floor)
|
||||
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
@ -537,7 +526,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
in_proj_floor=self.in_proj_floor*self.in_proj_floor_scale
|
||||
)
|
||||
|
||||
def rel_shift(self, x: Tensor) -> Tensor:
|
||||
@ -582,7 +570,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
in_proj_floor: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
@ -642,12 +629,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
if torch.equal(query, key) and torch.equal(key, value):
|
||||
# self-attention
|
||||
_qkv = nn.functional.linear(
|
||||
query, in_proj_weight, in_proj_bias
|
||||
)
|
||||
if in_proj_floor is not None:
|
||||
_qkv = torch.maximum(_qkv, in_proj_floor)
|
||||
q, k, v = _qkv.chunk(3, dim=-1)
|
||||
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).relu().chunk(3, dim=-1)
|
||||
|
||||
elif torch.equal(key, value):
|
||||
# encoder-decoder attention
|
||||
@ -658,10 +640,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = nn.functional.linear(query, _w, _b)
|
||||
if in_proj_floor is not None:
|
||||
_f = in_proj_floor[_start:_end]
|
||||
q = torch.maximum(q, _f)
|
||||
q = nn.functional.linear(query, _w, _b).relu()
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
@ -670,11 +649,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
_kv = nn.functional.linear(key, _w, _b)
|
||||
if in_proj_floor is not None:
|
||||
_f = in_proj_floor[_start:_end]
|
||||
_kv = torch.maximum(_kv, _f)
|
||||
k, v = _kv.chunk(2, dim=-1)
|
||||
k, v = nn.functional.linear(key, _w, _b).relu().chunk(2, dim=-1)
|
||||
|
||||
else:
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
@ -684,10 +659,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = nn.functional.linear(query, _w, _b)
|
||||
if in_proj_floor is not None:
|
||||
_f = in_proj_floor[_start:_end]
|
||||
q = torch.maximum(q, _f)
|
||||
q = nn.functional.linear(query, _w, _b).relu()
|
||||
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
@ -697,10 +669,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
k = nn.functional.linear(key, _w, _b)
|
||||
if in_proj_floor is not None:
|
||||
_f = in_proj_floor[_start:_end]
|
||||
k = torch.maximum(k, _f)
|
||||
k = nn.functional.linear(key, _w, _b).relu()
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
@ -709,10 +678,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
v = nn.functional.linear(value, _w, _b)
|
||||
if in_proj_floor is not None:
|
||||
_f = in_proj_floor[_start:_end]
|
||||
v = torch.maximum(v, _f)
|
||||
v = nn.functional.linear(value, _w, _b).relu()
|
||||
|
||||
|
||||
if attn_mask is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user