diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index cf2f7f3aa..174ccf39e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -728,7 +728,7 @@ class LimitParamValue(torch.autograd.Function): x, = ctx.saved_tensors # where x < ctx.min, ensure all grads are negative (this will tend to make # x more positive). - x_grad *= torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) + x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) # where x > ctx.max, ensure all grads are positive (this will tend to make # x more negative). x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 6317c587e..eeb186a6f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -553,7 +553,7 @@ class ZipformerEncoder(nn.Module): final_layerdrop_prob: float = 0.05, ) -> None: super().__init__() - self.encoder_pos = RelPositionalEncoding(pos_dim, dropout_rate=0.15) + self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -884,8 +884,22 @@ class SimpleCombiner(torch.nn.Module): -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) Args: @@ -898,8 +912,8 @@ class RelPositionalEncoding(torch.nn.Module): dropout_rate: float, max_len: int = 1000 ) -> None: - """Construct a PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() self.embed_dim = embed_dim assert embed_dim % 2 == 0 self.dropout = torch.nn.Dropout(dropout_rate) @@ -1199,7 +1213,7 @@ class SelfAttention(nn.Module): Args: embed_dim: the input and output embedding dimension num_heads: the number of attention heads - value_dim: the value dimension per head + value_head_dim: the value dimension per head """ def __init__( self,