Implement GaussProjDrop

This commit is contained in:
Daniel Povey 2022-06-07 11:51:24 +08:00
parent 4352a16f57
commit 40a0934b4e
2 changed files with 76 additions and 43 deletions

View File

@ -640,7 +640,7 @@ class ScaledEmbedding(nn.Module):
return s.format(**self.__dict__)
class RandProjDrop(torch.nn.Module):
class GaussProjDrop(torch.nn.Module):
"""
This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions.
The directions of dropout are fixed when the class is initialized, and are orthogonal.
@ -652,8 +652,13 @@ class RandProjDrop(torch.nn.Module):
num_channels: int,
dropout_rate: float = 0.1,
channel_dim: int = -1):
super(RandProjDrop, self).__init__()
super(GaussProjDrop, self).__init__()
self.dropout_rate = dropout_rate
# this formula for rand_scale was found empirically, trying to match the
# statistics of dropout in terms of cross-correlation with the input, see
# _test_gauss_proj_drop()
self.rand_scale = (dropout_rate / (1-dropout_rate)) ** 0.5 # * (num_channels ** -0.5)
self.channel_dim = channel_dim
rand_mat = torch.randn(num_channels, num_channels)
@ -661,31 +666,52 @@ class RandProjDrop(torch.nn.Module):
self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
def _randperm_like(self, x: Tensor):
"""
Returns random permutations of the integers [0,1,..x.shape[-1]-1],
with the same shape as x. All dimensions of x other than the last dimension
will be treated as batch dimensions.
Torch's randperm does not support a batch dimension, so we pseudo-randomly simulate it.
For now, requires x.shape[-1] to be either a power of 2 or 3 times a power of 2, as
we normally set channel dims. This is required for some number theoretic stuff.
"""
n = x.shape[-1]
assert n & (n-1) == 0 or (n//3 & (n//3 - 1)) == 0
b = x.numel() // n
randint = random.randint(0, 1000)
perm = torch.randperm(n, device=x.device)
# ensure all elements of batch_rand are coprime to n; this will ensure
# that multiplying the permutation by batch_rand and taking modulo
# n leaves us with permutations.
batch_rand = torch.arange(b, device=x.device) * (randint * 6) + 1
batch_rand = batch_rand.unsqueeze(-1)
ans = (perm * batch_rand) % n
ans = ans.reshape(x.shape)
return ans
def forward(self, x: Tensor) -> Tensor:
if not self.training:
x = torch.nn.functional.dropout(x, self.dropout_rate,
training=False)
return x
else:
self._randomize_U()
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
U = self.U.clone()
x = torch.matmul(x, U)
x = torch.nn.functional.dropout(x, self.dropout_rate,
training=True)
x = torch.matmul(x, U.t())
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
return x
x_bypass = x # will be used for "+ I"
perm = self._randperm_like(x)
x = torch.gather(x, -1, perm)
# self.U will act like a different matrix for every row of x, because of the random
# permutation.
x = torch.matmul(x, self.U)
x_next = torch.empty_like(x)
# scatter_ uses perm in opposite way
# from gather, inverting it.
x_next.scatter_(-1, perm, x)
x = (x_next * self.rand_scale + x_bypass)
return x
def _randomize_U(self):
dim = random.randint(0, 1)
U = self.U
num_channels = U.shape[0]
# pick place to split U in two pieces.
r = random.randint(1, num_channels - 2)
U_part1 = U.narrow(dim, 0, r)
U_part2 = U.narrow(dim, r, num_channels-r)
U = torch.cat((U_part2, U_part1), dim=dim)
self.U[:] = U
def _test_activation_balancer_sign():
probs = torch.arange(0, 1, 0.01)
@ -760,23 +786,30 @@ def _test_double_swish_deriv():
torch.autograd.gradcheck(m, x)
def _test_proj_drop():
x = torch.randn(30000, 300)
m = RandProjDrop(300, 0.1)
y = m(x)
xmag = (x*x).mean()
ymag = (y*y).mean()
print(f"xmag = {xmag}, ymag = {ymag}")
#assert abs((ymag / xmag) - 0.9) < 0.02
m.eval()
y = m(x)
ymag = (y*y).mean()
print(f"xmag[eval] = {xmag}, ymag = {ymag}")
#assert abs((ymag / xmag) - 0.9) < 0.02
def _test_gauss_proj_drop():
x = torch.randn(30000, 384)
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
m1 = torch.nn.Dropout(dropout_rate)
m2 = GaussProjDrop(384, dropout_rate)
for mode in ['train', 'eval']:
y1 = m1(x)
y2 = m2(x)
xmag = (x*x).mean()
y1mag = (y1*y1).mean()
cross1 = (x*y1).mean()
y2mag = (y2*y2).mean()
cross2 = (x*y2).mean()
print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}")
m1.eval()
m2.eval()
if __name__ == "__main__":
_test_proj_drop()
_test_activation_balancer_sign()
_test_activation_balancer_magnitude()
_test_basic_norm()
_test_double_swish_deriv()
_test_gauss_proj_drop()
if False:
_test_activation_balancer_sign()
_test_activation_balancer_magnitude()
_test_basic_norm()
_test_double_swish_deriv()

View File

@ -29,7 +29,7 @@ from scaling import (
ScaledConv1d,
ScaledConv2d,
ScaledLinear,
RandProjDrop,
GaussProjDrop,
)
from torch import Tensor, nn
@ -197,7 +197,7 @@ class ConformerEncoderLayer(nn.Module):
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
)
self.dropout = RandProjDrop(d_model, dropout)
self.dropout = GaussProjDrop(d_model, dropout)
def forward(
self,
@ -369,7 +369,7 @@ class RelPositionalEncoding(torch.nn.Module):
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.dropout = RandProjDrop(d_model, dropout_rate)
self.dropout = GaussProjDrop(d_model, dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))