From 2c6f5e82b28e43da361aecbf04f79651551dc39a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Nov 2022 22:32:38 +0800 Subject: [PATCH] Use atan not tanh --- .../ASR/pruned_transducer_stateless7/zipformer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index bd506310f..739b19a17 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1002,22 +1002,24 @@ class RelPositionalEncoding(torch.nn.Module): device=x.device).to(torch.float32).unsqueeze(1) - length_factor = self.embed_dim / (2.0*math.pi) # todo: test this. + length_factor = self.embed_dim / (2.0 * math.pi) # todo: test this. freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) - x_tanh = (x / length_factor).tanh() # results between -pi and pi + x_atan = (x / length_factor).atan() # results between -pi and pi - cosines = (x_tanh * freqs).cos() - sines = (x_tanh * freqs).sin() + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) pe[:, 0::2] = cosines pe[:, 1::2] = sines pe[:, -1] = 1.0 # for bias. - #print("cosines = ", cosines[T-r:T+r,:5]) - #print("sines = ", sines[T-r:T+r,:5]) + #r = 2 + #print("cosines = ", cosines[T-r:T+r,-5:]) + #print("sines = ", sines[T-r:T+r,-5:]) self.pe = pe.to(dtype=x.dtype)