diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index ff9468167..5b510a71d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -544,7 +544,8 @@ class ZipformerEncoder(nn.Module): final_layerdrop_rate: float = 0.05, ) -> None: super().__init__() - self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15) + self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15, + length_factor=1.5) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -897,11 +898,14 @@ class CompactRelPositionalEncoding(torch.nn.Module): embed_dim: Embedding dimension. dropout_rate: Dropout rate. max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. """ def __init__( self, embed_dim: int, dropout_rate: float, - max_len: int = 1000 + max_len: int = 1000, + length_factor: float = 1.0, ) -> None: """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() @@ -909,8 +913,12 @@ class CompactRelPositionalEncoding(torch.nn.Module): assert embed_dim % 2 == 0 self.dropout = torch.nn.Dropout(dropout_rate) self.pe = None + assert length_factor >= 1.0 + self.length_factor = length_factor self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor) -> None: """Reset the positional encodings.""" if self.pe is not None: @@ -940,18 +948,16 @@ class CompactRelPositionalEncoding(torch.nn.Module): # is important. x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length)) - # length_factor is chosen so that the FFT can exactly separate points - # close to the origin (T == 0). So this part of the formulation is not really - # heuristic. - length_factor = self.embed_dim / (2.0 * math.pi) - # multiplying length_factor by this heuristic constant should reduce the resolution near to the - # origin, i.e. reduce its ability to separate points near zero. - length_factor *= 2.0 + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) # note for machine implementations: if atan is not available, we can use: # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) - x_atan = (x_compressed / length_factor).atan() # results between -pi and pi + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi cosines = (x_atan * freqs).cos() sines = (x_atan * freqs).sin() @@ -961,14 +967,6 @@ class CompactRelPositionalEncoding(torch.nn.Module): pe[:, 1::2] = sines pe[:, -1] = 1.0 # for bias. - # if we have the length_factor correct, the cosines around 0 offset (T in the array) - # should be oscillating in sign like -1, 1, -1; and the sines should all be close to - # zero. - #r = 2 - #print("cosines = ", cosines[T-r:T+r,-5:]) - #print("sines = ", sines[T-r:T+r,-5:]) - - self.pe = pe.to(dtype=x.dtype)