mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Reorganize Whiten() code; configs are not the same as before. Also remove MaxEig for self_attn module
This commit is contained in:
parent
9919a05612
commit
fc728f2738
@ -31,6 +31,8 @@ from scaling import (
|
||||
DoubleSwish,
|
||||
ScaledConv1d,
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
Whiten,
|
||||
_diag,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -801,129 +803,6 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
return self.dropout(pos_emb)
|
||||
|
||||
|
||||
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
||||
if x.ndim == 2:
|
||||
return x.diag()
|
||||
else:
|
||||
(batch, dim, dim) = x.shape
|
||||
x = x.reshape(batch, dim * dim)
|
||||
x = x[:, ::dim+1]
|
||||
assert x.shape == (batch, dim)
|
||||
return x
|
||||
|
||||
class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
x: Tensor,
|
||||
whitening_limit: float,
|
||||
grad_scale: float) -> Tensor:
|
||||
ctx.save_for_backward(x)
|
||||
ctx.whitening_limit = whitening_limit
|
||||
ctx.grad_scale = grad_scale
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx,
|
||||
x_grad: Tensor):
|
||||
x_orig, = ctx.saved_tensors
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x_detached = x_orig.to(torch.float32).detach()
|
||||
x_detached.requires_grad = True
|
||||
assert x_detached.ndim >= 3
|
||||
x = x_detached.reshape(-1, x_detached.shape[-2],
|
||||
x_detached.shape[-1]).transpose(0, 1)
|
||||
(num_groups, num_frames, channels_per_group) = x.shape
|
||||
|
||||
# subtract the mean so we use the centered, not uncentered, covariance.
|
||||
# My experience has been that when we "mess with the gradients" like this,
|
||||
# it's better not do anything that tries to move the mean around, because
|
||||
# that can easily cause instability.
|
||||
x = x - x.mean(dim=1, keepdim=True)
|
||||
|
||||
# x_covar: (num_groups, channels_per_group, channels_per_group)
|
||||
x_covar = torch.matmul(x.transpose(1, 2), x)
|
||||
# normalize x_covar so that its average diagonal element is 1.
|
||||
x_covar = x_covar / (_diag(x_covar).mean() + 1.0e-20)
|
||||
# x_covar_sq: (num_groups, channels_per_group, channels_per_group).
|
||||
# if the normalized x_covar were just `num_groups` copies of the
|
||||
# identity matrix, x_covar_sq will have the same value. But
|
||||
# in general, it will be larger than that.
|
||||
x_covar_sq = torch.matmul(x_covar, x_covar)
|
||||
|
||||
metric = _diag(x_covar_sq).mean()
|
||||
|
||||
if random.random() < 0.005 or __name__ == "__main__":
|
||||
logging.info(f"Whitening: num_groups={num_groups}, channels_per_group={channels_per_group}, "
|
||||
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}")
|
||||
|
||||
(metric - ctx.whitening_limit).relu().backward()
|
||||
penalty_grad = x_detached.grad
|
||||
scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() /
|
||||
(penalty_grad.norm() + 1.0e-20))
|
||||
penalty_grad = penalty_grad * scale
|
||||
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
|
||||
|
||||
|
||||
|
||||
|
||||
class Whiten(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
whitening_limit: float,
|
||||
prob: float,
|
||||
grad_scale: float):
|
||||
"""
|
||||
Args:
|
||||
num_groups: the number of groups to divide the input into before
|
||||
whitening it. We will attempt to make the feature covariance
|
||||
within each group, after mean subtraction, as "white" as possible
|
||||
while having the same trace across all groups.
|
||||
whitening_limit: a value greater than 1.0, that dictates how much
|
||||
freedom we have to violate the constraints. 1.0 would mean perfectly
|
||||
white, with exactly the same trace across groups; larger values
|
||||
give more freedom. E.g. 2.0.
|
||||
prob: the probability with which we apply this object (also affects
|
||||
grad scale). e.g. 0.25
|
||||
grad_scale: determines the scale on the gradient term from this object,
|
||||
relative to the rest of the gradient on the attention weights;
|
||||
will be divided by `prob`. e.g. 0.005
|
||||
"""
|
||||
super(Whiten, self).__init__()
|
||||
assert whitening_limit >= 1
|
||||
assert 0 < prob <= 1
|
||||
assert grad_scale >= 0
|
||||
self.whitening_limit = whitening_limit
|
||||
self.prob = prob
|
||||
self.grad_scale = grad_scale
|
||||
|
||||
def forward(self,
|
||||
x: Tensor) -> Tensor:
|
||||
"""
|
||||
In the forward pass, this function just returns the input unmodified.
|
||||
In the backward pass, it will modify the gradients to ensure that the
|
||||
distribution in each group has close to (lambda times I) as the covariance
|
||||
after mean subtraction, with the same lambda across groups.
|
||||
For whitening_limit > 1, there will be more freedom to violate this
|
||||
constraint.
|
||||
|
||||
Args:
|
||||
x: the input of shape (*, num_groups, channels_per_group)
|
||||
|
||||
Returns:
|
||||
x, unmodified. You should make sure
|
||||
you use the returned value, or the graph will be freed
|
||||
and nothing will happen in backprop.
|
||||
"""
|
||||
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
|
||||
return x
|
||||
else:
|
||||
return WhiteningPenaltyFunction.apply(x,
|
||||
self.whitening_limit,
|
||||
self.grad_scale / self.prob)
|
||||
|
||||
|
||||
|
||||
|
||||
class RelPositionMultiheadAttention(nn.Module):
|
||||
r"""Multi-Head Attention layer with relative position encoding
|
||||
@ -958,20 +837,20 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
|
||||
|
||||
# self.whiten is applied on the values in forward()
|
||||
self.whiten_values = Whiten(whitening_limit=1.1,
|
||||
prob=1.0 if __name__ == "__main__" else 0.1,
|
||||
grad_scale=0.0025)
|
||||
# self.whiten_values is applied on the values in forward()
|
||||
self.whiten_values = Whiten(num_groups=num_heads,
|
||||
whitening_limit=1.1,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.025)
|
||||
# self.whiten_keys is applied on the keys in forward()
|
||||
self.whiten_keys = Whiten(whitening_limit=1.1,
|
||||
prob=1.0 if __name__ == "__main__" else 0.1,
|
||||
grad_scale=0.0025)
|
||||
self.whiten_keys = Whiten(num_groups=num_heads,
|
||||
whitening_limit=1.1,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.025)
|
||||
|
||||
|
||||
self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
|
||||
channel_dim=-1, max_abs=5.0)
|
||||
self.in_max_eig = MaxEig(3 * embed_dim // 2,
|
||||
channel_dim=-1)
|
||||
self.out_proj = ScaledLinear(
|
||||
embed_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
||||
)
|
||||
@ -980,10 +859,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
|
||||
initial_scale=0.05)
|
||||
# self.whiten_values2 is applied on the values in forward2()
|
||||
self.whiten_values2 = Whiten(whitening_limit=1.1,
|
||||
prob=1.0 if __name__ == "__main__" else 0.1,
|
||||
grad_scale=0.0025)
|
||||
|
||||
self.whiten_values2 = Whiten(num_groups=num_heads,
|
||||
whitening_limit=1.1,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.025)
|
||||
|
||||
# linear transformation for positional encoding (projects to a scalar per head,
|
||||
# which will be added to the score).
|
||||
@ -1037,7 +916,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
and S is the sequence length.
|
||||
"""
|
||||
x, weights = self.multi_head_attention_forward(
|
||||
self.in_max_eig(self.in_balancer(self.in_proj(x))),
|
||||
self.in_balancer(self.in_proj(x)),
|
||||
self.linear_pos(pos_emb),
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
@ -1155,6 +1034,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# self-attention
|
||||
q, k, v = x.chunk(3, dim=-1)
|
||||
|
||||
k = self.whiten_keys(k) # does nothing in the forward pass.
|
||||
v = self.whiten_values(v) # does nothing in the forward pass.
|
||||
|
||||
if attn_mask is not None:
|
||||
assert (
|
||||
@ -1207,11 +1088,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim)
|
||||
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
|
||||
k = self.whiten_keys(k) # does nothing in the forward pass.
|
||||
v = v.contiguous().view(-1, bsz, num_heads, head_dim)
|
||||
v = self.whiten_values(v) # does nothing in the forward pass.
|
||||
v = v.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
||||
@ -1297,7 +1174,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
head_dim = embed_dim // (num_heads * 2)
|
||||
# v: (tgt_len, bsz, embed_dim // 2)
|
||||
v = self.in_proj2(x)
|
||||
v = v.contiguous().view(-1, bsz, num_heads, head_dim)
|
||||
v = self.whiten_values2(v) # does nothing in the forward pass.
|
||||
v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
|
||||
@ -421,6 +421,172 @@ class ActivationBalancer(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
|
||||
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
||||
if x.ndim == 2:
|
||||
return x.diag()
|
||||
else:
|
||||
(batch, dim, dim) = x.shape
|
||||
x = x.reshape(batch, dim * dim)
|
||||
x = x[:, ::dim+1]
|
||||
assert x.shape == (batch, dim)
|
||||
return x
|
||||
|
||||
|
||||
def _whitening_metric(x: Tensor,
|
||||
num_groups: int):
|
||||
"""
|
||||
Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
|
||||
of the centered feature covariance are the same within each group's covariance matrix
|
||||
and also between groups.
|
||||
Args:
|
||||
x: a Tensor of shape (*, num_channels)
|
||||
num_groups: the number of groups of channels, a number >=1 that divides num_channels
|
||||
Returns:
|
||||
Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
|
||||
greater than 1.0 otherwise.
|
||||
"""
|
||||
assert x.dtype != torch.float16
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
(num_frames, num_channels) = x.shape
|
||||
assert num_channels % num_groups == 0
|
||||
channels_per_group = num_channels // num_groups
|
||||
x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
|
||||
# x now has shape (num_groups, num_frames, channels_per_group)
|
||||
# subtract the mean so we use the centered, not uncentered, covariance.
|
||||
# My experience has been that when we "mess with the gradients" like this,
|
||||
# it's better not do anything that tries to move the mean around, because
|
||||
# that can easily cause instability.
|
||||
x = x - x.mean(dim=1, keepdim=True)
|
||||
# x_covar: (num_groups, channels_per_group, channels_per_group)
|
||||
x_covar = torch.matmul(x.transpose(1, 2), x)
|
||||
x_covar_mean_diag = _diag(x_covar).mean()
|
||||
# the following expression is what we'd get if we took the matrix product
|
||||
# of each covariance and measured the mean of its trace, i.e.
|
||||
# the same as _diag(torch.matmul(x_covar, x_covar)).mean().
|
||||
x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group)
|
||||
# this metric will be >= 1.0; the larger it is, the less 'white' the data was.
|
||||
metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
|
||||
return metric
|
||||
|
||||
|
||||
class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
x: Tensor,
|
||||
num_groups: int,
|
||||
whitening_limit: float,
|
||||
grad_scale: float) -> Tensor:
|
||||
ctx.save_for_backward(x)
|
||||
ctx.num_groups = num_groups
|
||||
ctx.whitening_limit = whitening_limit
|
||||
ctx.grad_scale = grad_scale
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx,
|
||||
x_grad: Tensor):
|
||||
x_orig, = ctx.saved_tensors
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x_detached = x_orig.to(torch.float32).detach()
|
||||
x_detached.requires_grad = True
|
||||
|
||||
metric = _whitening_metric(x_detached, ctx.num_groups)
|
||||
|
||||
if random.random() < 0.005 or __name__ == "__main__":
|
||||
logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
|
||||
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}")
|
||||
|
||||
(metric - ctx.whitening_limit).relu().backward()
|
||||
penalty_grad = x_detached.grad
|
||||
scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() /
|
||||
(penalty_grad.norm() + 1.0e-20))
|
||||
penalty_grad = penalty_grad * scale
|
||||
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
|
||||
|
||||
|
||||
|
||||
class Whiten(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_groups: int,
|
||||
whitening_limit: float,
|
||||
prob: Union[float, Tuple[float,float]],
|
||||
grad_scale: float):
|
||||
"""
|
||||
Args:
|
||||
num_groups: the number of groups to divide the channel dim into before
|
||||
whitening. We will attempt to make the feature covariance
|
||||
within each group, after mean subtraction, as "white" as possible,
|
||||
while having the same trace across all groups.
|
||||
whitening_limit: a value greater than 1.0, that dictates how much
|
||||
freedom we have to violate the constraints. 1.0 would mean perfectly
|
||||
white, with exactly the same trace across groups; larger values
|
||||
give more freedom. E.g. 2.0.
|
||||
prob: the probability with which we apply the gradient modification
|
||||
(also affects the grad scale). May be supplied as a float,
|
||||
or as a pair (min_prob, max_prob)
|
||||
|
||||
grad_scale: determines the scale on the gradient term from this object,
|
||||
relative to the rest of the gradient on the attention weights.
|
||||
E.g. 0.02 (you may want to use smaller values than this if prob is large)
|
||||
"""
|
||||
super(Whiten, self).__init__()
|
||||
assert num_groups >= 1
|
||||
assert whitening_limit >= 1
|
||||
assert grad_scale >= 0
|
||||
self.num_groups = num_groups
|
||||
self.whitening_limit = whitening_limit
|
||||
if isinstance(prob, float):
|
||||
assert 0 < prob <= 1
|
||||
self.prob = prob
|
||||
else:
|
||||
(self.min_prob, self.max_prob) = prob
|
||||
assert 0 < self.min_prob < self.max_prob <= 1
|
||||
self.prob = self.max_prob
|
||||
|
||||
self.grad_scale = grad_scale
|
||||
|
||||
def forward(self,
|
||||
x: Tensor) -> Tensor:
|
||||
"""
|
||||
In the forward pass, this function just returns the input unmodified.
|
||||
In the backward pass, it will modify the gradients to ensure that the
|
||||
distribution in each group has close to (lambda times I) as the covariance
|
||||
after mean subtraction, with the same lambda across groups.
|
||||
For whitening_limit > 1, there will be more freedom to violate this
|
||||
constraint.
|
||||
|
||||
Args:
|
||||
x: the input of shape (*, num_channels)
|
||||
|
||||
Returns:
|
||||
x, unmodified. You should make sure
|
||||
you use the returned value, or the graph will be freed
|
||||
and nothing will happen in backprop.
|
||||
"""
|
||||
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
|
||||
return x
|
||||
else:
|
||||
if hasattr(self, 'min_prob') and random.random() < 0.25:
|
||||
# occasionally switch between min_prob and max_prob, based on whether
|
||||
# we are above or below the threshold.
|
||||
if _whitening_metric(x, self.num_groups) > self.whitening_limit:
|
||||
# there would be a change to the grad.
|
||||
self.prob = self.max_prob
|
||||
else:
|
||||
self.prob = self.min_prob
|
||||
|
||||
return WhiteningPenaltyFunction.apply(x,
|
||||
self.num_groups,
|
||||
self.whitening_limit,
|
||||
self.grad_scale)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class MaxEig(torch.nn.Module):
|
||||
"""
|
||||
Modifies the backpropped derivatives of a function to try to discourage
|
||||
@ -632,6 +798,35 @@ def _test_max_eig():
|
||||
assert not torch.allclose(x.grad, y_grad)
|
||||
|
||||
|
||||
def _test_whiten():
|
||||
for proportion in [0.1, 0.5, 10.0]:
|
||||
logging.info(f"_test_whiten(): proportion = {proportion}")
|
||||
x = torch.randn(100, 128)
|
||||
direction = torch.randn(128)
|
||||
coeffs = torch.randn(100, 1)
|
||||
x += proportion * direction * coeffs
|
||||
|
||||
x.requires_grad = True
|
||||
|
||||
num_channels = 128
|
||||
m = Whiten(1, # num_groups
|
||||
5.0, # whitening_limit,
|
||||
prob=1.0,
|
||||
grad_scale=0.1) # grad_scale
|
||||
|
||||
|
||||
for _ in range(4):
|
||||
y = m(x)
|
||||
|
||||
y_grad = torch.randn_like(x)
|
||||
y.backward(gradient=y_grad)
|
||||
|
||||
if proportion < 0.2:
|
||||
assert torch.allclose(x.grad, y_grad)
|
||||
elif proportion > 1.0:
|
||||
assert not torch.allclose(x.grad, y_grad)
|
||||
|
||||
|
||||
|
||||
def _test_activation_balancer_sign():
|
||||
probs = torch.arange(0, 1, 0.01)
|
||||
@ -714,6 +909,7 @@ if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_whiten()
|
||||
_test_max_eig()
|
||||
_test_activation_balancer_sign()
|
||||
_test_activation_balancer_magnitude()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user