mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Introduce random shift with stddev=1.0 into pos_emb
This commit is contained in:
parent
e9806950f5
commit
f7c99ed1d1
@ -890,27 +890,32 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
embed_dim: Embedding dimension.
|
embed_dim: Embedding dimension.
|
||||||
dropout_rate: Dropout rate.
|
dropout_rate: Dropout rate.
|
||||||
|
random_shift: standard deviation of random distance by which we shift each time, if
|
||||||
|
training.
|
||||||
max_len: Maximum input length: just a heuristic for initialization.
|
max_len: Maximum input length: just a heuristic for initialization.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, embed_dim: int,
|
self, embed_dim: int,
|
||||||
dropout_rate: float,
|
dropout_rate: FloatLike = 0.0,
|
||||||
|
random_shift: FloatLike = 1.0,
|
||||||
max_len: int = 1000
|
max_len: int = 1000
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct a CompactRelPositionalEncoding object."""
|
"""Construct a CompactRelPositionalEncoding object."""
|
||||||
super(CompactRelPositionalEncoding, self).__init__()
|
super(CompactRelPositionalEncoding, self).__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
assert embed_dim % 2 == 0
|
assert embed_dim % 2 == 0
|
||||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
self.dropout_rate = dropout_rate
|
||||||
|
self.random_shift = random_shift
|
||||||
self.pe = None
|
self.pe = None
|
||||||
self.extend_pe(torch.tensor(0.0).expand(max_len))
|
self.extend_pe(torch.tensor(0.0).expand(max_len), 0)
|
||||||
|
|
||||||
def extend_pe(self, x: Tensor) -> None:
|
def extend_pe(self, x: Tensor, shift: int) -> None:
|
||||||
"""Reset the positional encodings."""
|
"""Reset the positional encodings."""
|
||||||
|
T = x.size(0) + abs(shift)
|
||||||
if self.pe is not None:
|
if self.pe is not None:
|
||||||
# self.pe contains both positive and negative parts
|
# self.pe contains both positive and negative parts
|
||||||
# the length of self.pe is 2 * input_len - 1
|
# the length of self.pe is 2 * input_len - 1
|
||||||
if self.pe.size(0) >= x.size(0) * 2 - 1:
|
if self.pe.size(0) >= T * 2 - 1:
|
||||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||||
x.device
|
x.device
|
||||||
@ -918,7 +923,6 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||||
return
|
return
|
||||||
|
|
||||||
T = x.size(0)
|
|
||||||
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
|
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
|
||||||
x = torch.arange(-(T-1), T,
|
x = torch.arange(-(T-1), T,
|
||||||
device=x.device).to(torch.float32).unsqueeze(1)
|
device=x.device).to(torch.float32).unsqueeze(1)
|
||||||
@ -973,16 +977,22 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x)
|
|
||||||
|
shift = 0 if self.training else int(round(random.normalvariate(0, 1) * float(self.random_shift)))
|
||||||
|
self.extend_pe(x, shift)
|
||||||
|
|
||||||
pos_emb = self.pe[
|
pos_emb = self.pe[
|
||||||
self.pe.size(0) // 2
|
self.pe.size(0) // 2
|
||||||
- x.size(0)
|
- x.size(0) + shift
|
||||||
+ 1 : self.pe.size(0) // 2 # noqa E203
|
+ 1 : self.pe.size(0) // 2 # noqa E203
|
||||||
+ x.size(0),
|
+ x.size(0) + shift,
|
||||||
:
|
:
|
||||||
]
|
]
|
||||||
pos_emb = pos_emb.unsqueeze(0)
|
pos_emb = pos_emb.unsqueeze(0)
|
||||||
return self.dropout(pos_emb)
|
pos_emb = torch.nn.functional.dropout(pos_emb,
|
||||||
|
p=float(self.dropout_rate),
|
||||||
|
training=self.training)
|
||||||
|
return pos_emb
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user