mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement max-eig-proportion..
This commit is contained in:
parent
5f27cbdb44
commit
3d72a65de8
@ -249,8 +249,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att = self.self_attn(
|
||||
src,
|
||||
src,
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=src_mask,
|
||||
@ -490,9 +488,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
x: Tensor,
|
||||
pos_emb: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
@ -500,7 +496,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
x: input to be projected to query, key, value
|
||||
pos_emb: Positional embedding tensor
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. When given a binary mask and a value is True,
|
||||
@ -513,11 +509,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
Shape:
|
||||
- Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
- x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
@ -540,9 +532,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
return self.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.in_balancer(self.in_proj(x)),
|
||||
pos_emb,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
@ -584,11 +574,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
def multi_head_attention_forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
x: Tensor,
|
||||
pos_emb: Tensor,
|
||||
embed_dim_to_check: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
in_proj_weight: Tensor,
|
||||
in_proj_bias: Tensor,
|
||||
@ -604,7 +592,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
pos_emb: Positional embedding tensor
|
||||
embed_dim_to_check: total dimension of the model.
|
||||
embed_dim: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
dropout_p: probability of an element to be zeroed.
|
||||
@ -646,9 +634,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == embed_dim_to_check
|
||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||
tgt_len, bsz, _ = x.size()
|
||||
|
||||
head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
@ -657,62 +643,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
scaling = float(head_dim) ** -0.5
|
||||
|
||||
def linear(x, w, b):
|
||||
return self.in_balancer(nn.functional.linear(x, w, b))
|
||||
|
||||
if torch.equal(query, key) and torch.equal(key, value):
|
||||
# self-attention
|
||||
q, k, v = linear(
|
||||
query, in_proj_weight, in_proj_bias
|
||||
).chunk(3, dim=-1)
|
||||
# self-attention
|
||||
q, k, v = x.chunk(3, dim=-1)
|
||||
|
||||
elif torch.equal(key, value):
|
||||
# encoder-decoder attention
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = linear(query, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
k, v = linear(key, _w, _b).chunk(2, dim=-1)
|
||||
|
||||
else:
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = linear(query, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = embed_dim * 2
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
k = linear(key, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim * 2
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
v = linear(value, _w, _b)
|
||||
|
||||
if attn_mask is not None:
|
||||
assert (
|
||||
@ -732,15 +666,15 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||
if list(attn_mask.size()) != [1, tgt_len, tgt_len]:
|
||||
raise RuntimeError(
|
||||
"The size of the 2D attn_mask is not correct."
|
||||
)
|
||||
elif attn_mask.dim() == 3:
|
||||
if list(attn_mask.size()) != [
|
||||
bsz * num_heads,
|
||||
query.size(0),
|
||||
key.size(0),
|
||||
tgt_len,
|
||||
tgt_len,
|
||||
]:
|
||||
raise RuntimeError(
|
||||
"The size of the 3D attn_mask is not correct."
|
||||
|
||||
@ -254,7 +254,7 @@ class ScaledAdam(Optimizer):
|
||||
if ans < 1.0:
|
||||
state["num_clipped"] += 1
|
||||
if ans < 0.1:
|
||||
logging.warn("Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||
return ans
|
||||
|
||||
|
||||
|
||||
@ -145,71 +145,6 @@ def find_direction_coeffs(x: Tensor,
|
||||
return cur_direction, coeffs
|
||||
|
||||
|
||||
def get_max_eig_proportion(x: Tensor,
|
||||
prev_direction: Tensor,
|
||||
subtract_mean: bool) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Figure out (an approximation to) the proportion of the variance of a set of
|
||||
feature vectors that can be attributed to the top eigen-direction.
|
||||
Args:
|
||||
x: a Tensor of shape (*, num_channels). There must be more than one frame,
|
||||
i.e. x.numel() // num_channels > 1.
|
||||
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
|
||||
of the top eigen-direction, or a random direction if this is the first
|
||||
iteration. Expected to be without gradient. Does not have to be
|
||||
normalized.
|
||||
subtract_mean: if True, we will first subtract the mean of x, over the
|
||||
frames. Suggest to make this true in most circumstances.
|
||||
|
||||
Returns: (cur_direction, max_proportion), where:
|
||||
cur_direction: a Tensor of shape (num_channels,) that is the current
|
||||
estimate of the top eigen-direction. Detached / not intended to be
|
||||
differentiable.
|
||||
proportion: a scalar Tensor containing the proportion of the variance
|
||||
of the input that is in direction `cur_direction`. This is with
|
||||
gradient, that can be propagated back to x.
|
||||
"""
|
||||
num_channels = x.shape[-1]
|
||||
assert prev_direction.shape == (num_channels,)
|
||||
x = x.reshape(-1, num_channels)
|
||||
if subtract_mean:
|
||||
x = x - x.mean(dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
cur_norm = prev_direction.norm()
|
||||
|
||||
prev_direction = prev_direction / cur_norm
|
||||
is_ok = (cur_norm / cur_norm == 1.0)
|
||||
# if there was a problem like NaN or inf, restart. this should be very rare.
|
||||
prev_direction = torch.where(is_ok.unsqueeze(-1).expand(prev_direction.shape),
|
||||
prev_direction,
|
||||
torch.randn_like(prev_direction) * (num_channels ** -0.5))
|
||||
|
||||
# `coeffs` are the coefficients of `prev_direction` in x.
|
||||
coeffs = (x * prev_direction).sum(dim=1, keepdim=True)
|
||||
|
||||
x_norm = x.norm()
|
||||
x_coeffs1_norm = (x - coeffs * prev_direction).norm()
|
||||
|
||||
with torch.no_grad():
|
||||
cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20)
|
||||
|
||||
x_coeffs2_norm = (x - coeffs * cur_direction).norm()
|
||||
|
||||
# for the returned direction interpolate with prev_direction so that
|
||||
# even if x == 0, we get a nonzero new direction.
|
||||
ans_direction = 0.5 * (prev_direction + cur_direction)
|
||||
|
||||
x_sumsq = (x**2).sum() + 1.0e-20
|
||||
x_remaining_sumsq = ((x - coeffs * cur_direction) ** 2).sum() + 1.0e-20
|
||||
|
||||
proportion = (x - x_remaining_sumsq) / x_sumsq
|
||||
|
||||
return (ans_direction, proportion)
|
||||
|
||||
print(f"x_norm={x_norm}, x_coeffs1_norm={x_coeffs1_norm}, x_coeffs2_norm={x_coeffs2_norm}")
|
||||
|
||||
|
||||
|
||||
|
||||
class MaxEigLimiterFunction(torch.autograd.Function):
|
||||
@ -233,17 +168,18 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
||||
if subtract_mean:
|
||||
x = x - x.mean(dim=0)
|
||||
new_direction, coeffs = find_direction_coeffs(x, direction)
|
||||
x_var = (x**2).sum()
|
||||
x_var = (x**2).mean()
|
||||
x_residual = x - coeffs * new_direction
|
||||
x_residual_var = (x_residual**2).sum()
|
||||
x_residual_var = (x_residual**2).mean()
|
||||
# `variance_proportion` is the proportion of the variance accounted for
|
||||
# by the top eigen-direction.
|
||||
variance_proportion = (x_var - x_residual_var) / x_var
|
||||
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
||||
|
||||
ans_direction = direction + new_direction # ensure nonzero even if x == 0
|
||||
ans_direction = ans_direction / ans_direction.norm()
|
||||
|
||||
logging.info(f"variance_proportion = {variance_proportion.item()}")
|
||||
if random.random() < 0.01:
|
||||
logging.info(f"variance_proportion = {variance_proportion.item()}")
|
||||
|
||||
# Caution: this causes a CUDA sync, which is not ideal.
|
||||
if variance_proportion >= max_variance_proportion:
|
||||
@ -262,7 +198,6 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
||||
if not hasattr(ctx, 'channel_dim'):
|
||||
# the top eig's proportion of the variance was below the threshold.
|
||||
return x_grad, None, None, None, None, None, None
|
||||
|
||||
with torch.enable_grad():
|
||||
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
||||
x_orig.requires_grad = True
|
||||
@ -271,16 +206,16 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
||||
new_direction.requires_grad = False
|
||||
if ctx.subtract_mean:
|
||||
x = x - x.mean(dim=0)
|
||||
x_var = (x**2).sum()
|
||||
x_var = (x ** 2).mean()
|
||||
x_residual = x - coeffs * new_direction
|
||||
x_residual_var = (x_residual**2).sum()
|
||||
x_residual_var = (x_residual ** 2).mean()
|
||||
# `variance_proportion` is the proportion of the variance accounted for
|
||||
# by the top eigen-direction. This is to be minimized.
|
||||
variance_proportion = (x_var - x_residual_var) / x_var
|
||||
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
||||
variance_proportion.backward()
|
||||
x_orig_grad = x_orig.grad
|
||||
x_extra_grad = x_orig.grad * x_orig.grad.norm() / (x_orig_grad.norm() + 1.0e-20)
|
||||
return x_grad + x_extra_grad, None, None, None, None, None, None
|
||||
x_orig_grad = x_orig.grad
|
||||
x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20)
|
||||
return x_grad + x_extra_grad.detach(), None, None, None, None, None, None
|
||||
|
||||
|
||||
class BasicNorm(torch.nn.Module):
|
||||
@ -448,7 +383,9 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.max_var_per_eig = max_var_per_eig
|
||||
if max_var_per_eig > 0.0:
|
||||
with torch.no_grad():
|
||||
direction = torch.randn(num_channels)
|
||||
# arbitrary.. would use randn() but want to leave the rest of the model's
|
||||
# random parameters unchanged for comparison
|
||||
direction = torch.arange(num_channels).to(torch.float)
|
||||
direction = direction / direction.norm()
|
||||
self.register_buffer('max_eig_direction', direction)
|
||||
else:
|
||||
@ -460,15 +397,16 @@ class ActivationBalancer(torch.nn.Module):
|
||||
return x
|
||||
|
||||
if self.max_var_per_eig > 0:
|
||||
x, new_direction = MaxEigLimiterFunction.apply(
|
||||
x, self.max_eig_direction,
|
||||
self.channel_dim,
|
||||
0.1, # prob
|
||||
True, # subtract_mean
|
||||
self.max_var_per_eig,
|
||||
self.max_factor,
|
||||
)
|
||||
self.max_eig_direction[:] = new_direction
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x, new_direction = MaxEigLimiterFunction.apply(
|
||||
x, self.max_eig_direction,
|
||||
self.channel_dim,
|
||||
0.25, # prob
|
||||
True, # subtract_mean
|
||||
self.max_var_per_eig,
|
||||
self.max_factor,
|
||||
)
|
||||
self.max_eig_direction[:] = new_direction.detach()
|
||||
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
@ -628,17 +566,12 @@ def _test_double_swish_deriv():
|
||||
torch.autograd.gradcheck(m, x)
|
||||
|
||||
|
||||
def _test_get_max_eig_proportion():
|
||||
x = torch.randn(100, 128)
|
||||
d = torch.randn(128) * (128 ** -0.5)
|
||||
get_max_eig_proportion(x, d, True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_max_eig_limiter()
|
||||
_test_get_max_eig_proportion()
|
||||
_test_activation_balancer_sign()
|
||||
_test_activation_balancer_magnitude()
|
||||
_test_basic_norm()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user