mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove pos_emb csales
This commit is contained in:
parent
ba69eb48fe
commit
804917837e
@ -904,7 +904,6 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
self.embed_dim = embed_dim
|
||||
assert embed_dim % 2 == 0
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.pe_scales = torch.nn.Parameter(torch.ones(embed_dim))
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(max_len))
|
||||
|
||||
@ -985,13 +984,6 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
+ x.size(0),
|
||||
:
|
||||
]
|
||||
scales = self.pe_scales
|
||||
if self.training and random.random() < 0.5:
|
||||
# randomly, half the time, clamp to this range; this will discourage
|
||||
# the scales going outside of this range while allowing them to
|
||||
# re-enter (because derivs won't always be zero).
|
||||
scales = scales.clamp(min=0.25, max=4.0)
|
||||
pos_emb = pos_emb * scales
|
||||
batch_size = x.size(1)
|
||||
pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1) # now: (batch_size, 2*time-1, embed_dim)
|
||||
return self.dropout(pos_emb)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user