diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 949af6c19..6a091f2d5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -964,80 +964,89 @@ class SimpleCombiner(torch.nn.Module): class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py Args: embed_dim: Embedding dimension. + max_offset: determines the largest offset that can be coded distinctly from smaller + offsets. If you change this, the embedding will change. dropout_rate: Dropout rate. - max_len: Maximum input length. - + max_len: Maximum input length: just a heuristic for initialization. """ - def __init__( - self, embed_dim: int, dropout_rate: float, max_len: int = 5000 + self, embed_dim: int, + dropout_rate: float, + max_offset: float = 1000.0, + max_len: int = 10 ) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() + self.max_offset = max_offset self.embed_dim = embed_dim + assert embed_dim % 2 == 0 self.dropout = torch.nn.Dropout(dropout_rate) self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self.extend_pe(torch.tensor(0.0).expand(max_len)) def extend_pe(self, x: Tensor) -> None: """Reset the positional encodings.""" if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(0) * 2 - 1: + if self.pe.size(0) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device if self.pe.dtype != x.dtype or str(self.pe.device) != str( x.device ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return - # Suppose `i` means to the position of query vecotr and `j` means the - # position of key vector. We use position relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + # e.g. ratio might be 1.1 + ratio = self.max_offset ** (2.0 / self.embed_dim) + # e.g. width_factor == 0.1, if ratio is 1.1. determines steepness of sigmoid. + width_factor = ratio - 1.0 + + max_val = 1000.0 + + # centers of sigmoids (positive) + centers = (1.0 / (ratio - 1.0)) * (ratio ** torch.arange(self.embed_dim // 2, + device=x.device)) + # adjust centers so the 1st value is 0.5, the 2nd value is about 1.5, and so on. + centers = centers + (0.5 - centers[0]) + centers = centers.reshape(self.embed_dim // 2, 1).expand(self.embed_dim // 2, 2).contiguous() + centers[:, 1] *= -1.0 + centers = centers.reshape(self.embed_dim) + widths = (centers.abs() * width_factor).clamp(min=1.0) + + # shape: (2*T - 1, embed_dim) + pe = ((x - centers) / widths).sigmoid() + + self.pe = pe.to(dtype=x.dtype) + + + def forward(self, x: torch.Tensor) -> Tensor: """Add positional encoding. Args: x (torch.Tensor): Input tensor (time, batch, `*`). Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ self.extend_pe(x) pos_emb = self.pe[ - :, - self.pe.size(1) // 2 + self.pe.size(0) // 2 - x.size(0) - + 1 : self.pe.size(1) // 2 # noqa E203 + + 1 : self.pe.size(0) // 2 # noqa E203 + x.size(0), + : ] + pos_emb = pos_emb.unsqueeze(0) # now: (1, 2*time-1, embed_dim) return self.dropout(pos_emb)