mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Implement GaussProjDrop
This commit is contained in:
parent
4352a16f57
commit
40a0934b4e
@ -640,7 +640,7 @@ class ScaledEmbedding(nn.Module):
|
||||
return s.format(**self.__dict__)
|
||||
|
||||
|
||||
class RandProjDrop(torch.nn.Module):
|
||||
class GaussProjDrop(torch.nn.Module):
|
||||
"""
|
||||
This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions.
|
||||
The directions of dropout are fixed when the class is initialized, and are orthogonal.
|
||||
@ -652,8 +652,13 @@ class RandProjDrop(torch.nn.Module):
|
||||
num_channels: int,
|
||||
dropout_rate: float = 0.1,
|
||||
channel_dim: int = -1):
|
||||
super(RandProjDrop, self).__init__()
|
||||
super(GaussProjDrop, self).__init__()
|
||||
self.dropout_rate = dropout_rate
|
||||
# this formula for rand_scale was found empirically, trying to match the
|
||||
# statistics of dropout in terms of cross-correlation with the input, see
|
||||
# _test_gauss_proj_drop()
|
||||
self.rand_scale = (dropout_rate / (1-dropout_rate)) ** 0.5 # * (num_channels ** -0.5)
|
||||
|
||||
self.channel_dim = channel_dim
|
||||
|
||||
rand_mat = torch.randn(num_channels, num_channels)
|
||||
@ -661,31 +666,52 @@ class RandProjDrop(torch.nn.Module):
|
||||
self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
|
||||
|
||||
|
||||
def _randperm_like(self, x: Tensor):
|
||||
"""
|
||||
Returns random permutations of the integers [0,1,..x.shape[-1]-1],
|
||||
with the same shape as x. All dimensions of x other than the last dimension
|
||||
will be treated as batch dimensions.
|
||||
|
||||
Torch's randperm does not support a batch dimension, so we pseudo-randomly simulate it.
|
||||
|
||||
For now, requires x.shape[-1] to be either a power of 2 or 3 times a power of 2, as
|
||||
we normally set channel dims. This is required for some number theoretic stuff.
|
||||
"""
|
||||
n = x.shape[-1]
|
||||
|
||||
assert n & (n-1) == 0 or (n//3 & (n//3 - 1)) == 0
|
||||
|
||||
b = x.numel() // n
|
||||
randint = random.randint(0, 1000)
|
||||
perm = torch.randperm(n, device=x.device)
|
||||
# ensure all elements of batch_rand are coprime to n; this will ensure
|
||||
# that multiplying the permutation by batch_rand and taking modulo
|
||||
# n leaves us with permutations.
|
||||
batch_rand = torch.arange(b, device=x.device) * (randint * 6) + 1
|
||||
batch_rand = batch_rand.unsqueeze(-1)
|
||||
ans = (perm * batch_rand) % n
|
||||
ans = ans.reshape(x.shape)
|
||||
return ans
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if not self.training:
|
||||
x = torch.nn.functional.dropout(x, self.dropout_rate,
|
||||
training=False)
|
||||
return x
|
||||
else:
|
||||
self._randomize_U()
|
||||
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
||||
U = self.U.clone()
|
||||
x = torch.matmul(x, U)
|
||||
x = torch.nn.functional.dropout(x, self.dropout_rate,
|
||||
training=True)
|
||||
x = torch.matmul(x, U.t())
|
||||
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
||||
return x
|
||||
x_bypass = x # will be used for "+ I"
|
||||
perm = self._randperm_like(x)
|
||||
x = torch.gather(x, -1, perm)
|
||||
# self.U will act like a different matrix for every row of x, because of the random
|
||||
# permutation.
|
||||
x = torch.matmul(x, self.U)
|
||||
x_next = torch.empty_like(x)
|
||||
# scatter_ uses perm in opposite way
|
||||
# from gather, inverting it.
|
||||
x_next.scatter_(-1, perm, x)
|
||||
x = (x_next * self.rand_scale + x_bypass)
|
||||
return x
|
||||
|
||||
def _randomize_U(self):
|
||||
dim = random.randint(0, 1)
|
||||
U = self.U
|
||||
num_channels = U.shape[0]
|
||||
# pick place to split U in two pieces.
|
||||
r = random.randint(1, num_channels - 2)
|
||||
U_part1 = U.narrow(dim, 0, r)
|
||||
U_part2 = U.narrow(dim, r, num_channels-r)
|
||||
U = torch.cat((U_part2, U_part1), dim=dim)
|
||||
self.U[:] = U
|
||||
|
||||
def _test_activation_balancer_sign():
|
||||
probs = torch.arange(0, 1, 0.01)
|
||||
@ -760,23 +786,30 @@ def _test_double_swish_deriv():
|
||||
torch.autograd.gradcheck(m, x)
|
||||
|
||||
|
||||
def _test_proj_drop():
|
||||
x = torch.randn(30000, 300)
|
||||
m = RandProjDrop(300, 0.1)
|
||||
y = m(x)
|
||||
xmag = (x*x).mean()
|
||||
ymag = (y*y).mean()
|
||||
print(f"xmag = {xmag}, ymag = {ymag}")
|
||||
#assert abs((ymag / xmag) - 0.9) < 0.02
|
||||
m.eval()
|
||||
y = m(x)
|
||||
ymag = (y*y).mean()
|
||||
print(f"xmag[eval] = {xmag}, ymag = {ymag}")
|
||||
#assert abs((ymag / xmag) - 0.9) < 0.02
|
||||
def _test_gauss_proj_drop():
|
||||
x = torch.randn(30000, 384)
|
||||
|
||||
|
||||
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
||||
m1 = torch.nn.Dropout(dropout_rate)
|
||||
m2 = GaussProjDrop(384, dropout_rate)
|
||||
for mode in ['train', 'eval']:
|
||||
y1 = m1(x)
|
||||
y2 = m2(x)
|
||||
xmag = (x*x).mean()
|
||||
y1mag = (y1*y1).mean()
|
||||
cross1 = (x*y1).mean()
|
||||
y2mag = (y2*y2).mean()
|
||||
cross2 = (x*y2).mean()
|
||||
print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}")
|
||||
m1.eval()
|
||||
m2.eval()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_proj_drop()
|
||||
_test_activation_balancer_sign()
|
||||
_test_activation_balancer_magnitude()
|
||||
_test_basic_norm()
|
||||
_test_double_swish_deriv()
|
||||
_test_gauss_proj_drop()
|
||||
if False:
|
||||
_test_activation_balancer_sign()
|
||||
_test_activation_balancer_magnitude()
|
||||
_test_basic_norm()
|
||||
_test_double_swish_deriv()
|
||||
|
@ -29,7 +29,7 @@ from scaling import (
|
||||
ScaledConv1d,
|
||||
ScaledConv2d,
|
||||
ScaledLinear,
|
||||
RandProjDrop,
|
||||
GaussProjDrop,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -197,7 +197,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||
)
|
||||
|
||||
self.dropout = RandProjDrop(d_model, dropout)
|
||||
self.dropout = GaussProjDrop(d_model, dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -369,7 +369,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.dropout = RandProjDrop(d_model, dropout_rate)
|
||||
self.dropout = GaussProjDrop(d_model, dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user