mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Introduce normalization..
This commit is contained in:
parent
e891a65735
commit
950cd4a3e8
@ -713,6 +713,25 @@ class GaussProjDrop(torch.nn.Module):
|
||||
x = (x_next * self.rand_scale + x_bypass)
|
||||
return x
|
||||
|
||||
class PseudoNormalizeFunction(torch.autograd.Function):
|
||||
"""
|
||||
Function object that is the identity function in the forward pass; and, in the
|
||||
backward pass, removes the component of the derivative in the direction of x itself
|
||||
(as if it had gone through some kind of normalization layer
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
ctx.save_for_backward(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad: Tensor) -> Tensor:
|
||||
x, = ctx.saved_tensors
|
||||
eps = 1.0e-20
|
||||
x_sumsq = (x**2).sum() + eps
|
||||
grad_x_sum = (x_grad * x).sum()
|
||||
return x_grad - x * (grad_x_sum / x_sumsq)
|
||||
|
||||
|
||||
def _compute_correlation_loss(cov: Tensor,
|
||||
eps: float) -> Tensor:
|
||||
@ -740,7 +759,9 @@ def _update_cov_stats(cov: Tensor,
|
||||
x: Tensor of features/activations, of shape (num_frames, num_channels)
|
||||
beta: The decay constant for the stats, e.g. 0.8.
|
||||
"""
|
||||
return cov * beta + torch.matmul(x.t(), x) * (1-beta)
|
||||
new_cov = torch.matmul(x.t(), x)
|
||||
new_cov = PseudoNormalizeFunction.apply(new_cov)
|
||||
return cov * beta + new_cov * (1-beta)
|
||||
|
||||
|
||||
class DecorrelateFunction(torch.autograd.Function):
|
||||
@ -889,6 +910,8 @@ class Decorrelate(torch.nn.Module):
|
||||
cov = torch.matmul(x.t(), x)
|
||||
with torch.no_grad():
|
||||
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
||||
m = self.cov.max()
|
||||
assert m == m
|
||||
|
||||
return ans # ans == x.
|
||||
|
||||
@ -987,6 +1010,14 @@ def _test_gauss_proj_drop():
|
||||
m1.eval()
|
||||
m2.eval()
|
||||
|
||||
def _test_pseudo_normalize():
|
||||
x = torch.randn(3, 4)
|
||||
x.requires_grad = True
|
||||
y = PseudoNormalizeFunction.apply(x)
|
||||
l = y.sin().sum()
|
||||
l.backward()
|
||||
assert (x.grad * x).sum().abs() < 0.1
|
||||
|
||||
def _test_decorrelate():
|
||||
D = 384
|
||||
x = torch.randn(30000, D)
|
||||
@ -1014,6 +1045,7 @@ if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_pseudo_normalize()
|
||||
_test_decorrelate()
|
||||
_test_gauss_proj_drop()
|
||||
_test_activation_balancer_sign()
|
||||
|
||||
@ -94,7 +94,7 @@ class Conformer(EncoderInterface):
|
||||
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
|
||||
)
|
||||
|
||||
self.decorrelate = Decorrelate(d_model, scale=0.1)
|
||||
self.decorrelate = Decorrelate(d_model, scale=0.05)
|
||||
|
||||
|
||||
def forward(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user