mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp724' into scaled_adam_exp726
This commit is contained in:
commit
744dca1c9b
@ -463,120 +463,6 @@ class BasicNorm(torch.nn.Module):
|
||||
return x * scales
|
||||
|
||||
|
||||
class LinearWithAuxLossFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, weight: Tensor,
|
||||
aux_grad_scale: float) -> Tensor:
|
||||
"""
|
||||
Returns matmul(x, weight.t()).
|
||||
In the backward pass it will include an auxiliary loss based on predicting x from
|
||||
matmul(y, weight).
|
||||
"""
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(torch.float16)
|
||||
ctx.save_for_backward(x, weight)
|
||||
ctx.aux_grad_scale = aux_grad_scale
|
||||
return torch.matmul(x, weight.t())
|
||||
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]:
|
||||
x, weight = ctx.saved_tensors
|
||||
|
||||
x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype))
|
||||
weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(),
|
||||
x.reshape(-1, x.shape[-1]).to(ans_grad.dtype))
|
||||
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch.enable_grad():
|
||||
x = x.to(weight.dtype)
|
||||
x, weight = x.detach(), weight.detach()
|
||||
weight.requires_grad = True
|
||||
# recompute y as we need the gradient; this is easier to implement than
|
||||
# saving y in the context.
|
||||
y = torch.matmul(x, weight.t())
|
||||
z = torch.matmul(y, weight)
|
||||
# subtract mean
|
||||
dims_to_mean = tuple(range(x.ndim-1))
|
||||
x = x - x.mean(dim=dims_to_mean)
|
||||
z = z - z.mean(dim=dims_to_mean)
|
||||
# compute optimal scale on z
|
||||
with torch.no_grad():
|
||||
alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20)
|
||||
diff = x - alpha * z
|
||||
# meansq is the loss function.
|
||||
meansq = (diff ** 2).mean()
|
||||
meansq.backward()
|
||||
weight_aux_grad = weight.grad
|
||||
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
weight_grad_norm = weight_grad.to(torch.float32).norm()
|
||||
aux_grad_norm = weight_aux_grad.norm()
|
||||
weight_grad_scale = ctx.aux_grad_scale * weight_grad_norm / (aux_grad_norm + 1.0e-20)
|
||||
weight_grad = weight_grad + (weight_grad_scale * weight_aux_grad).to(weight_grad.dtype)
|
||||
|
||||
return x_grad, weight_grad, None
|
||||
|
||||
|
||||
|
||||
class LinearWithAuxLoss(nn.Module):
|
||||
"""
|
||||
A linear layer with an auxiliary loss that you can put on a schedule, that
|
||||
encourages it to correspond to the largest-variance directions of the
|
||||
input features.
|
||||
|
||||
Suppose the input is x, and this layer computes:
|
||||
y = M x
|
||||
(the bias is applied separately), then we define:
|
||||
z = exp(alpha) * M^T y
|
||||
where alpha is learnable; and the auxiliary loss will be:
|
||||
aux_loss = normalize_mean(z - x)^2.
|
||||
(normalize_mean refers to subtracting the average value per channel,
|
||||
over the minibatch).
|
||||
In the backward pass we compute the derivative of the auxiliary loss
|
||||
and add it to the weight and bias grads, with a scale chosen such
|
||||
that the extra grad's norm equals `aux_grad_scales` times the norm
|
||||
of the existing grad.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
bias: bool = True,
|
||||
aux_grad_scale: Optional[FloatLike] = None,
|
||||
prob: FloatLike = 0.25,
|
||||
initial_scale: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
if aux_grad_scale is None:
|
||||
aux_grad_scale = ScheduledFloat((0.0, 1.0), (1000.0, 0.1),
|
||||
(2000.0, 0.01), (8000.0, 0.0))
|
||||
|
||||
self.aux_grad_scale = aux_grad_scale
|
||||
self.prob = prob
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_channels, in_channels)
|
||||
* (in_channels ** -0.5) * initial_scale)
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.randn(out_channels) *
|
||||
0.01 * initial_scale)
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
|
||||
def forward(self,
|
||||
x: Tensor):
|
||||
aux_grad_scale = float(self.aux_grad_scale)
|
||||
if (not self.training or torch.jit.is_scripting() or
|
||||
aux_grad_scale == 0.0 or random.random() > float(self.prob)):
|
||||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
else:
|
||||
ans = LinearWithAuxLossFunction.apply(x, self.weight,
|
||||
aux_grad_scale)
|
||||
if self.bias is not None:
|
||||
ans += self.bias
|
||||
return ans
|
||||
|
||||
|
||||
def ScaledLinear(*args,
|
||||
@ -711,12 +597,14 @@ class ActivationBalancer(torch.nn.Module):
|
||||
scale_gain_factor: FloatLike = 0.04,
|
||||
min_abs: FloatLike = 0.2,
|
||||
max_abs: FloatLike = 100.0,
|
||||
min_prob: FloatLike = 0.1,
|
||||
prob: Optional[FloatLike] = None,
|
||||
):
|
||||
super(ActivationBalancer, self).__init__()
|
||||
# CAUTION: this code expects self.batch_count to be overwritten in the main training
|
||||
# loop.
|
||||
self.batch_count = 0
|
||||
|
||||
|
||||
if prob is None:
|
||||
prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4)
|
||||
self.prob = prob
|
||||
|
||||
# actually self.num_channels is no longer needed except for an assertion.
|
||||
self.num_channels = num_channels
|
||||
@ -726,7 +614,6 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.max_factor = max_factor
|
||||
self.min_abs = min_abs
|
||||
self.max_abs = max_abs
|
||||
self.min_prob = min_prob
|
||||
self.sign_gain_factor = sign_gain_factor
|
||||
self.scale_gain_factor = scale_gain_factor
|
||||
|
||||
@ -738,9 +625,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
if torch.jit.is_scripting() or not x.requires_grad:
|
||||
return _no_op(x)
|
||||
|
||||
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
||||
# a floor at min_prob (==0.1, by default)
|
||||
prob = max(float(self.min_prob), 0.5 ** (1 + (self.batch_count / 4000.0)))
|
||||
prob = float(self.prob)
|
||||
|
||||
if random.random() < prob:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
|
||||
@ -36,7 +36,6 @@ from scaling import (
|
||||
ScaledConv1d,
|
||||
ScaledConv2d,
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
LinearWithAuxLoss,
|
||||
Whiten,
|
||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||
penalize_abs_values_gt,
|
||||
@ -355,16 +354,11 @@ class Zipformer(EncoderInterface):
|
||||
|
||||
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
|
||||
return ScheduledFloat((0.0, x),
|
||||
(12000.0, ratio * x),
|
||||
(20000.0, ratio * x),
|
||||
default=x)
|
||||
|
||||
def _aux_grad_scale() -> float:
|
||||
return 0.2
|
||||
def _aux_grad_prob_out() -> ScheduledFloat:
|
||||
return 0.0 # ScheduledFloat((0.0, 0.25), (1000.0, 0.05), (8000.0, 0.0125))
|
||||
def _aux_grad_prob_in() -> ScheduledFloat:
|
||||
return 0.0 # ScheduledFloat((0.0, 0.25), (1000.0, 0.0))
|
||||
#return ScheduledFloat((0.0, 0.25), (1000.0, 0.05), (8000.0, 0.0125))
|
||||
def _balancer_schedule(min_prob: float):
|
||||
return ScheduledFloat((0.0, 0.4), (8000.0, min_prob))
|
||||
|
||||
|
||||
|
||||
@ -398,8 +392,8 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||
# to work correctly.
|
||||
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
||||
dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0),
|
||||
nonlin_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (20000, 0.0), default=0),
|
||||
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
||||
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
|
||||
bypass_max: FloatLike = 1.0,
|
||||
@ -410,10 +404,10 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
# probability of skipping the entire layer.
|
||||
self.layer_skip_rate = copy.deepcopy(layer_skip_rate)
|
||||
# skip probability for dynamic modules (meaning: anything but feedforward).
|
||||
self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate)
|
||||
# an additional skip probability that applies to NoninAttentionModule to stop it from
|
||||
self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
|
||||
# an additional skip probability that applies to ConvModule to stop it from
|
||||
# contributing too much early on.
|
||||
self.nonlin_skip_rate = copy.deepcopy(nonlin_skip_rate)
|
||||
self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
|
||||
|
||||
# min and max for self.bypass_scale, applied with probability 0.5 to avoid grads
|
||||
# ever becoming zero.
|
||||
@ -507,7 +501,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
src_orig = src
|
||||
|
||||
# dropout rate for non-feedforward submodules
|
||||
dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0
|
||||
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
|
||||
|
||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||
if self.self_attn_weights is not None:
|
||||
@ -528,7 +522,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
# skip the layer
|
||||
return src, attn_weights
|
||||
|
||||
use_self_attn = (random.random() >= dynamic_skip_rate)
|
||||
use_self_attn = (random.random() >= attention_skip_rate)
|
||||
if use_self_attn:
|
||||
selected_attn_weights = attn_weights[head_offset:head_offset+2]
|
||||
if random.random() < float(self.const_attention_rate):
|
||||
@ -541,7 +535,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
||||
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
|
||||
|
||||
if torch.jit.is_scripting() or (use_self_attn and random.random() >= float(self.nonlin_skip_rate)):
|
||||
if torch.jit.is_scripting() or use_self_attn:
|
||||
src = src + self.nonlin_attention_module(src,
|
||||
selected_attn_weights[0:1])
|
||||
|
||||
@ -555,7 +549,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
src = src + self.self_attn(
|
||||
src, attn_weights)
|
||||
|
||||
if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate:
|
||||
if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate):
|
||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
src = src + self.feed_forward2(src)
|
||||
@ -750,6 +744,7 @@ class AttentionDownsample(torch.nn.Module):
|
||||
"""
|
||||
super(AttentionDownsample, self).__init__()
|
||||
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
|
||||
self.bias = nn.Parameter(torch.zeros(downsample))
|
||||
|
||||
self.name = None # will be set from training code
|
||||
self.dropout = copy.deepcopy(dropout)
|
||||
@ -783,8 +778,9 @@ class AttentionDownsample(torch.nn.Module):
|
||||
assert src.shape[0] == d_seq_len * ds
|
||||
|
||||
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||
# scores: (d_seq_len, downsample, batch_size)
|
||||
# scores: (d_seq_len, downsample, batch_size, 1)
|
||||
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
||||
scores = scores + self.bias.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
scores = penalize_abs_values_gt(scores,
|
||||
limit=20.0,
|
||||
@ -820,7 +816,7 @@ class SimpleUpsample(torch.nn.Module):
|
||||
num_channels: int,
|
||||
upsample: int):
|
||||
super(SimpleUpsample, self).__init__()
|
||||
self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01)
|
||||
self.upsample = upsample
|
||||
|
||||
def forward(self,
|
||||
src: Tensor) -> Tensor:
|
||||
@ -829,10 +825,9 @@ class SimpleUpsample(torch.nn.Module):
|
||||
Returns a tensor of shape
|
||||
( (seq_len*upsample), batch_size, num_channels)
|
||||
"""
|
||||
upsample = self.bias.shape[0]
|
||||
upsample = self.upsample
|
||||
(seq_len, batch_size, num_channels) = src.shape
|
||||
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
|
||||
src = src + self.bias.unsqueeze(1)
|
||||
src = src.reshape(seq_len * upsample, batch_size, num_channels)
|
||||
return src
|
||||
|
||||
@ -1296,15 +1291,11 @@ class AttentionSqueeze(nn.Module):
|
||||
super().__init__()
|
||||
self.bottleneck_dim = bottleneck_dim
|
||||
|
||||
self.in_proj = LinearWithAuxLoss(embed_dim, hidden_dim,
|
||||
bias=False,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
|
||||
|
||||
self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim,
|
||||
bottleneck_dim,
|
||||
aux_grad_scale=_aux_grad_scale(),
|
||||
prob=_aux_grad_prob_in())
|
||||
self.in_proj = nn.Linear(embed_dim, hidden_dim,
|
||||
bias=False)
|
||||
|
||||
self.to_bottleneck_proj = nn.Linear(embed_dim,
|
||||
bottleneck_dim)
|
||||
|
||||
# bottleneck_balancer is before the actiation. Mostly, for well-trained
|
||||
# instances of this module, the mean absolute values per channel are in
|
||||
@ -1315,7 +1306,6 @@ class AttentionSqueeze(nn.Module):
|
||||
min_positive=0.2, max_positive=0.8,
|
||||
min_abs=0.05,
|
||||
max_abs=ScheduledFloat((0.0, 0.5), (4000.0, 1.0), default=1.0),
|
||||
min_prob=0.1,
|
||||
)
|
||||
self.bottleneck_activation = TanSwish() # in bottleneck
|
||||
self.activation = Identity() # for diagnostics
|
||||
@ -1329,13 +1319,13 @@ class AttentionSqueeze(nn.Module):
|
||||
hidden_dim, channel_dim=-1,
|
||||
min_positive=0.2, max_positive=0.8,
|
||||
min_abs=0.2, max_abs=1.0,
|
||||
min_prob=0.05,
|
||||
prob=_balancer_schedule(0.05),
|
||||
)
|
||||
self.activation_balancer = ActivationBalancer(
|
||||
hidden_dim, channel_dim=-1,
|
||||
min_positive=0.2, max_positive=0.8,
|
||||
min_abs=0.2, max_abs=1.0,
|
||||
min_prob=0.05,
|
||||
prob=_balancer_schedule(0.05),
|
||||
)
|
||||
self.activation_whiten = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(4.0, ratio=3.0),
|
||||
@ -1343,17 +1333,16 @@ class AttentionSqueeze(nn.Module):
|
||||
grad_scale=0.01)
|
||||
|
||||
|
||||
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, hidden_dim)
|
||||
self.from_bottleneck_proj = nn.Linear(bottleneck_dim, hidden_dim)
|
||||
|
||||
self.out_proj = LinearWithAuxLoss(hidden_dim, embed_dim,
|
||||
aux_grad_scale=_aux_grad_scale(),
|
||||
prob=_aux_grad_prob_out(),
|
||||
self.out_proj = ScaledLinear(hidden_dim, embed_dim,
|
||||
bias=False, initial_scale=0.05)
|
||||
|
||||
self.out_balancer = ActivationBalancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
min_positive=0.3, max_positive=0.7,
|
||||
min_abs=ScheduledFloat((0.0, 0.001), (8000.0, 0.005)),
|
||||
min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)),
|
||||
prob=0.05, # out of concern for memory usage
|
||||
)
|
||||
|
||||
|
||||
@ -1404,21 +1393,19 @@ class FeedforwardModule(nn.Module):
|
||||
feedforward_dim: int,
|
||||
dropout: FloatLike):
|
||||
super(FeedforwardModule, self).__init__()
|
||||
self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in())
|
||||
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
|
||||
|
||||
self.hidden_balancer = ActivationBalancer(feedforward_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=0.3,
|
||||
max_positive=1.0,
|
||||
min_abs=0.75,
|
||||
max_abs=5.0,
|
||||
min_prob=0.25)
|
||||
max_abs=5.0)
|
||||
self.activation = SwooshL()
|
||||
self.dropout = Dropout2(dropout)
|
||||
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
|
||||
initial_scale=0.01,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
|
||||
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
||||
initial_scale=0.01)
|
||||
|
||||
self.out_whiten = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.025, 0.25),
|
||||
@ -1488,11 +1475,11 @@ class NonlinAttentionModule(nn.Module):
|
||||
self.balancer2 = ActivationBalancer(
|
||||
channels, channel_dim=-1,
|
||||
min_positive=0.3, max_positive=0.7,
|
||||
min_abs=ScheduledFloat((0.0, 0.001), (8000.0, 0.005)),
|
||||
min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)),
|
||||
prob=0.05, # out of concern for memory usage
|
||||
)
|
||||
|
||||
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
attn_weights: Tensor,
|
||||
@ -1560,9 +1547,8 @@ class ConvolutionModule(nn.Module):
|
||||
bottleneck_dim = channels
|
||||
|
||||
|
||||
self.in_proj = LinearWithAuxLoss(
|
||||
self.in_proj = nn.Linear(
|
||||
channels, 2 * bottleneck_dim,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()
|
||||
)
|
||||
|
||||
|
||||
@ -1618,9 +1604,8 @@ class ConvolutionModule(nn.Module):
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
self.out_proj = LinearWithAuxLoss(
|
||||
self.out_proj = ScaledLinear(
|
||||
bottleneck_dim, channels,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(),
|
||||
initial_scale=0.05,
|
||||
)
|
||||
|
||||
@ -1834,8 +1819,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
self.scale_max = 1.0
|
||||
self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1))
|
||||
|
||||
self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels,
|
||||
aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out())
|
||||
self.out = nn.Linear(out_height * layer3_channels, out_channels)
|
||||
|
||||
self.dropout = Dropout2(dropout)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user