Remove pos_emb schedule

This commit is contained in:
Daniel Povey 2022-11-14 15:31:56 +08:00
parent 54048009db
commit ba69eb48fe

View File

@ -1014,8 +1014,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
query_head_dim: dimension of the query (and key), per head. e.g. 24.
pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
dropout: dropout probability for attn_output_weights. Default: 0.0.
pos_emb_skip: probability for skipping the pos_emb part of the scores on
any given call to forward(), in training time.
"""
def __init__(
@ -1026,8 +1024,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
query_head_dim: int,
pos_head_dim: int,
dropout: float = 0.0,
pos_emb_skip: FloatLike = ScheduledFloat((0.0, 0.5),
(4000.0, 0.025), default=0.0)
) -> None:
super().__init__()
self.embed_dim = embed_dim
@ -1035,7 +1031,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self.query_head_dim = query_head_dim
self.pos_head_dim = pos_head_dim
self.dropout = dropout
self.pos_emb_skip = copy.deepcopy(pos_emb_skip)
key_head_dim = query_head_dim
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
@ -1120,7 +1115,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = torch.matmul(q, k)
if not self.training or random.random() >= float(self.pos_emb_skip):
pos_emb = self.linear_pos(pos_emb)
seq_len2 = 2 * seq_len - 1
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)