Add min in q,k,v of attention

This commit is contained in:
Daniel Povey 2022-02-06 21:19:37 +08:00
parent 8f8ec223a7
commit 48a764eccf
2 changed files with 49 additions and 3 deletions

View File

@ -440,8 +440,19 @@ class RelPositionMultiheadAttention(nn.Module):
), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.in_proj_floor_scale = 10.0 # so it learns fast enough..
with torch.no_grad():
in_proj_floor = torch.Tensor(3 * embed_dim)
# key and query get a floor value quite close to zero.
in_proj_floor[:2*embed_dim] = -0.2 / self.in_proj_floor_scale
# value gets very low floor, may be close to having no effectc.
in_proj_floor[2*embed_dim:] = -1.5 / self.in_proj_floor_scale
self.in_proj_floor = nn.Parameter(in_proj_floor)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
@ -526,6 +537,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
in_proj_floor=self.in_proj_floor*self.in_proj_floor_scale
)
def rel_shift(self, x: Tensor) -> Tensor:
@ -570,6 +582,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
in_proj_floor: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -629,9 +642,12 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(
_qkv = nn.functional.linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
)
if in_proj_floor is not None:
_qkv = torch.maximum(_qkv, in_proj_floor)
q, k, v = _qkv.chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
@ -643,6 +659,10 @@ class RelPositionMultiheadAttention(nn.Module):
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
if in_proj_floor is not None:
_f = in_proj_floor[_start:_end]
q = torch.maximum(q, _f)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
@ -650,7 +670,11 @@ class RelPositionMultiheadAttention(nn.Module):
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
_kv = nn.functional.linear(key, _w, _b)
if in_proj_floor is not None:
_f = in_proj_floor[_start:_end]
_kv = torch.maximum(_kv, _f)
k, v = _kv.chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
@ -661,6 +685,10 @@ class RelPositionMultiheadAttention(nn.Module):
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
if in_proj_floor is not None:
_f = in_proj_floor[_start:_end]
q = torch.maximum(q, _f)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
@ -670,6 +698,9 @@ class RelPositionMultiheadAttention(nn.Module):
if _b is not None:
_b = _b[_start:_end]
k = nn.functional.linear(key, _w, _b)
if in_proj_floor is not None:
_f = in_proj_floor[_start:_end]
k = torch.maximum(k, _f)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
@ -679,6 +710,10 @@ class RelPositionMultiheadAttention(nn.Module):
if _b is not None:
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
if in_proj_floor is not None:
_f = in_proj_floor[_start:_end]
v = torch.maximum(v, _f)
if attn_mask is not None:
assert (
@ -918,3 +953,13 @@ class Swish(torch.nn.Module):
def identity(x):
return x
if __name__ == '__main__':
feature_dim = 50
c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = c(torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64))

View File

@ -82,6 +82,7 @@ class Decoder(nn.Module):
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
y = y.to(torch.int64)
embedding_out = self.embedding(y)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)