From 70408d22fe8e112f27f6267ec5ceb35dae659027 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Nov 2022 15:44:08 +0800 Subject: [PATCH] Add trainable scales for pos_emb # Conflicts: # egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py --- .../ASR/pruned_transducer_stateless7/zipformer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c7dc693ea..6e1f20fc7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -986,6 +986,7 @@ class RelPositionalEncoding(torch.nn.Module): self.embed_dim = embed_dim assert embed_dim % 2 == 0 self.dropout = torch.nn.Dropout(dropout_rate) + self.pe_scales = torch.nn.Parameter(torch.ones(embed_dim)) self.pe = None self.extend_pe(torch.tensor(0.0).expand(max_len)) @@ -1052,6 +1053,13 @@ class RelPositionalEncoding(torch.nn.Module): + x.size(0), : ] + scales = self.pe_scales + if self.training and random.random() < 0.5: + # randomly, half the time, clamp to this range; this will discourage + # the scales going outside of this range while allowing them to + # re-enter (because derivs won't always be zero). + scales = scales.clamp(min=0.25, max=4.0) + pos_emb = pos_emb * scales batch_size = x.size(1) pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1) # now: (batch_size, 2*time-1, embed_dim) return self.dropout(pos_emb)