Merge branch 'scaled_adam_exp724' into scaled_adam_exp726

This commit is contained in:
Daniel Povey 2022-12-17 15:46:57 +08:00
commit 744dca1c9b
2 changed files with 45 additions and 176 deletions

View File

@ -463,120 +463,6 @@ class BasicNorm(torch.nn.Module):
return x * scales
class LinearWithAuxLossFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, weight: Tensor,
aux_grad_scale: float) -> Tensor:
"""
Returns matmul(x, weight.t()).
In the backward pass it will include an auxiliary loss based on predicting x from
matmul(y, weight).
"""
if torch.is_autocast_enabled():
x = x.to(torch.float16)
ctx.save_for_backward(x, weight)
ctx.aux_grad_scale = aux_grad_scale
return torch.matmul(x, weight.t())
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]:
x, weight = ctx.saved_tensors
x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype))
weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(),
x.reshape(-1, x.shape[-1]).to(ans_grad.dtype))
with torch.cuda.amp.autocast(enabled=False):
with torch.enable_grad():
x = x.to(weight.dtype)
x, weight = x.detach(), weight.detach()
weight.requires_grad = True
# recompute y as we need the gradient; this is easier to implement than
# saving y in the context.
y = torch.matmul(x, weight.t())
z = torch.matmul(y, weight)
# subtract mean
dims_to_mean = tuple(range(x.ndim-1))
x = x - x.mean(dim=dims_to_mean)
z = z - z.mean(dim=dims_to_mean)
# compute optimal scale on z
with torch.no_grad():
alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20)
diff = x - alpha * z
# meansq is the loss function.
meansq = (diff ** 2).mean()
meansq.backward()
weight_aux_grad = weight.grad
with torch.cuda.amp.autocast(enabled=False):
weight_grad_norm = weight_grad.to(torch.float32).norm()
aux_grad_norm = weight_aux_grad.norm()
weight_grad_scale = ctx.aux_grad_scale * weight_grad_norm / (aux_grad_norm + 1.0e-20)
weight_grad = weight_grad + (weight_grad_scale * weight_aux_grad).to(weight_grad.dtype)
return x_grad, weight_grad, None
class LinearWithAuxLoss(nn.Module):
"""
A linear layer with an auxiliary loss that you can put on a schedule, that
encourages it to correspond to the largest-variance directions of the
input features.
Suppose the input is x, and this layer computes:
y = M x
(the bias is applied separately), then we define:
z = exp(alpha) * M^T y
where alpha is learnable; and the auxiliary loss will be:
aux_loss = normalize_mean(z - x)^2.
(normalize_mean refers to subtracting the average value per channel,
over the minibatch).
In the backward pass we compute the derivative of the auxiliary loss
and add it to the weight and bias grads, with a scale chosen such
that the extra grad's norm equals `aux_grad_scales` times the norm
of the existing grad.
"""
def __init__(self,
in_channels: int,
out_channels: int,
bias: bool = True,
aux_grad_scale: Optional[FloatLike] = None,
prob: FloatLike = 0.25,
initial_scale: float = 1.0,
):
super().__init__()
if aux_grad_scale is None:
aux_grad_scale = ScheduledFloat((0.0, 1.0), (1000.0, 0.1),
(2000.0, 0.01), (8000.0, 0.0))
self.aux_grad_scale = aux_grad_scale
self.prob = prob
self.weight = nn.Parameter(torch.randn(out_channels, in_channels)
* (in_channels ** -0.5) * initial_scale)
if bias:
self.bias = nn.Parameter(torch.randn(out_channels) *
0.01 * initial_scale)
else:
self.register_parameter('bias', None)
def forward(self,
x: Tensor):
aux_grad_scale = float(self.aux_grad_scale)
if (not self.training or torch.jit.is_scripting() or
aux_grad_scale == 0.0 or random.random() > float(self.prob)):
return torch.nn.functional.linear(x, self.weight, self.bias)
else:
ans = LinearWithAuxLossFunction.apply(x, self.weight,
aux_grad_scale)
if self.bias is not None:
ans += self.bias
return ans
def ScaledLinear(*args,
@ -711,12 +597,14 @@ class ActivationBalancer(torch.nn.Module):
scale_gain_factor: FloatLike = 0.04,
min_abs: FloatLike = 0.2,
max_abs: FloatLike = 100.0,
min_prob: FloatLike = 0.1,
prob: Optional[FloatLike] = None,
):
super(ActivationBalancer, self).__init__()
# CAUTION: this code expects self.batch_count to be overwritten in the main training
# loop.
self.batch_count = 0
if prob is None:
prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4)
self.prob = prob
# actually self.num_channels is no longer needed except for an assertion.
self.num_channels = num_channels
@ -726,7 +614,6 @@ class ActivationBalancer(torch.nn.Module):
self.max_factor = max_factor
self.min_abs = min_abs
self.max_abs = max_abs
self.min_prob = min_prob
self.sign_gain_factor = sign_gain_factor
self.scale_gain_factor = scale_gain_factor
@ -738,9 +625,7 @@ class ActivationBalancer(torch.nn.Module):
if torch.jit.is_scripting() or not x.requires_grad:
return _no_op(x)
# the prob of doing some work exponentially decreases from 0.5 till it hits
# a floor at min_prob (==0.1, by default)
prob = max(float(self.min_prob), 0.5 ** (1 + (self.batch_count / 4000.0)))
prob = float(self.prob)
if random.random() < prob:
assert x.shape[self.channel_dim] == self.num_channels

View File

@ -36,7 +36,6 @@ from scaling import (
ScaledConv1d,
ScaledConv2d,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
LinearWithAuxLoss,
Whiten,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
penalize_abs_values_gt,
@ -355,16 +354,11 @@ class Zipformer(EncoderInterface):
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
return ScheduledFloat((0.0, x),
(12000.0, ratio * x),
(20000.0, ratio * x),
default=x)
def _aux_grad_scale() -> float:
return 0.2
def _aux_grad_prob_out() -> ScheduledFloat:
return 0.0 # ScheduledFloat((0.0, 0.25), (1000.0, 0.05), (8000.0, 0.0125))
def _aux_grad_prob_in() -> ScheduledFloat:
return 0.0 # ScheduledFloat((0.0, 0.25), (1000.0, 0.0))
#return ScheduledFloat((0.0, 0.25), (1000.0, 0.05), (8000.0, 0.0125))
def _balancer_schedule(min_prob: float):
return ScheduledFloat((0.0, 0.4), (8000.0, min_prob))
@ -398,8 +392,8 @@ class ZipformerEncoderLayer(nn.Module):
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
# to work correctly.
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
nonlin_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (20000, 0.0), default=0),
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
bypass_max: FloatLike = 1.0,
@ -410,10 +404,10 @@ class ZipformerEncoderLayer(nn.Module):
# probability of skipping the entire layer.
self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
# skip probability for dynamic modules (meaning: anything but feedforward).
self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate)
# an additional skip probability that applies to NoninAttentionModule to stop it from
self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
# an additional skip probability that applies to ConvModule to stop it from
# contributing too much early on.
self.nonlin_skip_rate = copy.deepcopy(nonlin_skip_rate)
self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
# ever becoming zero.
@ -507,7 +501,7 @@ class ZipformerEncoderLayer(nn.Module):
src_orig = src
# dropout rate for non-feedforward submodules
dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
if self.self_attn_weights is not None:
@ -528,7 +522,7 @@ class ZipformerEncoderLayer(nn.Module):
# skip the layer
return src, attn_weights
use_self_attn = (random.random() >= dynamic_skip_rate)
use_self_attn = (random.random() >= attention_skip_rate)
if use_self_attn:
selected_attn_weights = attn_weights[head_offset:head_offset+2]
if random.random() < float(self.const_attention_rate):
@ -541,7 +535,7 @@ class ZipformerEncoderLayer(nn.Module):
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
if torch.jit.is_scripting() or (use_self_attn and random.random() >= float(self.nonlin_skip_rate)):
if torch.jit.is_scripting() or use_self_attn:
src = src + self.nonlin_attention_module(src,
selected_attn_weights[0:1])
@ -555,7 +549,7 @@ class ZipformerEncoderLayer(nn.Module):
src = src + self.self_attn(
src, attn_weights)
if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate:
if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate):
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
src = src + self.feed_forward2(src)
@ -750,6 +744,7 @@ class AttentionDownsample(torch.nn.Module):
"""
super(AttentionDownsample, self).__init__()
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
self.bias = nn.Parameter(torch.zeros(downsample))
self.name = None # will be set from training code
self.dropout = copy.deepcopy(dropout)
@ -783,8 +778,9 @@ class AttentionDownsample(torch.nn.Module):
assert src.shape[0] == d_seq_len * ds
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
# scores: (d_seq_len, downsample, batch_size)
# scores: (d_seq_len, downsample, batch_size, 1)
scores = (src * self.query).sum(dim=-1, keepdim=True)
scores = scores + self.bias.unsqueeze(-1).unsqueeze(-1)
scores = penalize_abs_values_gt(scores,
limit=20.0,
@ -820,7 +816,7 @@ class SimpleUpsample(torch.nn.Module):
num_channels: int,
upsample: int):
super(SimpleUpsample, self).__init__()
self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01)
self.upsample = upsample
def forward(self,
src: Tensor) -> Tensor:
@ -829,10 +825,9 @@ class SimpleUpsample(torch.nn.Module):
Returns a tensor of shape
( (seq_len*upsample), batch_size, num_channels)
"""
upsample = self.bias.shape[0]
upsample = self.upsample
(seq_len, batch_size, num_channels) = src.shape
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
src = src + self.bias.unsqueeze(1)
src = src.reshape(seq_len * upsample, batch_size, num_channels)
return src
@ -1296,15 +1291,11 @@ class AttentionSqueeze(nn.Module):
super().__init__()
self.bottleneck_dim = bottleneck_dim
self.in_proj = LinearWithAuxLoss(embed_dim, hidden_dim,
bias=False,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim,
bottleneck_dim,
aux_grad_scale=_aux_grad_scale(),
prob=_aux_grad_prob_in())
self.in_proj = nn.Linear(embed_dim, hidden_dim,
bias=False)
self.to_bottleneck_proj = nn.Linear(embed_dim,
bottleneck_dim)
# bottleneck_balancer is before the actiation. Mostly, for well-trained
# instances of this module, the mean absolute values per channel are in
@ -1315,7 +1306,6 @@ class AttentionSqueeze(nn.Module):
min_positive=0.2, max_positive=0.8,
min_abs=0.05,
max_abs=ScheduledFloat((0.0, 0.5), (4000.0, 1.0), default=1.0),
min_prob=0.1,
)
self.bottleneck_activation = TanSwish() # in bottleneck
self.activation = Identity() # for diagnostics
@ -1329,13 +1319,13 @@ class AttentionSqueeze(nn.Module):
hidden_dim, channel_dim=-1,
min_positive=0.2, max_positive=0.8,
min_abs=0.2, max_abs=1.0,
min_prob=0.05,
prob=_balancer_schedule(0.05),
)
self.activation_balancer = ActivationBalancer(
hidden_dim, channel_dim=-1,
min_positive=0.2, max_positive=0.8,
min_abs=0.2, max_abs=1.0,
min_prob=0.05,
prob=_balancer_schedule(0.05),
)
self.activation_whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
@ -1343,17 +1333,16 @@ class AttentionSqueeze(nn.Module):
grad_scale=0.01)
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, hidden_dim)
self.from_bottleneck_proj = nn.Linear(bottleneck_dim, hidden_dim)
self.out_proj = LinearWithAuxLoss(hidden_dim, embed_dim,
aux_grad_scale=_aux_grad_scale(),
prob=_aux_grad_prob_out(),
bias=False, initial_scale=0.05)
self.out_proj = ScaledLinear(hidden_dim, embed_dim,
bias=False, initial_scale=0.05)
self.out_balancer = ActivationBalancer(
embed_dim, channel_dim=-1,
min_positive=0.3, max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.001), (8000.0, 0.005)),
min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)),
prob=0.05, # out of concern for memory usage
)
@ -1404,21 +1393,19 @@ class FeedforwardModule(nn.Module):
feedforward_dim: int,
dropout: FloatLike):
super(FeedforwardModule, self).__init__()
self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
self.hidden_balancer = ActivationBalancer(feedforward_dim,
channel_dim=-1,
min_positive=0.3,
max_positive=1.0,
min_abs=0.75,
max_abs=5.0,
min_prob=0.25)
max_abs=5.0)
self.activation = SwooshL()
self.dropout = Dropout2(dropout)
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
initial_scale=0.01,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
initial_scale=0.01)
self.out_whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(7.5),
prob=(0.025, 0.25),
@ -1488,11 +1475,11 @@ class NonlinAttentionModule(nn.Module):
self.balancer2 = ActivationBalancer(
channels, channel_dim=-1,
min_positive=0.3, max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.001), (8000.0, 0.005)),
min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)),
prob=0.05, # out of concern for memory usage
)
def forward(self,
x: Tensor,
attn_weights: Tensor,
@ -1560,9 +1547,8 @@ class ConvolutionModule(nn.Module):
bottleneck_dim = channels
self.in_proj = LinearWithAuxLoss(
self.in_proj = nn.Linear(
channels, 2 * bottleneck_dim,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()
)
@ -1618,9 +1604,8 @@ class ConvolutionModule(nn.Module):
prob=(0.025, 0.25),
grad_scale=0.01)
self.out_proj = LinearWithAuxLoss(
self.out_proj = ScaledLinear(
bottleneck_dim, channels,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(),
initial_scale=0.05,
)
@ -1834,8 +1819,7 @@ class Conv2dSubsampling(nn.Module):
self.scale_max = 1.0
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1))
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
self.out = nn.Linear(out_height * layer3_channels, out_channels)
self.dropout = Dropout2(dropout)