Introduce random shift with stddev=1.0 into pos_emb

This commit is contained in:
Daniel Povey 2022-11-18 12:06:12 +08:00
parent e9806950f5
commit f7c99ed1d1

View File

@ -890,27 +890,32 @@ class CompactRelPositionalEncoding(torch.nn.Module):
Args:
embed_dim: Embedding dimension.
dropout_rate: Dropout rate.
random_shift: standard deviation of random distance by which we shift each time, if
training.
max_len: Maximum input length: just a heuristic for initialization.
"""
def __init__(
self, embed_dim: int,
dropout_rate: float,
dropout_rate: FloatLike = 0.0,
random_shift: FloatLike = 1.0,
max_len: int = 1000
) -> None:
"""Construct a CompactRelPositionalEncoding object."""
super(CompactRelPositionalEncoding, self).__init__()
self.embed_dim = embed_dim
assert embed_dim % 2 == 0
self.dropout = torch.nn.Dropout(dropout_rate)
self.dropout_rate = dropout_rate
self.random_shift = random_shift
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(max_len))
self.extend_pe(torch.tensor(0.0).expand(max_len), 0)
def extend_pe(self, x: Tensor) -> None:
def extend_pe(self, x: Tensor, shift: int) -> None:
"""Reset the positional encodings."""
T = x.size(0) + abs(shift)
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(0) >= x.size(0) * 2 - 1:
if self.pe.size(0) >= T * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
@ -918,7 +923,6 @@ class CompactRelPositionalEncoding(torch.nn.Module):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
T = x.size(0)
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
x = torch.arange(-(T-1), T,
device=x.device).to(torch.float32).unsqueeze(1)
@ -973,16 +977,22 @@ class CompactRelPositionalEncoding(torch.nn.Module):
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
shift = 0 if self.training else int(round(random.normalvariate(0, 1) * float(self.random_shift)))
self.extend_pe(x, shift)
pos_emb = self.pe[
self.pe.size(0) // 2
- x.size(0)
- x.size(0) + shift
+ 1 : self.pe.size(0) // 2 # noqa E203
+ x.size(0),
+ x.size(0) + shift,
:
]
pos_emb = pos_emb.unsqueeze(0)
return self.dropout(pos_emb)
pos_emb = torch.nn.functional.dropout(pos_emb,
p=float(self.dropout_rate),
training=self.training)
return pos_emb