mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
Various bug fixes
This commit is contained in:
parent
40a0934b4e
commit
cd6b707e2b
@ -713,6 +713,100 @@ class GaussProjDrop(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Decorrelate(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
This module is something similar to dropout; it is a random transformation that
|
||||||
|
does nothing in eval mode.
|
||||||
|
It is designed specifically to encourage the input data to be decorrelated, i.e.
|
||||||
|
to have a diagonal covariance matrix (not necessarily unity).
|
||||||
|
|
||||||
|
To save time, in training mode we only apply it on randomly selected minibatches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_channels: The number of channels, e.g. 256.
|
||||||
|
apply_prob: The probability with which we apply this each time, in
|
||||||
|
training mode. This is to save time (but of course it
|
||||||
|
will tend to make the effect weaker).
|
||||||
|
dropout_rate: This number determines the scale of the random multiplicative
|
||||||
|
noise, in such a way that the self-correlation and cross-correlation
|
||||||
|
statistics match those dropout with the same `dropout_rate`
|
||||||
|
(assuming we applied the transform, e.g. if apply_prob == 1.0)
|
||||||
|
eps: An epsilon used to prevent division by zero.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
apply_prob: float = 0.25,
|
||||||
|
dropout_rate: float = 0.1,
|
||||||
|
eps: float = 1.0e-04,
|
||||||
|
channel_dim: int = -1):
|
||||||
|
super(Decorrelate, self).__init__()
|
||||||
|
self.apply_prob = apply_prob
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.channel_dim = channel_dim
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def _get_covar(self, x: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Returns the uncentered covariance matrix associated with feature matrix x, detached
|
||||||
|
from its input.
|
||||||
|
Args:
|
||||||
|
x: Tensor of shape (*, num_channels)
|
||||||
|
Returns:
|
||||||
|
Covariance matrix `cov`, of shape (num_channels, num_channels)
|
||||||
|
"""
|
||||||
|
x = x.detach()
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
x = x * (x.shape[0] ** -0.5) # avoid overflow in half precision
|
||||||
|
return torch.matmul(x.t(), x)
|
||||||
|
|
||||||
|
def _normalize_covar(self, cov: Tensor, eps: float) -> Tensor:
|
||||||
|
"""
|
||||||
|
Normlizes a covariance matrix so that its diagonal is 1, by multiplying by
|
||||||
|
its diagonal**-0.5 on both sides.
|
||||||
|
Args:
|
||||||
|
cov: matrix to normalize
|
||||||
|
eps: floating point value >0, used to prevent division by zero.
|
||||||
|
"""
|
||||||
|
diag = cov.diag()
|
||||||
|
inv_sqrt_diag = (diag + eps) ** -0.5
|
||||||
|
cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
||||||
|
return cov
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
if not self.training or random.random() > self.apply_prob:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
||||||
|
x_bypass = x # will be used for "+ I"
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
cov = self._get_covar(x)
|
||||||
|
cov = self._normalize_covar(cov, self.eps)
|
||||||
|
avg_squared_eig = (cov**2).sum(dim=0).mean()
|
||||||
|
|
||||||
|
# the odd-looking formula below was obtained empirically, to match
|
||||||
|
# the self-product and cross-correlation statistics of dropout
|
||||||
|
rand_scale = ((self.dropout_rate / (1.0 - self.dropout_rate)) ** 0.5) / avg_squared_eig
|
||||||
|
|
||||||
|
# by multiplying by `cov`, then randomizing the sign of elements, then
|
||||||
|
# multiplying by `cov` again, we are generating something that has
|
||||||
|
# more noise in directions corresponding to larger eigenvlues of `cov`.
|
||||||
|
# (Actually we scale by the square of the eigenvalue, which is not very
|
||||||
|
# desirable, but was easy to implement in a fast way
|
||||||
|
x = torch.matmul(x * rand_scale, cov)
|
||||||
|
rand_mask = (torch.rand_like(x) > 0.5)
|
||||||
|
# randomize the sign of elements of x.
|
||||||
|
# important to write the expression this way, so that only rand_mask needs
|
||||||
|
# to be stored for backprop.
|
||||||
|
x = x - 2 * (rand_mask * x)
|
||||||
|
x = torch.matmul(x, cov)
|
||||||
|
x = x + x_bypass
|
||||||
|
x = x.transpose(self.channel_dim, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_activation_balancer_sign():
|
def _test_activation_balancer_sign():
|
||||||
probs = torch.arange(0, 1, 0.01)
|
probs = torch.arange(0, 1, 0.01)
|
||||||
N = 1000
|
N = 1000
|
||||||
@ -805,10 +899,30 @@ def _test_gauss_proj_drop():
|
|||||||
m1.eval()
|
m1.eval()
|
||||||
m2.eval()
|
m2.eval()
|
||||||
|
|
||||||
|
def _test_decorrelate():
|
||||||
|
x = torch.randn(30000, 384)
|
||||||
|
|
||||||
|
|
||||||
|
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
||||||
|
m1 = torch.nn.Dropout(dropout_rate)
|
||||||
|
m2 = Decorrelate(apply_prob=1.0, rand_scale=dropout_rate)
|
||||||
|
for mode in ['train', 'eval']:
|
||||||
|
y1 = m1(x)
|
||||||
|
y2 = m2(x)
|
||||||
|
xmag = (x*x).mean()
|
||||||
|
y1mag = (y1*y1).mean()
|
||||||
|
cross1 = (x*y1).mean()
|
||||||
|
y2mag = (y2*y2).mean()
|
||||||
|
cross2 = (x*y2).mean()
|
||||||
|
print(f"rate={dropout_rate}, mode={mode}, xmag = {xmag}, y1mag = {y1mag}, y2mag = {y2mag}, cross1={cross1}, cross2={cross2}")
|
||||||
|
m1.eval()
|
||||||
|
m2.eval()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_test_gauss_proj_drop()
|
_test_decorrelate()
|
||||||
if False:
|
if False:
|
||||||
|
_test_gauss_proj_drop()
|
||||||
_test_activation_balancer_sign()
|
_test_activation_balancer_sign()
|
||||||
_test_activation_balancer_magnitude()
|
_test_activation_balancer_magnitude()
|
||||||
_test_basic_norm()
|
_test_basic_norm()
|
||||||
|
@ -29,7 +29,8 @@ from scaling import (
|
|||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledConv2d,
|
ScaledConv2d,
|
||||||
ScaledLinear,
|
ScaledLinear,
|
||||||
GaussProjDrop,
|
Decorrelate,
|
||||||
|
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -197,7 +198,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dropout = GaussProjDrop(d_model, dropout)
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
|
self.decorrelate = Decorrelate(apply_prob=0.25, dropout_rate=0.05)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -259,6 +262,10 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.dropout(self.feed_forward(src))
|
src = src + self.dropout(self.feed_forward(src))
|
||||||
|
|
||||||
|
# encourage dimensions of `src` to be un-correlated with each other, this will
|
||||||
|
# help Adam converge better.
|
||||||
|
src = self.decorrelate(src)
|
||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
if alpha != 1.0:
|
if alpha != 1.0:
|
||||||
@ -369,7 +376,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
"""Construct an PositionalEncoding object."""
|
"""Construct an PositionalEncoding object."""
|
||||||
super(RelPositionalEncoding, self).__init__()
|
super(RelPositionalEncoding, self).__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.dropout = GaussProjDrop(d_model, dropout_rate)
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||||
self.pe = None
|
self.pe = None
|
||||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user