mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove pos_emb schedule
This commit is contained in:
parent
54048009db
commit
ba69eb48fe
@ -1014,8 +1014,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
query_head_dim: dimension of the query (and key), per head. e.g. 24.
|
query_head_dim: dimension of the query (and key), per head. e.g. 24.
|
||||||
pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
|
pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
|
||||||
dropout: dropout probability for attn_output_weights. Default: 0.0.
|
dropout: dropout probability for attn_output_weights. Default: 0.0.
|
||||||
pos_emb_skip: probability for skipping the pos_emb part of the scores on
|
|
||||||
any given call to forward(), in training time.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -1026,8 +1024,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
query_head_dim: int,
|
query_head_dim: int,
|
||||||
pos_head_dim: int,
|
pos_head_dim: int,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
pos_emb_skip: FloatLike = ScheduledFloat((0.0, 0.5),
|
|
||||||
(4000.0, 0.025), default=0.0)
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -1035,7 +1031,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
self.query_head_dim = query_head_dim
|
self.query_head_dim = query_head_dim
|
||||||
self.pos_head_dim = pos_head_dim
|
self.pos_head_dim = pos_head_dim
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.pos_emb_skip = copy.deepcopy(pos_emb_skip)
|
|
||||||
|
|
||||||
key_head_dim = query_head_dim
|
key_head_dim = query_head_dim
|
||||||
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
|
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
|
||||||
@ -1120,26 +1115,25 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
attn_scores = torch.matmul(q, k)
|
attn_scores = torch.matmul(q, k)
|
||||||
|
|
||||||
if not self.training or random.random() >= float(self.pos_emb_skip):
|
pos_emb = self.linear_pos(pos_emb)
|
||||||
pos_emb = self.linear_pos(pos_emb)
|
seq_len2 = 2 * seq_len - 1
|
||||||
seq_len2 = 2 * seq_len - 1
|
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
|
||||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
|
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||||
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
|
||||||
|
|
||||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||||
# [where seq_len2 represents relative position.]
|
# [where seq_len2 represents relative position.]
|
||||||
pos_scores = torch.matmul(p, pos_emb)
|
pos_scores = torch.matmul(p, pos_emb)
|
||||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||||
# not, but let this code define which way round it is supposed to be.
|
# not, but let this code define which way round it is supposed to be.
|
||||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
|
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
|
||||||
(pos_scores.stride(0),
|
(pos_scores.stride(0),
|
||||||
pos_scores.stride(1),
|
pos_scores.stride(1),
|
||||||
pos_scores.stride(2)-pos_scores.stride(3),
|
pos_scores.stride(2)-pos_scores.stride(3),
|
||||||
pos_scores.stride(3)),
|
pos_scores.stride(3)),
|
||||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||||
|
|
||||||
attn_scores = attn_scores + pos_scores
|
attn_scores = attn_scores + pos_scores
|
||||||
|
|
||||||
if self.training and random.random() < 0.1:
|
if self.training and random.random() < 0.1:
|
||||||
# This is a harder way of limiting the attention scores to not be
|
# This is a harder way of limiting the attention scores to not be
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user