mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
Add min in q,k,v of attention
This commit is contained in:
parent
8f8ec223a7
commit
48a764eccf
@ -440,8 +440,19 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
||||||
|
|
||||||
|
self.in_proj_floor_scale = 10.0 # so it learns fast enough..
|
||||||
|
with torch.no_grad():
|
||||||
|
in_proj_floor = torch.Tensor(3 * embed_dim)
|
||||||
|
# key and query get a floor value quite close to zero.
|
||||||
|
in_proj_floor[:2*embed_dim] = -0.2 / self.in_proj_floor_scale
|
||||||
|
# value gets very low floor, may be close to having no effectc.
|
||||||
|
in_proj_floor[2*embed_dim:] = -1.5 / self.in_proj_floor_scale
|
||||||
|
self.in_proj_floor = nn.Parameter(in_proj_floor)
|
||||||
|
|
||||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||||
|
|
||||||
|
|
||||||
# linear transformation for positional encoding.
|
# linear transformation for positional encoding.
|
||||||
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||||
# these two learnable bias are used in matrix c and matrix d
|
# these two learnable bias are used in matrix c and matrix d
|
||||||
@ -526,6 +537,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
key_padding_mask=key_padding_mask,
|
key_padding_mask=key_padding_mask,
|
||||||
need_weights=need_weights,
|
need_weights=need_weights,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
|
in_proj_floor=self.in_proj_floor*self.in_proj_floor_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
def rel_shift(self, x: Tensor) -> Tensor:
|
def rel_shift(self, x: Tensor) -> Tensor:
|
||||||
@ -570,6 +582,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
need_weights: bool = True,
|
need_weights: bool = True,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
in_proj_floor: Optional[Tensor] = None,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -629,9 +642,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
if torch.equal(query, key) and torch.equal(key, value):
|
if torch.equal(query, key) and torch.equal(key, value):
|
||||||
# self-attention
|
# self-attention
|
||||||
q, k, v = nn.functional.linear(
|
_qkv = nn.functional.linear(
|
||||||
query, in_proj_weight, in_proj_bias
|
query, in_proj_weight, in_proj_bias
|
||||||
).chunk(3, dim=-1)
|
)
|
||||||
|
if in_proj_floor is not None:
|
||||||
|
_qkv = torch.maximum(_qkv, in_proj_floor)
|
||||||
|
q, k, v = _qkv.chunk(3, dim=-1)
|
||||||
|
|
||||||
elif torch.equal(key, value):
|
elif torch.equal(key, value):
|
||||||
# encoder-decoder attention
|
# encoder-decoder attention
|
||||||
@ -643,6 +659,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
q = nn.functional.linear(query, _w, _b)
|
q = nn.functional.linear(query, _w, _b)
|
||||||
|
if in_proj_floor is not None:
|
||||||
|
_f = in_proj_floor[_start:_end]
|
||||||
|
q = torch.maximum(q, _f)
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
_start = embed_dim
|
_start = embed_dim
|
||||||
@ -650,7 +670,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:, :]
|
_w = in_proj_weight[_start:, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:]
|
_b = _b[_start:]
|
||||||
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
_kv = nn.functional.linear(key, _w, _b)
|
||||||
|
if in_proj_floor is not None:
|
||||||
|
_f = in_proj_floor[_start:_end]
|
||||||
|
_kv = torch.maximum(_kv, _f)
|
||||||
|
k, v = _kv.chunk(2, dim=-1)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
@ -661,6 +685,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
q = nn.functional.linear(query, _w, _b)
|
q = nn.functional.linear(query, _w, _b)
|
||||||
|
if in_proj_floor is not None:
|
||||||
|
_f = in_proj_floor[_start:_end]
|
||||||
|
q = torch.maximum(q, _f)
|
||||||
|
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
@ -670,6 +698,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
k = nn.functional.linear(key, _w, _b)
|
k = nn.functional.linear(key, _w, _b)
|
||||||
|
if in_proj_floor is not None:
|
||||||
|
_f = in_proj_floor[_start:_end]
|
||||||
|
k = torch.maximum(k, _f)
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
@ -679,6 +710,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:]
|
_b = _b[_start:]
|
||||||
v = nn.functional.linear(value, _w, _b)
|
v = nn.functional.linear(value, _w, _b)
|
||||||
|
if in_proj_floor is not None:
|
||||||
|
_f = in_proj_floor[_start:_end]
|
||||||
|
v = torch.maximum(v, _f)
|
||||||
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
assert (
|
assert (
|
||||||
@ -918,3 +953,13 @@ class Swish(torch.nn.Module):
|
|||||||
|
|
||||||
def identity(x):
|
def identity(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
feature_dim = 50
|
||||||
|
c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
|
||||||
|
batch_size = 5
|
||||||
|
seq_len = 20
|
||||||
|
# Just make sure the forward pass runs.
|
||||||
|
f = c(torch.randn(batch_size, seq_len, feature_dim),
|
||||||
|
torch.full((batch_size,), seq_len, dtype=torch.int64))
|
||||||
|
@ -82,6 +82,7 @@ class Decoder(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, U, embedding_dim).
|
Return a tensor of shape (N, U, embedding_dim).
|
||||||
"""
|
"""
|
||||||
|
y = y.to(torch.int64)
|
||||||
embedding_out = self.embedding(y)
|
embedding_out = self.embedding(y)
|
||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user