Remove pos_emb csales

This commit is contained in:
Daniel Povey 2022-11-14 15:32:54 +08:00
parent ba69eb48fe
commit 804917837e

View File

@ -904,7 +904,6 @@ class RelPositionalEncoding(torch.nn.Module):
self.embed_dim = embed_dim
assert embed_dim % 2 == 0
self.dropout = torch.nn.Dropout(dropout_rate)
self.pe_scales = torch.nn.Parameter(torch.ones(embed_dim))
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(max_len))
@ -985,13 +984,6 @@ class RelPositionalEncoding(torch.nn.Module):
+ x.size(0),
:
]
scales = self.pe_scales
if self.training and random.random() < 0.5:
# randomly, half the time, clamp to this range; this will discourage
# the scales going outside of this range while allowing them to
# re-enter (because derivs won't always be zero).
scales = scales.clamp(min=0.25, max=4.0)
pos_emb = pos_emb * scales
batch_size = x.size(1)
pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1) # now: (batch_size, 2*time-1, embed_dim)
return self.dropout(pos_emb)