mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp4_max_var_per_eig' into scaled_adam_exp7
# Conflicts: # egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py
This commit is contained in:
commit
5f27cbdb44
@ -173,7 +173,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
ActivationBalancer(channel_dim=-1, max_abs=10.0),
|
||||
ActivationBalancer(dim_feedforward,
|
||||
channel_dim=-1, max_abs=10.0),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model,
|
||||
@ -182,7 +183,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
ActivationBalancer(channel_dim=-1, max_abs=10.0),
|
||||
ActivationBalancer(dim_feedforward,
|
||||
channel_dim=-1, max_abs=10.0),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model,
|
||||
@ -196,7 +198,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||
self.balancer = ActivationBalancer(
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||
d_model, channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
@ -464,8 +466,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
||||
self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0)
|
||||
self.proj_balancer = ActivationBalancer(channel_dim=-1, max_abs=10.0)
|
||||
self.in_balancer = ActivationBalancer(3 * embed_dim,
|
||||
channel_dim=-1, max_abs=5.0,
|
||||
max_var_per_eig=0.1)
|
||||
self.proj_balancer = ActivationBalancer(embed_dim,
|
||||
channel_dim=-1, max_abs=10.0,
|
||||
min_positive=0.0, max_positive=1.0)
|
||||
self.out_proj = ScaledLinear(
|
||||
embed_dim, embed_dim, bias=True, initial_scale=0.5
|
||||
)
|
||||
@ -900,6 +906,7 @@ class ConvolutionModule(nn.Module):
|
||||
# it will be in a better position to start learning something, i.e. to latch onto
|
||||
# the correct range.
|
||||
self.deriv_balancer1 = ActivationBalancer(
|
||||
2 * channels,
|
||||
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||
)
|
||||
|
||||
@ -914,7 +921,7 @@ class ConvolutionModule(nn.Module):
|
||||
)
|
||||
|
||||
self.deriv_balancer2 = ActivationBalancer(
|
||||
channel_dim=1, min_positive=0.05, max_positive=1.0
|
||||
channels, channel_dim=1, min_positive=0.05, max_positive=1.0
|
||||
)
|
||||
|
||||
self.activation = DoubleSwish()
|
||||
@ -1000,7 +1007,8 @@ class Conv2dSubsampling(nn.Module):
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
ActivationBalancer(layer1_channels,
|
||||
channel_dim=1),
|
||||
DoubleSwish(),
|
||||
nn.Conv2d(
|
||||
in_channels=layer1_channels,
|
||||
@ -1008,7 +1016,8 @@ class Conv2dSubsampling(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
ActivationBalancer(layer2_channels,
|
||||
channel_dim=1),
|
||||
DoubleSwish(),
|
||||
nn.Conv2d(
|
||||
in_channels=layer2_channels,
|
||||
@ -1016,7 +1025,8 @@ class Conv2dSubsampling(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
ActivationBalancer(channel_dim=1),
|
||||
ActivationBalancer(layer3_channels,
|
||||
channel_dim=1),
|
||||
DoubleSwish(),
|
||||
)
|
||||
out_height = (((in_channels - 1) // 2 - 1) // 2)
|
||||
@ -1027,6 +1037,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||
# constrain median of output to be close to zero.
|
||||
self.out_balancer = ActivationBalancer(
|
||||
out_channels,
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||
)
|
||||
|
||||
|
||||
@ -114,6 +114,173 @@ class ActivationBalancerFunction(torch.autograd.Function):
|
||||
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
||||
|
||||
|
||||
def find_direction_coeffs(x: Tensor,
|
||||
prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Figure out (an approximation to) the proportion of the variance of a set of
|
||||
feature vectors that can be attributed to the top eigen-direction.
|
||||
Args:
|
||||
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
|
||||
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
|
||||
of the top eigen-direction, or a random direction if this is the first
|
||||
iteration. Does not have to be normalized, but should be nonzero.
|
||||
|
||||
Returns: (cur_direction, coeffs), where:
|
||||
cur_direction: a Tensor of shape (num_channels,) that is the current
|
||||
estimate of the top eigen-direction.
|
||||
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
|
||||
approximately minimizes, (x - coeffs * cur_direction).norm()
|
||||
"""
|
||||
(num_frames, num_channels) = x.shape
|
||||
assert num_channels > 1 and num_frames > 1
|
||||
|
||||
assert prev_direction.shape == (num_channels,)
|
||||
|
||||
# `coeffs` are the coefficients of `prev_direction` in x.
|
||||
# actually represent the coeffs up to a constant positive factor.
|
||||
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
|
||||
|
||||
cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20)
|
||||
|
||||
return cur_direction, coeffs
|
||||
|
||||
|
||||
def get_max_eig_proportion(x: Tensor,
|
||||
prev_direction: Tensor,
|
||||
subtract_mean: bool) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Figure out (an approximation to) the proportion of the variance of a set of
|
||||
feature vectors that can be attributed to the top eigen-direction.
|
||||
Args:
|
||||
x: a Tensor of shape (*, num_channels). There must be more than one frame,
|
||||
i.e. x.numel() // num_channels > 1.
|
||||
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
|
||||
of the top eigen-direction, or a random direction if this is the first
|
||||
iteration. Expected to be without gradient. Does not have to be
|
||||
normalized.
|
||||
subtract_mean: if True, we will first subtract the mean of x, over the
|
||||
frames. Suggest to make this true in most circumstances.
|
||||
|
||||
Returns: (cur_direction, max_proportion), where:
|
||||
cur_direction: a Tensor of shape (num_channels,) that is the current
|
||||
estimate of the top eigen-direction. Detached / not intended to be
|
||||
differentiable.
|
||||
proportion: a scalar Tensor containing the proportion of the variance
|
||||
of the input that is in direction `cur_direction`. This is with
|
||||
gradient, that can be propagated back to x.
|
||||
"""
|
||||
num_channels = x.shape[-1]
|
||||
assert prev_direction.shape == (num_channels,)
|
||||
x = x.reshape(-1, num_channels)
|
||||
if subtract_mean:
|
||||
x = x - x.mean(dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
cur_norm = prev_direction.norm()
|
||||
|
||||
prev_direction = prev_direction / cur_norm
|
||||
is_ok = (cur_norm / cur_norm == 1.0)
|
||||
# if there was a problem like NaN or inf, restart. this should be very rare.
|
||||
prev_direction = torch.where(is_ok.unsqueeze(-1).expand(prev_direction.shape),
|
||||
prev_direction,
|
||||
torch.randn_like(prev_direction) * (num_channels ** -0.5))
|
||||
|
||||
# `coeffs` are the coefficients of `prev_direction` in x.
|
||||
coeffs = (x * prev_direction).sum(dim=1, keepdim=True)
|
||||
|
||||
x_norm = x.norm()
|
||||
x_coeffs1_norm = (x - coeffs * prev_direction).norm()
|
||||
|
||||
with torch.no_grad():
|
||||
cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20)
|
||||
|
||||
x_coeffs2_norm = (x - coeffs * cur_direction).norm()
|
||||
|
||||
# for the returned direction interpolate with prev_direction so that
|
||||
# even if x == 0, we get a nonzero new direction.
|
||||
ans_direction = 0.5 * (prev_direction + cur_direction)
|
||||
|
||||
x_sumsq = (x**2).sum() + 1.0e-20
|
||||
x_remaining_sumsq = ((x - coeffs * cur_direction) ** 2).sum() + 1.0e-20
|
||||
|
||||
proportion = (x - x_remaining_sumsq) / x_sumsq
|
||||
|
||||
return (ans_direction, proportion)
|
||||
|
||||
print(f"x_norm={x_norm}, x_coeffs1_norm={x_coeffs1_norm}, x_coeffs2_norm={x_coeffs2_norm}")
|
||||
|
||||
|
||||
|
||||
|
||||
class MaxEigLimiterFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
direction: Tensor,
|
||||
channel_dim: int,
|
||||
prob: float,
|
||||
subtract_mean: bool,
|
||||
max_variance_proportion: float,
|
||||
grad_scale: float) -> Tuple[Tensor, Tensor]:
|
||||
if random.random() > prob:
|
||||
return x, direction
|
||||
eps = 1.0e-20
|
||||
num_channels = x.shape[channel_dim]
|
||||
assert max_variance_proportion > 1.0 / num_channels
|
||||
orig_x = x
|
||||
x = x.transpose(channel_dim, -1).reshape(-1, num_channels)
|
||||
if subtract_mean:
|
||||
x = x - x.mean(dim=0)
|
||||
new_direction, coeffs = find_direction_coeffs(x, direction)
|
||||
x_var = (x**2).sum()
|
||||
x_residual = x - coeffs * new_direction
|
||||
x_residual_var = (x_residual**2).sum()
|
||||
# `variance_proportion` is the proportion of the variance accounted for
|
||||
# by the top eigen-direction.
|
||||
variance_proportion = (x_var - x_residual_var) / x_var
|
||||
|
||||
ans_direction = direction + new_direction # ensure nonzero even if x == 0
|
||||
ans_direction = ans_direction / ans_direction.norm()
|
||||
|
||||
logging.info(f"variance_proportion = {variance_proportion.item()}")
|
||||
|
||||
# Caution: this causes a CUDA sync, which is not ideal.
|
||||
if variance_proportion >= max_variance_proportion:
|
||||
ctx.channel_dim = channel_dim
|
||||
ctx.subtract_mean = subtract_mean
|
||||
ctx.grad_scale = grad_scale
|
||||
ctx.save_for_backward(orig_x.detach(),
|
||||
coeffs.detach(),
|
||||
new_direction.detach())
|
||||
|
||||
return orig_x, ans_direction
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad, *args):
|
||||
# the *args is all the other derivs, which should be None or zero.
|
||||
if not hasattr(ctx, 'channel_dim'):
|
||||
# the top eig's proportion of the variance was below the threshold.
|
||||
return x_grad, None, None, None, None, None, None
|
||||
|
||||
with torch.enable_grad():
|
||||
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
||||
x_orig.requires_grad = True
|
||||
num_channels = x_orig.shape[ctx.channel_dim]
|
||||
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
|
||||
new_direction.requires_grad = False
|
||||
if ctx.subtract_mean:
|
||||
x = x - x.mean(dim=0)
|
||||
x_var = (x**2).sum()
|
||||
x_residual = x - coeffs * new_direction
|
||||
x_residual_var = (x_residual**2).sum()
|
||||
# `variance_proportion` is the proportion of the variance accounted for
|
||||
# by the top eigen-direction. This is to be minimized.
|
||||
variance_proportion = (x_var - x_residual_var) / x_var
|
||||
variance_proportion.backward()
|
||||
x_orig_grad = x_orig.grad
|
||||
x_extra_grad = x_orig.grad * x_orig.grad.norm() / (x_orig_grad.norm() + 1.0e-20)
|
||||
return x_grad + x_extra_grad, None, None, None, None, None, None
|
||||
|
||||
|
||||
class BasicNorm(torch.nn.Module):
|
||||
@ -236,6 +403,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
|
||||
|
||||
Args:
|
||||
num_channels: the number of channels
|
||||
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
||||
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
||||
min_positive: the minimum, per channel, of the proportion of the time
|
||||
@ -252,29 +420,56 @@ class ActivationBalancer(torch.nn.Module):
|
||||
max_abs: the maximum average-absolute-value difference from the mean
|
||||
value per channel, which we allow, before we start to modify
|
||||
the derivatives to prevent this.
|
||||
max_var_per_eig: the maximum proportion of the variance of the
|
||||
features/channels, after mean subtraction, that can come from
|
||||
any given eigenvalue.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
channel_dim: int,
|
||||
min_positive: float = 0.05,
|
||||
max_positive: float = 0.95,
|
||||
max_factor: float = 0.01,
|
||||
min_abs: float = 0.2,
|
||||
max_abs: float = 100.0,
|
||||
max_var_per_eig: float = 0.0,
|
||||
):
|
||||
super(ActivationBalancer, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
self.min_positive = min_positive
|
||||
self.max_positive = max_positive
|
||||
self.max_factor = max_factor
|
||||
self.min_abs = min_abs
|
||||
self.max_abs = max_abs
|
||||
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
|
||||
self.max_var_per_eig = max_var_per_eig
|
||||
if max_var_per_eig > 0.0:
|
||||
with torch.no_grad():
|
||||
direction = torch.randn(num_channels)
|
||||
direction = direction / direction.norm()
|
||||
self.register_buffer('max_eig_direction', direction)
|
||||
else:
|
||||
self.max_eig_direction = None
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting():
|
||||
return x
|
||||
|
||||
if self.max_var_per_eig > 0:
|
||||
x, new_direction = MaxEigLimiterFunction.apply(
|
||||
x, self.max_eig_direction,
|
||||
self.channel_dim,
|
||||
0.1, # prob
|
||||
True, # subtract_mean
|
||||
self.max_var_per_eig,
|
||||
self.max_factor,
|
||||
)
|
||||
self.max_eig_direction[:] = new_direction
|
||||
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
@ -326,6 +521,35 @@ class DoubleSwish(torch.nn.Module):
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
|
||||
def _test_max_eig_limiter():
|
||||
|
||||
for proportion in [0.1, 0.5, 10.0]:
|
||||
logging.info(f"proportion = {proportion}")
|
||||
x = torch.randn(100, 128)
|
||||
direction = torch.randn(128)
|
||||
coeffs = torch.randn(100, 1)
|
||||
x += proportion * direction * coeffs
|
||||
|
||||
x.requires_grad = True
|
||||
|
||||
y, new_direction = MaxEigLimiterFunction.apply(x, direction,
|
||||
1, # channel_dim
|
||||
1.0, # prob
|
||||
True, # subtract_mean
|
||||
0.5, # max_variance_proportion
|
||||
0.1, # grad_scale
|
||||
)
|
||||
|
||||
cosine = (new_direction * direction).sum() / (new_direction.norm() * direction.norm())
|
||||
logging.info(f"Direction cosine = {cosine}")
|
||||
|
||||
y_grad = torch.randn_like(x)
|
||||
y.backward(gradient=y_grad)
|
||||
|
||||
if proportion < 0.2:
|
||||
assert torch.allclose(x.grad, y_grad)
|
||||
elif proportion > 1.0:
|
||||
assert not torch.allclose(x.grad, y_grad)
|
||||
|
||||
|
||||
|
||||
@ -336,6 +560,7 @@ def _test_activation_balancer_sign():
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = ActivationBalancer(
|
||||
probs.numel(),
|
||||
channel_dim=0,
|
||||
min_positive=0.05,
|
||||
max_positive=0.98,
|
||||
@ -361,6 +586,7 @@ def _test_activation_balancer_magnitude():
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = ActivationBalancer(
|
||||
magnitudes.numel(),
|
||||
channel_dim=0,
|
||||
min_positive=0.0,
|
||||
max_positive=1.0,
|
||||
@ -402,10 +628,17 @@ def _test_double_swish_deriv():
|
||||
torch.autograd.gradcheck(m, x)
|
||||
|
||||
|
||||
def _test_get_max_eig_proportion():
|
||||
x = torch.randn(100, 128)
|
||||
d = torch.randn(128) * (128 ** -0.5)
|
||||
get_max_eig_proportion(x, d, True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_max_eig_limiter()
|
||||
_test_get_max_eig_proportion()
|
||||
_test_activation_balancer_sign()
|
||||
_test_activation_balancer_magnitude()
|
||||
_test_basic_norm()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user