diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index d6c74a360..90819cf38 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1008,14 +1008,28 @@ class RelPositionalEncoding(torch.nn.Module): x = torch.arange(-(T-1), T, device=x.device).to(torch.float32).unsqueeze(1) - - length_factor = self.embed_dim / (2.0 * math.pi) # todo: test this. - freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) - x_atan = (x / length_factor).atan() # results between -pi and pi + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = (self.embed_dim ** 0.5) + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = compression_length * x.sign() * (x.abs() + compression_length).log() + # length_factor is chosen so that the FFT can exactly separate points + # close to the origin (T == 0). So this part of the formulation is not really + # heuristic. + length_factor = self.embed_dim / (2.0 * math.pi) # todo: test this. + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_factor).atan() # results between -pi and pi + cosines = (x_atan * freqs).cos() sines = (x_atan * freqs).sin()