Merge branch 'scaled_adam_exp7c' into scaled_adam_exp11c

This commit is contained in:
Daniel Povey 2022-09-22 18:15:44 +08:00
commit 03a77f8ae5
3 changed files with 204 additions and 89 deletions

View File

@ -173,7 +173,8 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1, max_abs=10.0), ActivationBalancer(dim_feedforward,
channel_dim=-1, max_abs=10.0),
DoubleSwish(), DoubleSwish(),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, ScaledLinear(dim_feedforward, d_model,
@ -182,7 +183,8 @@ class ConformerEncoderLayer(nn.Module):
self.feed_forward_macaron = nn.Sequential( self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1, max_abs=10.0), ActivationBalancer(dim_feedforward,
channel_dim=-1, max_abs=10.0),
DoubleSwish(), DoubleSwish(),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, ScaledLinear(dim_feedforward, d_model,
@ -196,7 +198,10 @@ class ConformerEncoderLayer(nn.Module):
# try to ensure the output is close to zero-mean (or at least, zero-median). # try to ensure the output is close to zero-mean (or at least, zero-median).
self.balancer = ActivationBalancer( self.balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 d_model, channel_dim=-1,
min_positive=0.45, max_positive=0.55,
max_abs=6.0,
max_var_per_eig=0.2,
) )
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@ -247,8 +252,6 @@ class ConformerEncoderLayer(nn.Module):
# multi-headed self-attention module # multi-headed self-attention module
src_att = self.self_attn( src_att = self.self_attn(
src,
src,
src, src,
pos_emb=pos_emb, pos_emb=pos_emb,
attn_mask=src_mask, attn_mask=src_mask,
@ -464,8 +467,12 @@ class RelPositionMultiheadAttention(nn.Module):
), "embed_dim must be divisible by num_heads" ), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0) self.in_balancer = ActivationBalancer(3 * embed_dim,
self.proj_balancer = ActivationBalancer(channel_dim=-1, max_abs=10.0) channel_dim=-1, max_abs=5.0,
max_var_per_eig=0.2)
self.proj_balancer = ActivationBalancer(embed_dim,
channel_dim=-1, max_abs=10.0,
min_positive=0.0, max_positive=1.0)
self.out_proj = ScaledLinear( self.out_proj = ScaledLinear(
embed_dim, embed_dim, bias=True, initial_scale=0.5 embed_dim, embed_dim, bias=True, initial_scale=0.5
) )
@ -484,9 +491,7 @@ class RelPositionMultiheadAttention(nn.Module):
def forward( def forward(
self, self,
query: Tensor, x: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, need_weights: bool = True,
@ -494,7 +499,7 @@ class RelPositionMultiheadAttention(nn.Module):
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
r""" r"""
Args: Args:
query, key, value: map a query and a set of key-value pairs to an output. x: input to be projected to query, key, value
pos_emb: Positional embedding tensor pos_emb: Positional embedding tensor
key_padding_mask: if provided, specified padding elements in the key will key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True, be ignored by the attention. When given a binary mask and a value is True,
@ -507,11 +512,7 @@ class RelPositionMultiheadAttention(nn.Module):
Shape: Shape:
- Inputs: - Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension. the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension. the embedding dimension.
@ -534,9 +535,7 @@ class RelPositionMultiheadAttention(nn.Module):
L is the target sequence length, S is the source sequence length. L is the target sequence length, S is the source sequence length.
""" """
return self.multi_head_attention_forward( return self.multi_head_attention_forward(
query, self.in_balancer(self.in_proj(x)),
key,
value,
pos_emb, pos_emb,
self.embed_dim, self.embed_dim,
self.num_heads, self.num_heads,
@ -578,11 +577,9 @@ class RelPositionMultiheadAttention(nn.Module):
def multi_head_attention_forward( def multi_head_attention_forward(
self, self,
query: Tensor, x: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
embed_dim_to_check: int, embed_dim: int,
num_heads: int, num_heads: int,
in_proj_weight: Tensor, in_proj_weight: Tensor,
in_proj_bias: Tensor, in_proj_bias: Tensor,
@ -598,7 +595,7 @@ class RelPositionMultiheadAttention(nn.Module):
Args: Args:
query, key, value: map a query and a set of key-value pairs to an output. query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor pos_emb: Positional embedding tensor
embed_dim_to_check: total dimension of the model. embed_dim: total dimension of the model.
num_heads: parallel attention heads. num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias. in_proj_weight, in_proj_bias: input projection weight and bias.
dropout_p: probability of an element to be zeroed. dropout_p: probability of an element to be zeroed.
@ -640,9 +637,7 @@ class RelPositionMultiheadAttention(nn.Module):
L is the target sequence length, S is the source sequence length. L is the target sequence length, S is the source sequence length.
""" """
tgt_len, bsz, embed_dim = query.size() tgt_len, bsz, _ = x.size()
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads head_dim = embed_dim // num_heads
assert ( assert (
@ -651,62 +646,10 @@ class RelPositionMultiheadAttention(nn.Module):
scaling = float(head_dim) ** -0.5 scaling = float(head_dim) ** -0.5
def linear(x, w, b):
return self.in_balancer(nn.functional.linear(x, w, b))
if torch.equal(query, key) and torch.equal(key, value): # self-attention
# self-attention q, k, v = x.chunk(3, dim=-1)
q, k, v = linear(
query, in_proj_weight, in_proj_bias
).chunk(3, dim=-1)
elif torch.equal(key, value):
# encoder-decoder attention
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = linear(key, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim * 2
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = linear(value, _w, _b)
if attn_mask is not None: if attn_mask is not None:
assert ( assert (
@ -726,15 +669,15 @@ class RelPositionMultiheadAttention(nn.Module):
if attn_mask.dim() == 2: if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0) attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: if list(attn_mask.size()) != [1, tgt_len, tgt_len]:
raise RuntimeError( raise RuntimeError(
"The size of the 2D attn_mask is not correct." "The size of the 2D attn_mask is not correct."
) )
elif attn_mask.dim() == 3: elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [ if list(attn_mask.size()) != [
bsz * num_heads, bsz * num_heads,
query.size(0), tgt_len,
key.size(0), tgt_len,
]: ]:
raise RuntimeError( raise RuntimeError(
"The size of the 3D attn_mask is not correct." "The size of the 3D attn_mask is not correct."
@ -900,6 +843,7 @@ class ConvolutionModule(nn.Module):
# it will be in a better position to start learning something, i.e. to latch onto # it will be in a better position to start learning something, i.e. to latch onto
# the correct range. # the correct range.
self.deriv_balancer1 = ActivationBalancer( self.deriv_balancer1 = ActivationBalancer(
2 * channels,
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
) )
@ -914,7 +858,7 @@ class ConvolutionModule(nn.Module):
) )
self.deriv_balancer2 = ActivationBalancer( self.deriv_balancer2 = ActivationBalancer(
channel_dim=1, min_positive=0.05, max_positive=1.0 channels, channel_dim=1, min_positive=0.05, max_positive=1.0
) )
self.activation = DoubleSwish() self.activation = DoubleSwish()
@ -1000,7 +944,8 @@ class Conv2dSubsampling(nn.Module):
kernel_size=3, kernel_size=3,
padding=1, padding=1,
), ),
ActivationBalancer(channel_dim=1), ActivationBalancer(layer1_channels,
channel_dim=1),
DoubleSwish(), DoubleSwish(),
nn.Conv2d( nn.Conv2d(
in_channels=layer1_channels, in_channels=layer1_channels,
@ -1008,7 +953,8 @@ class Conv2dSubsampling(nn.Module):
kernel_size=3, kernel_size=3,
stride=2, stride=2,
), ),
ActivationBalancer(channel_dim=1), ActivationBalancer(layer2_channels,
channel_dim=1),
DoubleSwish(), DoubleSwish(),
nn.Conv2d( nn.Conv2d(
in_channels=layer2_channels, in_channels=layer2_channels,
@ -1016,7 +962,8 @@ class Conv2dSubsampling(nn.Module):
kernel_size=3, kernel_size=3,
stride=2, stride=2,
), ),
ActivationBalancer(channel_dim=1), ActivationBalancer(layer3_channels,
channel_dim=1),
DoubleSwish(), DoubleSwish(),
) )
out_height = (((in_channels - 1) // 2 - 1) // 2) out_height = (((in_channels - 1) // 2 - 1) // 2)
@ -1027,6 +974,7 @@ class Conv2dSubsampling(nn.Module):
self.out_norm = BasicNorm(out_channels, learn_eps=False) self.out_norm = BasicNorm(out_channels, learn_eps=False)
# constrain median of output to be close to zero. # constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer( self.out_balancer = ActivationBalancer(
out_channels,
channel_dim=-1, min_positive=0.45, max_positive=0.55 channel_dim=-1, min_positive=0.45, max_positive=0.55
) )

View File

@ -254,7 +254,7 @@ class ScaledAdam(Optimizer):
if ans < 1.0: if ans < 1.0:
state["num_clipped"] += 1 state["num_clipped"] += 1
if ans < 0.1: if ans < 0.1:
logging.warn("Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
return ans return ans

View File

@ -114,6 +114,108 @@ class ActivationBalancerFunction(torch.autograd.Function):
return x_grad - neg_delta_grad, None, None, None, None, None, None return x_grad - neg_delta_grad, None, None, None, None, None, None
def find_direction_coeffs(x: Tensor,
prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
Figure out (an approximation to) the proportion of the variance of a set of
feature vectors that can be attributed to the top eigen-direction.
Args:
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
of the top eigen-direction, or a random direction if this is the first
iteration. Does not have to be normalized, but should be nonzero.
Returns: (cur_direction, coeffs), where:
cur_direction: a Tensor of shape (num_channels,) that is the current
estimate of the top eigen-direction.
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
approximately minimizes, (x - coeffs * cur_direction).norm()
"""
(num_frames, num_channels) = x.shape
assert num_channels > 1 and num_frames > 1
assert prev_direction.shape == (num_channels,)
# `coeffs` are the coefficients of `prev_direction` in x.
# actually represent the coeffs up to a constant positive factor.
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20)
return cur_direction, coeffs
class MaxEigLimiterFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
direction: Tensor,
channel_dim: int,
prob: float,
subtract_mean: bool,
max_variance_proportion: float,
grad_scale: float) -> Tuple[Tensor, Tensor]:
if random.random() > prob:
return x, direction
eps = 1.0e-20
num_channels = x.shape[channel_dim]
assert max_variance_proportion > 1.0 / num_channels
orig_x = x
x = x.transpose(channel_dim, -1).reshape(-1, num_channels)
if subtract_mean:
x = x - x.mean(dim=0)
new_direction, coeffs = find_direction_coeffs(x, direction)
x_var = (x**2).mean()
x_residual = x - coeffs * new_direction
x_residual_var = (x_residual**2).mean()
# `variance_proportion` is the proportion of the variance accounted for
# by the top eigen-direction.
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
ans_direction = direction + new_direction # ensure nonzero even if x == 0
ans_direction = ans_direction / ans_direction.norm()
if random.random() < 0.001:
logging.info(f"variance_proportion = {variance_proportion.item()}")
# Caution: this causes a CUDA sync, which is not ideal.
if variance_proportion >= max_variance_proportion:
ctx.channel_dim = channel_dim
ctx.subtract_mean = subtract_mean
ctx.grad_scale = grad_scale
ctx.save_for_backward(orig_x.detach(),
coeffs.detach(),
new_direction.detach())
return orig_x, ans_direction
@staticmethod
def backward(ctx, x_grad, *args):
# the *args is all the other derivs, which should be None or zero.
if not hasattr(ctx, 'channel_dim'):
# the top eig's proportion of the variance was below the threshold.
return x_grad, None, None, None, None, None, None
with torch.enable_grad():
(x_orig, coeffs, new_direction) = ctx.saved_tensors
x_orig.requires_grad = True
num_channels = x_orig.shape[ctx.channel_dim]
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
new_direction.requires_grad = False
if ctx.subtract_mean:
x = x - x.mean(dim=0)
x_var = (x ** 2).mean()
x_residual = x - coeffs * new_direction
x_residual_var = (x_residual ** 2).mean()
# `variance_proportion` is the proportion of the variance accounted for
# by the top eigen-direction. This is to be minimized.
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
variance_proportion.backward()
x_orig_grad = x_orig.grad
x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20)
return x_grad + x_extra_grad.detach(), None, None, None, None, None, None
class BasicNorm(torch.nn.Module): class BasicNorm(torch.nn.Module):
@ -236,6 +338,7 @@ class ActivationBalancer(torch.nn.Module):
Args: Args:
num_channels: the number of channels
channel_dim: the dimension/axis corresponding to the channel, e.g. channel_dim: the dimension/axis corresponding to the channel, e.g.
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
min_positive: the minimum, per channel, of the proportion of the time min_positive: the minimum, per channel, of the proportion of the time
@ -252,29 +355,60 @@ class ActivationBalancer(torch.nn.Module):
max_abs: the maximum average-absolute-value difference from the mean max_abs: the maximum average-absolute-value difference from the mean
value per channel, which we allow, before we start to modify value per channel, which we allow, before we start to modify
the derivatives to prevent this. the derivatives to prevent this.
max_var_per_eig: the maximum proportion of the variance of the
features/channels, after mean subtraction, that can come from
any given eigenvalue.
""" """
def __init__( def __init__(
self, self,
num_channels: int,
channel_dim: int, channel_dim: int,
min_positive: float = 0.05, min_positive: float = 0.05,
max_positive: float = 0.95, max_positive: float = 0.95,
max_factor: float = 0.01, max_factor: float = 0.01,
min_abs: float = 0.2, min_abs: float = 0.2,
max_abs: float = 100.0, max_abs: float = 100.0,
max_var_per_eig: float = 0.0,
): ):
super(ActivationBalancer, self).__init__() super(ActivationBalancer, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.min_positive = min_positive self.min_positive = min_positive
self.max_positive = max_positive self.max_positive = max_positive
self.max_factor = max_factor self.max_factor = max_factor
self.min_abs = min_abs self.min_abs = min_abs
self.max_abs = max_abs self.max_abs = max_abs
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
self.max_var_per_eig = max_var_per_eig
if max_var_per_eig > 0.0:
with torch.no_grad():
# arbitrary.. would use randn() but want to leave the rest of the model's
# random parameters unchanged for comparison
direction = torch.arange(num_channels).to(torch.float)
direction = direction / direction.norm()
self.register_buffer('max_eig_direction', direction)
else:
self.max_eig_direction = None
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return x return x
if self.max_var_per_eig > 0:
max_eig_prob = 0.25
with torch.cuda.amp.autocast(enabled=False):
x, new_direction = MaxEigLimiterFunction.apply(
x, self.max_eig_direction,
self.channel_dim,
max_eig_prob,
True, # subtract_mean
self.max_var_per_eig,
self.max_factor / max_eig_prob,
)
self.max_eig_direction[:] = new_direction.detach()
return ActivationBalancerFunction.apply( return ActivationBalancerFunction.apply(
x, x,
self.channel_dim, self.channel_dim,
@ -326,6 +460,35 @@ class DoubleSwish(torch.nn.Module):
return DoubleSwishFunction.apply(x) return DoubleSwishFunction.apply(x)
def _test_max_eig_limiter():
for proportion in [0.1, 0.5, 10.0]:
logging.info(f"proportion = {proportion}")
x = torch.randn(100, 128)
direction = torch.randn(128)
coeffs = torch.randn(100, 1)
x += proportion * direction * coeffs
x.requires_grad = True
y, new_direction = MaxEigLimiterFunction.apply(x, direction,
1, # channel_dim
1.0, # prob
True, # subtract_mean
0.5, # max_variance_proportion
0.1, # grad_scale
)
cosine = (new_direction * direction).sum() / (new_direction.norm() * direction.norm())
logging.info(f"Direction cosine = {cosine}")
y_grad = torch.randn_like(x)
y.backward(gradient=y_grad)
if proportion < 0.2:
assert torch.allclose(x.grad, y_grad)
elif proportion > 1.0:
assert not torch.allclose(x.grad, y_grad)
@ -336,6 +499,7 @@ def _test_activation_balancer_sign():
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
m = ActivationBalancer( m = ActivationBalancer(
probs.numel(),
channel_dim=0, channel_dim=0,
min_positive=0.05, min_positive=0.05,
max_positive=0.98, max_positive=0.98,
@ -361,6 +525,7 @@ def _test_activation_balancer_magnitude():
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
m = ActivationBalancer( m = ActivationBalancer(
magnitudes.numel(),
channel_dim=0, channel_dim=0,
min_positive=0.0, min_positive=0.0,
max_positive=1.0, max_positive=1.0,
@ -402,10 +567,12 @@ def _test_double_swish_deriv():
torch.autograd.gradcheck(m, x) torch.autograd.gradcheck(m, x)
if __name__ == "__main__": if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
_test_max_eig_limiter()
_test_activation_balancer_sign() _test_activation_balancer_sign()
_test_activation_balancer_magnitude() _test_activation_balancer_magnitude()
_test_basic_norm() _test_basic_norm()