diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index e502c991e..eff4f65c2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -896,32 +896,27 @@ class CompactRelPositionalEncoding(torch.nn.Module): Args: embed_dim: Embedding dimension. dropout_rate: Dropout rate. - random_shift: standard deviation of random distance by which we shift each time, if - training. max_len: Maximum input length: just a heuristic for initialization. """ def __init__( self, embed_dim: int, - dropout_rate: FloatLike = 0.0, - random_shift: FloatLike = 1.0, + dropout_rate: float, max_len: int = 1000 ) -> None: """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() self.embed_dim = embed_dim assert embed_dim % 2 == 0 - self.dropout_rate = dropout_rate - self.random_shift = random_shift + self.dropout = torch.nn.Dropout(dropout_rate) self.pe = None - self.extend_pe(torch.tensor(0.0).expand(max_len), 0) + self.extend_pe(torch.tensor(0.0).expand(max_len)) - def extend_pe(self, x: Tensor, shift: int) -> None: + def extend_pe(self, x: Tensor) -> None: """Reset the positional encodings.""" - T = x.size(0) + abs(shift) if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 - if self.pe.size(0) >= T * 2 - 1: + if self.pe.size(0) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device if self.pe.dtype != x.dtype or str(self.pe.device) != str( x.device @@ -929,6 +924,7 @@ class CompactRelPositionalEncoding(torch.nn.Module): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return + T = x.size(0) # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] x = torch.arange(-(T-1), T, device=x.device).to(torch.float32).unsqueeze(1) @@ -983,22 +979,16 @@ class CompactRelPositionalEncoding(torch.nn.Module): torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ - - shift = int(round(random.normalvariate(0, 1) * float(self.random_shift))) if self.training else 0 - self.extend_pe(x, shift) - + self.extend_pe(x) pos_emb = self.pe[ self.pe.size(0) // 2 - - x.size(0) + shift + - x.size(0) + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0) + shift, + + x.size(0), : ] pos_emb = pos_emb.unsqueeze(0) - pos_emb = torch.nn.functional.dropout(pos_emb, - p=float(self.dropout_rate), - training=self.training) - return pos_emb + return self.dropout(pos_emb)