Merge branch 'scaled_adam_exp4_max_var_per_eig' into scaled_adam_exp7

# Conflicts:
#	egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py
This commit is contained in:
Daniel Povey 2022-09-18 21:23:59 +08:00
commit 5f27cbdb44
2 changed files with 253 additions and 9 deletions

View File

@ -173,7 +173,8 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1, max_abs=10.0), ActivationBalancer(dim_feedforward,
channel_dim=-1, max_abs=10.0),
DoubleSwish(), DoubleSwish(),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, ScaledLinear(dim_feedforward, d_model,
@ -182,7 +183,8 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward_macaron = nn.Sequential( self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1, max_abs=10.0), ActivationBalancer(dim_feedforward,
channel_dim=-1, max_abs=10.0),
DoubleSwish(), DoubleSwish(),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, ScaledLinear(dim_feedforward, d_model,
@ -196,7 +198,7 @@ class ConformerEncoderLayer(nn.Module):
# try to ensure the output is close to zero-mean (or at least, zero-median). # try to ensure the output is close to zero-mean (or at least, zero-median).
self.balancer = ActivationBalancer( self.balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 d_model, channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
) )
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -464,8 +466,12 @@ class RelPositionMultiheadAttention(nn.Module):
), "embed_dim must be divisible by num_heads" ), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0) self.in_balancer = ActivationBalancer(3 * embed_dim,
self.proj_balancer = ActivationBalancer(channel_dim=-1, max_abs=10.0) channel_dim=-1, max_abs=5.0,
max_var_per_eig=0.1)
self.proj_balancer = ActivationBalancer(embed_dim,
channel_dim=-1, max_abs=10.0,
min_positive=0.0, max_positive=1.0)
self.out_proj = ScaledLinear( self.out_proj = ScaledLinear(
embed_dim, embed_dim, bias=True, initial_scale=0.5 embed_dim, embed_dim, bias=True, initial_scale=0.5
) )
@ -900,6 +906,7 @@ class ConvolutionModule(nn.Module):
# it will be in a better position to start learning something, i.e. to latch onto # it will be in a better position to start learning something, i.e. to latch onto
# the correct range. # the correct range.
self.deriv_balancer1 = ActivationBalancer( self.deriv_balancer1 = ActivationBalancer(
2 * channels,
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
) )
@ -914,7 +921,7 @@ class ConvolutionModule(nn.Module):
) )
self.deriv_balancer2 = ActivationBalancer( self.deriv_balancer2 = ActivationBalancer(
channel_dim=1, min_positive=0.05, max_positive=1.0 channels, channel_dim=1, min_positive=0.05, max_positive=1.0
) )
self.activation = DoubleSwish() self.activation = DoubleSwish()
@ -1000,7 +1007,8 @@ class Conv2dSubsampling(nn.Module):
kernel_size=3, kernel_size=3,
padding=1, padding=1,
), ),
ActivationBalancer(channel_dim=1), ActivationBalancer(layer1_channels,
channel_dim=1),
DoubleSwish(), DoubleSwish(),
nn.Conv2d( nn.Conv2d(
in_channels=layer1_channels, in_channels=layer1_channels,
@ -1008,7 +1016,8 @@ class Conv2dSubsampling(nn.Module):
kernel_size=3, kernel_size=3,
stride=2, stride=2,
), ),
ActivationBalancer(channel_dim=1), ActivationBalancer(layer2_channels,
channel_dim=1),
DoubleSwish(), DoubleSwish(),
nn.Conv2d( nn.Conv2d(
in_channels=layer2_channels, in_channels=layer2_channels,
@ -1016,7 +1025,8 @@ class Conv2dSubsampling(nn.Module):
kernel_size=3, kernel_size=3,
stride=2, stride=2,
), ),
ActivationBalancer(channel_dim=1), ActivationBalancer(layer3_channels,
channel_dim=1),
DoubleSwish(), DoubleSwish(),
) )
out_height = (((in_channels - 1) // 2 - 1) // 2) out_height = (((in_channels - 1) // 2 - 1) // 2)
@ -1027,6 +1037,7 @@ class Conv2dSubsampling(nn.Module):
self.out_norm = BasicNorm(out_channels, learn_eps=False) self.out_norm = BasicNorm(out_channels, learn_eps=False)
# constrain median of output to be close to zero. # constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer( self.out_balancer = ActivationBalancer(
out_channels,
channel_dim=-1, min_positive=0.45, max_positive=0.55 channel_dim=-1, min_positive=0.45, max_positive=0.55
) )

View File

@ -114,6 +114,173 @@ class ActivationBalancerFunction(torch.autograd.Function):
return x_grad - neg_delta_grad, None, None, None, None, None, None return x_grad - neg_delta_grad, None, None, None, None, None, None
def find_direction_coeffs(x: Tensor,
prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
Figure out (an approximation to) the proportion of the variance of a set of
feature vectors that can be attributed to the top eigen-direction.
Args:
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
of the top eigen-direction, or a random direction if this is the first
iteration. Does not have to be normalized, but should be nonzero.
Returns: (cur_direction, coeffs), where:
cur_direction: a Tensor of shape (num_channels,) that is the current
estimate of the top eigen-direction.
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
approximately minimizes, (x - coeffs * cur_direction).norm()
"""
(num_frames, num_channels) = x.shape
assert num_channels > 1 and num_frames > 1
assert prev_direction.shape == (num_channels,)
# `coeffs` are the coefficients of `prev_direction` in x.
# actually represent the coeffs up to a constant positive factor.
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20)
return cur_direction, coeffs
def get_max_eig_proportion(x: Tensor,
prev_direction: Tensor,
subtract_mean: bool) -> Tuple[Tensor, Tensor]:
"""
Figure out (an approximation to) the proportion of the variance of a set of
feature vectors that can be attributed to the top eigen-direction.
Args:
x: a Tensor of shape (*, num_channels). There must be more than one frame,
i.e. x.numel() // num_channels > 1.
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
of the top eigen-direction, or a random direction if this is the first
iteration. Expected to be without gradient. Does not have to be
normalized.
subtract_mean: if True, we will first subtract the mean of x, over the
frames. Suggest to make this true in most circumstances.
Returns: (cur_direction, max_proportion), where:
cur_direction: a Tensor of shape (num_channels,) that is the current
estimate of the top eigen-direction. Detached / not intended to be
differentiable.
proportion: a scalar Tensor containing the proportion of the variance
of the input that is in direction `cur_direction`. This is with
gradient, that can be propagated back to x.
"""
num_channels = x.shape[-1]
assert prev_direction.shape == (num_channels,)
x = x.reshape(-1, num_channels)
if subtract_mean:
x = x - x.mean(dim=0)
with torch.no_grad():
cur_norm = prev_direction.norm()
prev_direction = prev_direction / cur_norm
is_ok = (cur_norm / cur_norm == 1.0)
# if there was a problem like NaN or inf, restart. this should be very rare.
prev_direction = torch.where(is_ok.unsqueeze(-1).expand(prev_direction.shape),
prev_direction,
torch.randn_like(prev_direction) * (num_channels ** -0.5))
# `coeffs` are the coefficients of `prev_direction` in x.
coeffs = (x * prev_direction).sum(dim=1, keepdim=True)
x_norm = x.norm()
x_coeffs1_norm = (x - coeffs * prev_direction).norm()
with torch.no_grad():
cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20)
x_coeffs2_norm = (x - coeffs * cur_direction).norm()
# for the returned direction interpolate with prev_direction so that
# even if x == 0, we get a nonzero new direction.
ans_direction = 0.5 * (prev_direction + cur_direction)
x_sumsq = (x**2).sum() + 1.0e-20
x_remaining_sumsq = ((x - coeffs * cur_direction) ** 2).sum() + 1.0e-20
proportion = (x - x_remaining_sumsq) / x_sumsq
return (ans_direction, proportion)
print(f"x_norm={x_norm}, x_coeffs1_norm={x_coeffs1_norm}, x_coeffs2_norm={x_coeffs2_norm}")
class MaxEigLimiterFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
direction: Tensor,
channel_dim: int,
prob: float,
subtract_mean: bool,
max_variance_proportion: float,
grad_scale: float) -> Tuple[Tensor, Tensor]:
if random.random() > prob:
return x, direction
eps = 1.0e-20
num_channels = x.shape[channel_dim]
assert max_variance_proportion > 1.0 / num_channels
orig_x = x
x = x.transpose(channel_dim, -1).reshape(-1, num_channels)
if subtract_mean:
x = x - x.mean(dim=0)
new_direction, coeffs = find_direction_coeffs(x, direction)
x_var = (x**2).sum()
x_residual = x - coeffs * new_direction
x_residual_var = (x_residual**2).sum()
# `variance_proportion` is the proportion of the variance accounted for
# by the top eigen-direction.
variance_proportion = (x_var - x_residual_var) / x_var
ans_direction = direction + new_direction # ensure nonzero even if x == 0
ans_direction = ans_direction / ans_direction.norm()
logging.info(f"variance_proportion = {variance_proportion.item()}")
# Caution: this causes a CUDA sync, which is not ideal.
if variance_proportion >= max_variance_proportion:
ctx.channel_dim = channel_dim
ctx.subtract_mean = subtract_mean
ctx.grad_scale = grad_scale
ctx.save_for_backward(orig_x.detach(),
coeffs.detach(),
new_direction.detach())
return orig_x, ans_direction
@staticmethod
def backward(ctx, x_grad, *args):
# the *args is all the other derivs, which should be None or zero.
if not hasattr(ctx, 'channel_dim'):
# the top eig's proportion of the variance was below the threshold.
return x_grad, None, None, None, None, None, None
with torch.enable_grad():
(x_orig, coeffs, new_direction) = ctx.saved_tensors
x_orig.requires_grad = True
num_channels = x_orig.shape[ctx.channel_dim]
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
new_direction.requires_grad = False
if ctx.subtract_mean:
x = x - x.mean(dim=0)
x_var = (x**2).sum()
x_residual = x - coeffs * new_direction
x_residual_var = (x_residual**2).sum()
# `variance_proportion` is the proportion of the variance accounted for
# by the top eigen-direction. This is to be minimized.
variance_proportion = (x_var - x_residual_var) / x_var
variance_proportion.backward()
x_orig_grad = x_orig.grad
x_extra_grad = x_orig.grad * x_orig.grad.norm() / (x_orig_grad.norm() + 1.0e-20)
return x_grad + x_extra_grad, None, None, None, None, None, None
class BasicNorm(torch.nn.Module): class BasicNorm(torch.nn.Module):
@ -236,6 +403,7 @@ class ActivationBalancer(torch.nn.Module):
Args: Args:
num_channels: the number of channels
channel_dim: the dimension/axis corresponding to the channel, e.g. channel_dim: the dimension/axis corresponding to the channel, e.g.
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
min_positive: the minimum, per channel, of the proportion of the time min_positive: the minimum, per channel, of the proportion of the time
@ -252,29 +420,56 @@ class ActivationBalancer(torch.nn.Module):
max_abs: the maximum average-absolute-value difference from the mean max_abs: the maximum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify value per channel, which we allow, before we start to modify
the derivatives to prevent this. the derivatives to prevent this.
max_var_per_eig: the maximum proportion of the variance of the
features/channels, after mean subtraction, that can come from
any given eigenvalue.
""" """
def __init__( def __init__(
self, self,
num_channels: int,
channel_dim: int, channel_dim: int,
min_positive: float = 0.05, min_positive: float = 0.05,
max_positive: float = 0.95, max_positive: float = 0.95,
max_factor: float = 0.01, max_factor: float = 0.01,
min_abs: float = 0.2, min_abs: float = 0.2,
max_abs: float = 100.0, max_abs: float = 100.0,
max_var_per_eig: float = 0.0,
): ):
super(ActivationBalancer, self).__init__() super(ActivationBalancer, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.min_positive = min_positive self.min_positive = min_positive
self.max_positive = max_positive self.max_positive = max_positive
self.max_factor = max_factor self.max_factor = max_factor
self.min_abs = min_abs self.min_abs = min_abs
self.max_abs = max_abs self.max_abs = max_abs
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
self.max_var_per_eig = max_var_per_eig
if max_var_per_eig > 0.0:
with torch.no_grad():
direction = torch.randn(num_channels)
direction = direction / direction.norm()
self.register_buffer('max_eig_direction', direction)
else:
self.max_eig_direction = None
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return x return x
if self.max_var_per_eig > 0:
x, new_direction = MaxEigLimiterFunction.apply(
x, self.max_eig_direction,
self.channel_dim,
0.1, # prob
True, # subtract_mean
self.max_var_per_eig,
self.max_factor,
)
self.max_eig_direction[:] = new_direction
return ActivationBalancerFunction.apply( return ActivationBalancerFunction.apply(
x, x,
self.channel_dim, self.channel_dim,
@ -326,6 +521,35 @@ class DoubleSwish(torch.nn.Module):
return DoubleSwishFunction.apply(x) return DoubleSwishFunction.apply(x)
def _test_max_eig_limiter():
for proportion in [0.1, 0.5, 10.0]:
logging.info(f"proportion = {proportion}")
x = torch.randn(100, 128)
direction = torch.randn(128)
coeffs = torch.randn(100, 1)
x += proportion * direction * coeffs
x.requires_grad = True
y, new_direction = MaxEigLimiterFunction.apply(x, direction,
1, # channel_dim
1.0, # prob
True, # subtract_mean
0.5, # max_variance_proportion
0.1, # grad_scale
)
cosine = (new_direction * direction).sum() / (new_direction.norm() * direction.norm())
logging.info(f"Direction cosine = {cosine}")
y_grad = torch.randn_like(x)
y.backward(gradient=y_grad)
if proportion < 0.2:
assert torch.allclose(x.grad, y_grad)
elif proportion > 1.0:
assert not torch.allclose(x.grad, y_grad)
@ -336,6 +560,7 @@ def _test_activation_balancer_sign():
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
m = ActivationBalancer( m = ActivationBalancer(
probs.numel(),
channel_dim=0, channel_dim=0,
min_positive=0.05, min_positive=0.05,
max_positive=0.98, max_positive=0.98,
@ -361,6 +586,7 @@ def _test_activation_balancer_magnitude():
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
m = ActivationBalancer( m = ActivationBalancer(
magnitudes.numel(),
channel_dim=0, channel_dim=0,
min_positive=0.0, min_positive=0.0,
max_positive=1.0, max_positive=1.0,
@ -402,10 +628,17 @@ def _test_double_swish_deriv():
torch.autograd.gradcheck(m, x) torch.autograd.gradcheck(m, x)
def _test_get_max_eig_proportion():
x = torch.randn(100, 128)
d = torch.randn(128) * (128 ** -0.5)
get_max_eig_proportion(x, d, True)
if __name__ == "__main__": if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
_test_max_eig_limiter()
_test_get_max_eig_proportion()
_test_activation_balancer_sign() _test_activation_balancer_sign()
_test_activation_balancer_magnitude() _test_activation_balancer_magnitude()
_test_basic_norm() _test_basic_norm()