Revert 419->420 change, regarding random shift in pos embedding

This commit is contained in:
Daniel Povey 2022-11-20 13:07:20 +08:00
parent b9871cc4f5
commit 8b3303594c

View File

@ -896,32 +896,27 @@ class CompactRelPositionalEncoding(torch.nn.Module):
Args: Args:
embed_dim: Embedding dimension. embed_dim: Embedding dimension.
dropout_rate: Dropout rate. dropout_rate: Dropout rate.
random_shift: standard deviation of random distance by which we shift each time, if
training.
max_len: Maximum input length: just a heuristic for initialization. max_len: Maximum input length: just a heuristic for initialization.
""" """
def __init__( def __init__(
self, embed_dim: int, self, embed_dim: int,
dropout_rate: FloatLike = 0.0, dropout_rate: float,
random_shift: FloatLike = 1.0,
max_len: int = 1000 max_len: int = 1000
) -> None: ) -> None:
"""Construct a CompactRelPositionalEncoding object.""" """Construct a CompactRelPositionalEncoding object."""
super(CompactRelPositionalEncoding, self).__init__() super(CompactRelPositionalEncoding, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
assert embed_dim % 2 == 0 assert embed_dim % 2 == 0
self.dropout_rate = dropout_rate self.dropout = torch.nn.Dropout(dropout_rate)
self.random_shift = random_shift
self.pe = None self.pe = None
self.extend_pe(torch.tensor(0.0).expand(max_len), 0) self.extend_pe(torch.tensor(0.0).expand(max_len))
def extend_pe(self, x: Tensor, shift: int) -> None: def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings.""" """Reset the positional encodings."""
T = x.size(0) + abs(shift)
if self.pe is not None: if self.pe is not None:
# self.pe contains both positive and negative parts # self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(0) >= T * 2 - 1: if self.pe.size(0) >= x.size(0) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str( if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device x.device
@ -929,6 +924,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
self.pe = self.pe.to(dtype=x.dtype, device=x.device) self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
T = x.size(0)
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
x = torch.arange(-(T-1), T, x = torch.arange(-(T-1), T,
device=x.device).to(torch.float32).unsqueeze(1) device=x.device).to(torch.float32).unsqueeze(1)
@ -983,22 +979,16 @@ class CompactRelPositionalEncoding(torch.nn.Module):
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
""" """
self.extend_pe(x)
shift = int(round(random.normalvariate(0, 1) * float(self.random_shift))) if self.training else 0
self.extend_pe(x, shift)
pos_emb = self.pe[ pos_emb = self.pe[
self.pe.size(0) // 2 self.pe.size(0) // 2
- x.size(0) + shift - x.size(0)
+ 1 : self.pe.size(0) // 2 # noqa E203 + 1 : self.pe.size(0) // 2 # noqa E203
+ x.size(0) + shift, + x.size(0),
: :
] ]
pos_emb = pos_emb.unsqueeze(0) pos_emb = pos_emb.unsqueeze(0)
pos_emb = torch.nn.functional.dropout(pos_emb, return self.dropout(pos_emb)
p=float(self.dropout_rate),
training=self.training)
return pos_emb