diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index efcb25754..c6d44bbe1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -890,27 +890,32 @@ class CompactRelPositionalEncoding(torch.nn.Module): Args: embed_dim: Embedding dimension. dropout_rate: Dropout rate. + random_shift: standard deviation of random distance by which we shift each time, if + training. max_len: Maximum input length: just a heuristic for initialization. """ def __init__( self, embed_dim: int, - dropout_rate: float, + dropout_rate: FloatLike = 0.0, + random_shift: FloatLike = 1.0, max_len: int = 1000 ) -> None: """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() self.embed_dim = embed_dim assert embed_dim % 2 == 0 - self.dropout = torch.nn.Dropout(dropout_rate) + self.dropout_rate = dropout_rate + self.random_shift = random_shift self.pe = None - self.extend_pe(torch.tensor(0.0).expand(max_len)) + self.extend_pe(torch.tensor(0.0).expand(max_len), 0) - def extend_pe(self, x: Tensor) -> None: + def extend_pe(self, x: Tensor, shift: int) -> None: """Reset the positional encodings.""" + T = x.size(0) + abs(shift) if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 - if self.pe.size(0) >= x.size(0) * 2 - 1: + if self.pe.size(0) >= T * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device if self.pe.dtype != x.dtype or str(self.pe.device) != str( x.device @@ -918,7 +923,6 @@ class CompactRelPositionalEncoding(torch.nn.Module): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return - T = x.size(0) # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] x = torch.arange(-(T-1), T, device=x.device).to(torch.float32).unsqueeze(1) @@ -973,16 +977,22 @@ class CompactRelPositionalEncoding(torch.nn.Module): torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ - self.extend_pe(x) + + shift = 0 if self.training else int(round(random.normalvariate(0, 1) * float(self.random_shift))) + self.extend_pe(x, shift) + pos_emb = self.pe[ self.pe.size(0) // 2 - - x.size(0) + - x.size(0) + shift + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0), + + x.size(0) + shift, : ] pos_emb = pos_emb.unsqueeze(0) - return self.dropout(pos_emb) + pos_emb = torch.nn.functional.dropout(pos_emb, + p=float(self.dropout_rate), + training=self.training) + return pos_emb