diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 95e2b5ae8..fb601beb1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1010,7 +1010,7 @@ class RelPositionalEncoding(torch.nn.Module): # e.g. width_factor == 0.1, if ratio is 1.1. determines steepness of sigmoid. width_factor = ratio - 1.0 - # centers of sigmoids (positive) + # centers of tanh functions (positive) centers = (1.0 / (ratio - 1.0)) * (ratio ** torch.arange(self.embed_dim // 2, device=x.device)) # adjust centers so the 1st value is 0.5, the 2nd value is about 1.5, and so on. @@ -1018,10 +1018,14 @@ class RelPositionalEncoding(torch.nn.Module): centers = centers.reshape(self.embed_dim // 2, 1).expand(self.embed_dim // 2, 2).contiguous() centers[:, 1] *= -1.0 centers = centers.reshape(self.embed_dim) + widths = (centers.abs() * width_factor).clamp(min=1.0) # shape: (2*T - 1, embed_dim) - pe = ((x - centers) / widths).sigmoid() + pe = ((x - centers) / widths).tanh() + # Let the last dimension of the embedding be a constant 1.0, to provide a bias term + # so we don't have to give a bias to the projection. + pe[:, -1] = 1.0 self.pe = pe.to(dtype=x.dtype)