mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
A little code refactoring
This commit is contained in:
parent
bb1bee4a7b
commit
1a184596b6
@ -27,10 +27,7 @@ from scaling import (
|
|||||||
BasicNorm,
|
BasicNorm,
|
||||||
DoubleSwish,
|
DoubleSwish,
|
||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledConv2d,
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||||
ScaledLinear,
|
|
||||||
StructuredConv1d,
|
|
||||||
StructuredLinear,
|
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -1023,9 +1020,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
)
|
)
|
||||||
out_height = (((in_channels - 1) // 2 - 1) // 2)
|
out_height = (((in_channels - 1) // 2 - 1) // 2)
|
||||||
self.out = StructuredLinear(
|
self.out = nn.Linear(out_height * layer3_channels, out_channels)
|
||||||
(out_height, layer3_channels), (out_channels,)
|
|
||||||
)
|
|
||||||
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
# set learn_eps=False because out_norm is preceded by `out`, and `out`
|
||||||
# itself has learned scale, so the extra degree of freedom is not
|
# itself has learned scale, so the extra degree of freedom is not
|
||||||
# needed.
|
# needed.
|
||||||
|
|||||||
@ -314,11 +314,12 @@ class StructuredConv1d(nn.Conv1d):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def ScaledLinear(*args,
|
||||||
class ScaledLinear(nn.Linear):
|
initial_scale: float = 1.0,
|
||||||
|
**kwargs ) -> nn.Linear:
|
||||||
"""
|
"""
|
||||||
A modified version of nn.Linear that gives an easy way to set the
|
Behaves like a constructor of a modified version of nn.Linear
|
||||||
default initial parameter scale.
|
that gives an easy way to set the default initial parameter scale.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
Accepts the standard args and kwargs that nn.Linear accepts
|
Accepts the standard args and kwargs that nn.Linear accepts
|
||||||
@ -330,67 +331,42 @@ class ScaledLinear(nn.Linear):
|
|||||||
Another option, if you want to do something like this, is
|
Another option, if you want to do something like this, is
|
||||||
to re-initialize the parameters.
|
to re-initialize the parameters.
|
||||||
"""
|
"""
|
||||||
|
ans = nn.Linear(*args, **kwargs)
|
||||||
def __init__(
|
with torch.no_grad():
|
||||||
self,
|
ans.weight[:] *= initial_scale
|
||||||
*args,
|
if ans.bias is not None:
|
||||||
initial_scale: float = 1.0,
|
torch.nn.init.uniform_(ans.bias,
|
||||||
**kwargs
|
-0.1 * initial_scale,
|
||||||
):
|
0.1 * initial_scale)
|
||||||
super(ScaledLinear, self).__init__(*args, **kwargs)
|
return ans
|
||||||
with torch.no_grad():
|
|
||||||
self.weight[:] *= initial_scale
|
|
||||||
if self.bias is not None:
|
|
||||||
torch.nn.init.uniform_(self.bias,
|
|
||||||
-0.1 * initial_scale,
|
|
||||||
0.1 * initial_scale)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledConv1d(nn.Conv1d):
|
def ScaledConv1d(*args,
|
||||||
# See docs for ScaledLinear
|
initial_scale: float = 1.0,
|
||||||
def __init__(
|
**kwargs ) -> nn.Linear:
|
||||||
self,
|
"""
|
||||||
*args,
|
Behaves like a constructor of a modified version of nn.Conv1d
|
||||||
initial_scale: float = 1.0,
|
that gives an easy way to set the default initial parameter scale.
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
|
||||||
with torch.no_grad():
|
|
||||||
self.weight[:] *= initial_scale
|
|
||||||
if self.bias is not None:
|
|
||||||
torch.nn.init.uniform_(self.bias,
|
|
||||||
-0.1 * initial_scale,
|
|
||||||
0.1 * initial_scale)
|
|
||||||
|
|
||||||
def get_weight(self): # TODO: delete
|
Args:
|
||||||
return self.weight
|
Accepts the standard args and kwargs that nn.Linear accepts
|
||||||
|
e.g. in_features, out_features, bias=False.
|
||||||
|
|
||||||
def get_bias(self): # TODO: delete
|
initial_scale: you can override this if you want to increase
|
||||||
return self.bias
|
or decrease the initial magnitude of the module's output
|
||||||
|
(affects the initialization of weight_scale and bias_scale).
|
||||||
|
Another option, if you want to do something like this, is
|
||||||
class ScaledConv2d(nn.Conv2d):
|
to re-initialize the parameters.
|
||||||
# See docs for ScaledLinear
|
"""
|
||||||
def __init__(
|
ans = nn.Conv1d(*args, **kwargs)
|
||||||
self,
|
with torch.no_grad():
|
||||||
*args,
|
ans.weight[:] *= initial_scale
|
||||||
initial_scale: float = 1.0,
|
if ans.bias is not None:
|
||||||
**kwargs
|
torch.nn.init.uniform_(ans.bias,
|
||||||
):
|
-0.1 * initial_scale,
|
||||||
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
0.1 * initial_scale)
|
||||||
with torch.no_grad():
|
return ans
|
||||||
self.weight[:] *= initial_scale
|
|
||||||
if self.bias is not None:
|
|
||||||
torch.nn.init.uniform_(self.bias,
|
|
||||||
-0.1 * initial_scale,
|
|
||||||
0.1 * initial_scale)
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
return self.weight
|
|
||||||
|
|
||||||
def get_bias(self):
|
|
||||||
return self.bias
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -497,80 +473,6 @@ class DoubleSwish(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GaussProjDrop(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
This has an effect similar to torch.nn.Dropout, but does not privilege the on-axis directions.
|
|
||||||
The directions of dropout are fixed when the class is initialized, and are orthogonal.
|
|
||||||
|
|
||||||
dropout_rate: the dropout probability (actually will define the number of zeroed-out directions)
|
|
||||||
channel_dim: the axis corresponding to the channel, e.g. -1, 0, 1, 2.
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
num_channels: int,
|
|
||||||
dropout_rate: float = 0.1,
|
|
||||||
channel_dim: int = -1):
|
|
||||||
super(GaussProjDrop, self).__init__()
|
|
||||||
self.dropout_rate = dropout_rate
|
|
||||||
# this formula for rand_scale was found empirically, trying to match the
|
|
||||||
# statistics of dropout in terms of cross-correlation with the input, see
|
|
||||||
# _test_gauss_proj_drop()
|
|
||||||
self.rand_scale = (dropout_rate / (1-dropout_rate)) ** 0.5 # * (num_channels ** -0.5)
|
|
||||||
|
|
||||||
self.channel_dim = channel_dim
|
|
||||||
|
|
||||||
rand_mat = torch.randn(num_channels, num_channels)
|
|
||||||
U, _, _ = rand_mat.svd()
|
|
||||||
self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
|
|
||||||
|
|
||||||
|
|
||||||
def _randperm_like(self, x: Tensor):
|
|
||||||
"""
|
|
||||||
Returns random permutations of the integers [0,1,..x.shape[-1]-1],
|
|
||||||
with the same shape as x. All dimensions of x other than the last dimension
|
|
||||||
will be treated as batch dimensions.
|
|
||||||
|
|
||||||
Torch's randperm does not support a batch dimension, so we pseudo-randomly simulate it.
|
|
||||||
|
|
||||||
For now, requires x.shape[-1] to be either a power of 2 or 3 times a power of 2, as
|
|
||||||
we normally set channel dims. This is required for some number theoretic stuff.
|
|
||||||
"""
|
|
||||||
n = x.shape[-1]
|
|
||||||
|
|
||||||
assert n & (n-1) == 0 or (n//3 & (n//3 - 1)) == 0
|
|
||||||
|
|
||||||
b = x.numel() // n
|
|
||||||
randint = random.randint(0, 1000)
|
|
||||||
perm = torch.randperm(n, device=x.device)
|
|
||||||
# ensure all elements of batch_rand are coprime to n; this will ensure
|
|
||||||
# that multiplying the permutation by batch_rand and taking modulo
|
|
||||||
# n leaves us with permutations.
|
|
||||||
batch_rand = torch.arange(b, device=x.device) * (randint * 6) + 1
|
|
||||||
batch_rand = batch_rand.unsqueeze(-1)
|
|
||||||
ans = (perm * batch_rand) % n
|
|
||||||
ans = ans.reshape(x.shape)
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
if not self.training:
|
|
||||||
return x
|
|
||||||
else:
|
|
||||||
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
|
||||||
x_bypass = x # will be used for "+ I"
|
|
||||||
perm = self._randperm_like(x)
|
|
||||||
x = torch.gather(x, -1, perm)
|
|
||||||
# self.U will act like a different matrix for every row of x, because of the random
|
|
||||||
# permutation.
|
|
||||||
x = torch.matmul(x, self.U)
|
|
||||||
x_next = torch.empty_like(x)
|
|
||||||
# scatter_ uses perm in opposite way
|
|
||||||
# from gather, inverting it.
|
|
||||||
x_next.scatter_(-1, perm, x)
|
|
||||||
x = (x_next * self.rand_scale + x_bypass)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_activation_balancer_sign():
|
def _test_activation_balancer_sign():
|
||||||
probs = torch.arange(0, 1, 0.01)
|
probs = torch.arange(0, 1, 0.01)
|
||||||
@ -644,52 +546,11 @@ def _test_double_swish_deriv():
|
|||||||
m = DoubleSwish()
|
m = DoubleSwish()
|
||||||
torch.autograd.gradcheck(m, x)
|
torch.autograd.gradcheck(m, x)
|
||||||
|
|
||||||
def _test_structured_linear():
|
|
||||||
m = StructuredLinear((2, 100), (3, 100), bias=True)
|
|
||||||
assert m.weight.shape == (3, 100, 2, 100)
|
|
||||||
assert m.bias.shape == (3, 100)
|
|
||||||
x = torch.randn(50, 200)
|
|
||||||
y = m(x)
|
|
||||||
assert y.shape == (50, 300)
|
|
||||||
|
|
||||||
def _test_structured_conv1d():
|
|
||||||
m = StructuredConv1d((2, 100), (3, 100), kernel_size=3, padding=1, bias=True)
|
|
||||||
assert m.weight.shape == (3, 100, 2, 100, 3)
|
|
||||||
assert m.bias.shape == (3, 100)
|
|
||||||
T = 39
|
|
||||||
x = torch.randn(50, 200, T)
|
|
||||||
y = m(x)
|
|
||||||
assert y.shape == (50, 300, T)
|
|
||||||
|
|
||||||
def _test_gauss_proj_drop():
|
|
||||||
D = 384
|
|
||||||
x = torch.randn(30000, D)
|
|
||||||
|
|
||||||
|
|
||||||
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
|
||||||
m1 = torch.nn.Dropout(dropout_rate)
|
|
||||||
m2 = GaussProjDrop(D, dropout_rate)
|
|
||||||
for mode in ['train', 'eval']:
|
|
||||||
y1 = m1(x)
|
|
||||||
y2 = m2(x)
|
|
||||||
xmag = (x*x).mean()
|
|
||||||
y1mag = (y1*y1).mean()
|
|
||||||
cross1 = (x*y1).mean()
|
|
||||||
y2mag = (y2*y2).mean()
|
|
||||||
cross2 = (x*y2).mean()
|
|
||||||
print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}")
|
|
||||||
m1.eval()
|
|
||||||
m2.eval()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
_test_structured_linear()
|
|
||||||
_test_structured_conv1d()
|
|
||||||
_test_gauss_proj_drop()
|
|
||||||
_test_activation_balancer_sign()
|
_test_activation_balancer_sign()
|
||||||
_test_activation_balancer_magnitude()
|
_test_activation_balancer_magnitude()
|
||||||
_test_basic_norm()
|
_test_basic_norm()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user