mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement GaussProjDrop
This commit is contained in:
parent
4352a16f57
commit
40a0934b4e
@ -640,7 +640,7 @@ class ScaledEmbedding(nn.Module):
|
|||||||
return s.format(**self.__dict__)
|
return s.format(**self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
class RandProjDrop(torch.nn.Module):
|
class GaussProjDrop(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions.
|
This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions.
|
||||||
The directions of dropout are fixed when the class is initialized, and are orthogonal.
|
The directions of dropout are fixed when the class is initialized, and are orthogonal.
|
||||||
@ -652,8 +652,13 @@ class RandProjDrop(torch.nn.Module):
|
|||||||
num_channels: int,
|
num_channels: int,
|
||||||
dropout_rate: float = 0.1,
|
dropout_rate: float = 0.1,
|
||||||
channel_dim: int = -1):
|
channel_dim: int = -1):
|
||||||
super(RandProjDrop, self).__init__()
|
super(GaussProjDrop, self).__init__()
|
||||||
self.dropout_rate = dropout_rate
|
self.dropout_rate = dropout_rate
|
||||||
|
# this formula for rand_scale was found empirically, trying to match the
|
||||||
|
# statistics of dropout in terms of cross-correlation with the input, see
|
||||||
|
# _test_gauss_proj_drop()
|
||||||
|
self.rand_scale = (dropout_rate / (1-dropout_rate)) ** 0.5 # * (num_channels ** -0.5)
|
||||||
|
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
|
|
||||||
rand_mat = torch.randn(num_channels, num_channels)
|
rand_mat = torch.randn(num_channels, num_channels)
|
||||||
@ -661,31 +666,52 @@ class RandProjDrop(torch.nn.Module):
|
|||||||
self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
|
self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
|
||||||
|
|
||||||
|
|
||||||
|
def _randperm_like(self, x: Tensor):
|
||||||
|
"""
|
||||||
|
Returns random permutations of the integers [0,1,..x.shape[-1]-1],
|
||||||
|
with the same shape as x. All dimensions of x other than the last dimension
|
||||||
|
will be treated as batch dimensions.
|
||||||
|
|
||||||
|
Torch's randperm does not support a batch dimension, so we pseudo-randomly simulate it.
|
||||||
|
|
||||||
|
For now, requires x.shape[-1] to be either a power of 2 or 3 times a power of 2, as
|
||||||
|
we normally set channel dims. This is required for some number theoretic stuff.
|
||||||
|
"""
|
||||||
|
n = x.shape[-1]
|
||||||
|
|
||||||
|
assert n & (n-1) == 0 or (n//3 & (n//3 - 1)) == 0
|
||||||
|
|
||||||
|
b = x.numel() // n
|
||||||
|
randint = random.randint(0, 1000)
|
||||||
|
perm = torch.randperm(n, device=x.device)
|
||||||
|
# ensure all elements of batch_rand are coprime to n; this will ensure
|
||||||
|
# that multiplying the permutation by batch_rand and taking modulo
|
||||||
|
# n leaves us with permutations.
|
||||||
|
batch_rand = torch.arange(b, device=x.device) * (randint * 6) + 1
|
||||||
|
batch_rand = batch_rand.unsqueeze(-1)
|
||||||
|
ans = (perm * batch_rand) % n
|
||||||
|
ans = ans.reshape(x.shape)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
if not self.training:
|
if not self.training:
|
||||||
x = torch.nn.functional.dropout(x, self.dropout_rate,
|
return x
|
||||||
training=False)
|
|
||||||
else:
|
else:
|
||||||
self._randomize_U()
|
|
||||||
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
|
||||||
U = self.U.clone()
|
|
||||||
x = torch.matmul(x, U)
|
|
||||||
x = torch.nn.functional.dropout(x, self.dropout_rate,
|
|
||||||
training=True)
|
|
||||||
x = torch.matmul(x, U.t())
|
|
||||||
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
||||||
|
x_bypass = x # will be used for "+ I"
|
||||||
|
perm = self._randperm_like(x)
|
||||||
|
x = torch.gather(x, -1, perm)
|
||||||
|
# self.U will act like a different matrix for every row of x, because of the random
|
||||||
|
# permutation.
|
||||||
|
x = torch.matmul(x, self.U)
|
||||||
|
x_next = torch.empty_like(x)
|
||||||
|
# scatter_ uses perm in opposite way
|
||||||
|
# from gather, inverting it.
|
||||||
|
x_next.scatter_(-1, perm, x)
|
||||||
|
x = (x_next * self.rand_scale + x_bypass)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _randomize_U(self):
|
|
||||||
dim = random.randint(0, 1)
|
|
||||||
U = self.U
|
|
||||||
num_channels = U.shape[0]
|
|
||||||
# pick place to split U in two pieces.
|
|
||||||
r = random.randint(1, num_channels - 2)
|
|
||||||
U_part1 = U.narrow(dim, 0, r)
|
|
||||||
U_part2 = U.narrow(dim, r, num_channels-r)
|
|
||||||
U = torch.cat((U_part2, U_part1), dim=dim)
|
|
||||||
self.U[:] = U
|
|
||||||
|
|
||||||
def _test_activation_balancer_sign():
|
def _test_activation_balancer_sign():
|
||||||
probs = torch.arange(0, 1, 0.01)
|
probs = torch.arange(0, 1, 0.01)
|
||||||
@ -760,22 +786,29 @@ def _test_double_swish_deriv():
|
|||||||
torch.autograd.gradcheck(m, x)
|
torch.autograd.gradcheck(m, x)
|
||||||
|
|
||||||
|
|
||||||
def _test_proj_drop():
|
def _test_gauss_proj_drop():
|
||||||
x = torch.randn(30000, 300)
|
x = torch.randn(30000, 384)
|
||||||
m = RandProjDrop(300, 0.1)
|
|
||||||
y = m(x)
|
|
||||||
|
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
||||||
|
m1 = torch.nn.Dropout(dropout_rate)
|
||||||
|
m2 = GaussProjDrop(384, dropout_rate)
|
||||||
|
for mode in ['train', 'eval']:
|
||||||
|
y1 = m1(x)
|
||||||
|
y2 = m2(x)
|
||||||
xmag = (x*x).mean()
|
xmag = (x*x).mean()
|
||||||
ymag = (y*y).mean()
|
y1mag = (y1*y1).mean()
|
||||||
print(f"xmag = {xmag}, ymag = {ymag}")
|
cross1 = (x*y1).mean()
|
||||||
#assert abs((ymag / xmag) - 0.9) < 0.02
|
y2mag = (y2*y2).mean()
|
||||||
m.eval()
|
cross2 = (x*y2).mean()
|
||||||
y = m(x)
|
print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}")
|
||||||
ymag = (y*y).mean()
|
m1.eval()
|
||||||
print(f"xmag[eval] = {xmag}, ymag = {ymag}")
|
m2.eval()
|
||||||
#assert abs((ymag / xmag) - 0.9) < 0.02
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_test_proj_drop()
|
_test_gauss_proj_drop()
|
||||||
|
if False:
|
||||||
_test_activation_balancer_sign()
|
_test_activation_balancer_sign()
|
||||||
_test_activation_balancer_magnitude()
|
_test_activation_balancer_magnitude()
|
||||||
_test_basic_norm()
|
_test_basic_norm()
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from scaling import (
|
|||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledConv2d,
|
ScaledConv2d,
|
||||||
ScaledLinear,
|
ScaledLinear,
|
||||||
RandProjDrop,
|
GaussProjDrop,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -197,7 +197,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dropout = RandProjDrop(d_model, dropout)
|
self.dropout = GaussProjDrop(d_model, dropout)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -369,7 +369,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
"""Construct an PositionalEncoding object."""
|
"""Construct an PositionalEncoding object."""
|
||||||
super(RelPositionalEncoding, self).__init__()
|
super(RelPositionalEncoding, self).__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.dropout = RandProjDrop(d_model, dropout_rate)
|
self.dropout = GaussProjDrop(d_model, dropout_rate)
|
||||||
self.pe = None
|
self.pe = None
|
||||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user