Introduce random shift with stddev=1.0 into pos_emb

This commit is contained in:
Daniel Povey 2022-11-18 12:06:12 +08:00
parent e9806950f5
commit f7c99ed1d1

View File

@ -890,27 +890,32 @@ class CompactRelPositionalEncoding(torch.nn.Module):
Args: Args:
embed_dim: Embedding dimension. embed_dim: Embedding dimension.
dropout_rate: Dropout rate. dropout_rate: Dropout rate.
random_shift: standard deviation of random distance by which we shift each time, if
training.
max_len: Maximum input length: just a heuristic for initialization. max_len: Maximum input length: just a heuristic for initialization.
""" """
def __init__( def __init__(
self, embed_dim: int, self, embed_dim: int,
dropout_rate: float, dropout_rate: FloatLike = 0.0,
random_shift: FloatLike = 1.0,
max_len: int = 1000 max_len: int = 1000
) -> None: ) -> None:
"""Construct a CompactRelPositionalEncoding object.""" """Construct a CompactRelPositionalEncoding object."""
super(CompactRelPositionalEncoding, self).__init__() super(CompactRelPositionalEncoding, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
assert embed_dim % 2 == 0 assert embed_dim % 2 == 0
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout_rate = dropout_rate
self.random_shift = random_shift
self.pe = None self.pe = None
self.extend_pe(torch.tensor(0.0).expand(max_len)) self.extend_pe(torch.tensor(0.0).expand(max_len), 0)
def extend_pe(self, x: Tensor) -> None: def extend_pe(self, x: Tensor, shift: int) -> None:
"""Reset the positional encodings.""" """Reset the positional encodings."""
T = x.size(0) + abs(shift)
if self.pe is not None: if self.pe is not None:
# self.pe contains both positive and negative parts # self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(0) >= x.size(0) * 2 - 1: if self.pe.size(0) >= T * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str( if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device x.device
@ -918,7 +923,6 @@ class CompactRelPositionalEncoding(torch.nn.Module):
self.pe = self.pe.to(dtype=x.dtype, device=x.device) self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
T = x.size(0)
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
x = torch.arange(-(T-1), T, x = torch.arange(-(T-1), T,
device=x.device).to(torch.float32).unsqueeze(1) device=x.device).to(torch.float32).unsqueeze(1)
@ -973,16 +977,22 @@ class CompactRelPositionalEncoding(torch.nn.Module):
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
""" """
self.extend_pe(x)
shift = 0 if self.training else int(round(random.normalvariate(0, 1) * float(self.random_shift)))
self.extend_pe(x, shift)
pos_emb = self.pe[ pos_emb = self.pe[
self.pe.size(0) // 2 self.pe.size(0) // 2
- x.size(0) - x.size(0) + shift
+ 1 : self.pe.size(0) // 2 # noqa E203 + 1 : self.pe.size(0) // 2 # noqa E203
+ x.size(0), + x.size(0) + shift,
: :
] ]
pos_emb = pos_emb.unsqueeze(0) pos_emb = pos_emb.unsqueeze(0)
return self.dropout(pos_emb) pos_emb = torch.nn.functional.dropout(pos_emb,
p=float(self.dropout_rate),
training=self.training)
return pos_emb