mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Introduce random shift with stddev=1.0 into pos_emb
This commit is contained in:
parent
e9806950f5
commit
f7c99ed1d1
@ -890,27 +890,32 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
Args:
|
||||
embed_dim: Embedding dimension.
|
||||
dropout_rate: Dropout rate.
|
||||
random_shift: standard deviation of random distance by which we shift each time, if
|
||||
training.
|
||||
max_len: Maximum input length: just a heuristic for initialization.
|
||||
"""
|
||||
def __init__(
|
||||
self, embed_dim: int,
|
||||
dropout_rate: float,
|
||||
dropout_rate: FloatLike = 0.0,
|
||||
random_shift: FloatLike = 1.0,
|
||||
max_len: int = 1000
|
||||
) -> None:
|
||||
"""Construct a CompactRelPositionalEncoding object."""
|
||||
super(CompactRelPositionalEncoding, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
assert embed_dim % 2 == 0
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.dropout_rate = dropout_rate
|
||||
self.random_shift = random_shift
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(max_len))
|
||||
self.extend_pe(torch.tensor(0.0).expand(max_len), 0)
|
||||
|
||||
def extend_pe(self, x: Tensor) -> None:
|
||||
def extend_pe(self, x: Tensor, shift: int) -> None:
|
||||
"""Reset the positional encodings."""
|
||||
T = x.size(0) + abs(shift)
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(0) >= x.size(0) * 2 - 1:
|
||||
if self.pe.size(0) >= T * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
@ -918,7 +923,6 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
|
||||
T = x.size(0)
|
||||
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
|
||||
x = torch.arange(-(T-1), T,
|
||||
device=x.device).to(torch.float32).unsqueeze(1)
|
||||
@ -973,16 +977,22 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
|
||||
shift = 0 if self.training else int(round(random.normalvariate(0, 1) * float(self.random_shift)))
|
||||
self.extend_pe(x, shift)
|
||||
|
||||
pos_emb = self.pe[
|
||||
self.pe.size(0) // 2
|
||||
- x.size(0)
|
||||
- x.size(0) + shift
|
||||
+ 1 : self.pe.size(0) // 2 # noqa E203
|
||||
+ x.size(0),
|
||||
+ x.size(0) + shift,
|
||||
:
|
||||
]
|
||||
pos_emb = pos_emb.unsqueeze(0)
|
||||
return self.dropout(pos_emb)
|
||||
pos_emb = torch.nn.functional.dropout(pos_emb,
|
||||
p=float(self.dropout_rate),
|
||||
training=self.training)
|
||||
return pos_emb
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user