Add trainable scales for pos_emb

# Conflicts:
#	egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
This commit is contained in:
Daniel Povey 2022-11-12 15:44:08 +08:00
parent 603be9933b
commit 70408d22fe

View File

@ -986,6 +986,7 @@ class RelPositionalEncoding(torch.nn.Module):
self.embed_dim = embed_dim
assert embed_dim % 2 == 0
self.dropout = torch.nn.Dropout(dropout_rate)
self.pe_scales = torch.nn.Parameter(torch.ones(embed_dim))
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(max_len))
@ -1052,6 +1053,13 @@ class RelPositionalEncoding(torch.nn.Module):
+ x.size(0),
:
]
scales = self.pe_scales
if self.training and random.random() < 0.5:
# randomly, half the time, clamp to this range; this will discourage
# the scales going outside of this range while allowing them to
# re-enter (because derivs won't always be zero).
scales = scales.clamp(min=0.25, max=4.0)
pos_emb = pos_emb * scales
batch_size = x.size(1)
pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1) # now: (batch_size, 2*time-1, embed_dim)
return self.dropout(pos_emb)