diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index fc30e4f53..bd506310f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -967,20 +967,16 @@ class RelPositionalEncoding(torch.nn.Module): Args: embed_dim: Embedding dimension. - max_offset: determines the largest offset that can be coded distinctly from smaller - offsets. If you change this, the embedding will change. dropout_rate: Dropout rate. max_len: Maximum input length: just a heuristic for initialization. """ def __init__( self, embed_dim: int, dropout_rate: float, - max_offset: float = 1000.0, - max_len: int = 10 + max_len: int = 1000 ) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() - self.max_offset = max_offset self.embed_dim = embed_dim assert embed_dim % 2 == 0 self.dropout = torch.nn.Dropout(dropout_rate) @@ -1005,27 +1001,24 @@ class RelPositionalEncoding(torch.nn.Module): x = torch.arange(-(T-1), T, device=x.device).to(torch.float32).unsqueeze(1) - # e.g. ratio might be 1.1 - ratio = self.max_offset ** (2.0 / self.embed_dim) - # e.g. width_factor == 0.1, if ratio is 1.1. determines steepness of sigmoid. - width_factor = ratio - 1.0 - # centers of tanh functions (positive) - centers = (1.0 / (ratio - 1.0)) * (ratio ** torch.arange(self.embed_dim // 2, - device=x.device)) - # adjust centers so the 1st value is 0.5, the 2nd value is about 1.5, and so on. - centers = centers + (0.5 - centers[0]) - centers = centers.reshape(self.embed_dim // 2, 1).expand(self.embed_dim // 2, 2).contiguous() - centers[:, 1] *= -1.0 - centers = centers.reshape(self.embed_dim) + length_factor = self.embed_dim / (2.0*math.pi) # todo: test this. - widths = (centers.abs() * width_factor) + 1.0 + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + x_tanh = (x / length_factor).tanh() # results between -pi and pi + + cosines = (x_tanh * freqs).cos() + sines = (x_tanh * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + #print("cosines = ", cosines[T-r:T+r,:5]) + #print("sines = ", sines[T-r:T+r,:5]) - # shape: (2*T - 1, embed_dim) - pe = ((x - centers) / widths).tanh() - # Let the last dimension of the embedding be a constant 1.0, to provide a bias term - # so we don't have to give a bias to the projection. - pe[:, -1] = 1.0 self.pe = pe.to(dtype=x.dtype)