diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 672fd3465..208d15735 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -463,120 +463,6 @@ class BasicNorm(torch.nn.Module): return x * scales -class LinearWithAuxLossFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, weight: Tensor, - aux_grad_scale: float) -> Tensor: - """ - Returns matmul(x, weight.t()). - In the backward pass it will include an auxiliary loss based on predicting x from - matmul(y, weight). - """ - if torch.is_autocast_enabled(): - x = x.to(torch.float16) - ctx.save_for_backward(x, weight) - ctx.aux_grad_scale = aux_grad_scale - return torch.matmul(x, weight.t()) - - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]: - x, weight = ctx.saved_tensors - - x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype)) - weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(), - x.reshape(-1, x.shape[-1]).to(ans_grad.dtype)) - - - with torch.cuda.amp.autocast(enabled=False): - with torch.enable_grad(): - x = x.to(weight.dtype) - x, weight = x.detach(), weight.detach() - weight.requires_grad = True - # recompute y as we need the gradient; this is easier to implement than - # saving y in the context. - y = torch.matmul(x, weight.t()) - z = torch.matmul(y, weight) - # subtract mean - dims_to_mean = tuple(range(x.ndim-1)) - x = x - x.mean(dim=dims_to_mean) - z = z - z.mean(dim=dims_to_mean) - # compute optimal scale on z - with torch.no_grad(): - alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20) - diff = x - alpha * z - # meansq is the loss function. - meansq = (diff ** 2).mean() - meansq.backward() - weight_aux_grad = weight.grad - - - with torch.cuda.amp.autocast(enabled=False): - weight_grad_norm = weight_grad.to(torch.float32).norm() - aux_grad_norm = weight_aux_grad.norm() - weight_grad_scale = ctx.aux_grad_scale * weight_grad_norm / (aux_grad_norm + 1.0e-20) - weight_grad = weight_grad + (weight_grad_scale * weight_aux_grad).to(weight_grad.dtype) - - return x_grad, weight_grad, None - - - -class LinearWithAuxLoss(nn.Module): - """ - A linear layer with an auxiliary loss that you can put on a schedule, that - encourages it to correspond to the largest-variance directions of the - input features. - - Suppose the input is x, and this layer computes: - y = M x - (the bias is applied separately), then we define: - z = exp(alpha) * M^T y - where alpha is learnable; and the auxiliary loss will be: - aux_loss = normalize_mean(z - x)^2. - (normalize_mean refers to subtracting the average value per channel, - over the minibatch). - In the backward pass we compute the derivative of the auxiliary loss - and add it to the weight and bias grads, with a scale chosen such - that the extra grad's norm equals `aux_grad_scales` times the norm - of the existing grad. - """ - def __init__(self, - in_channels: int, - out_channels: int, - bias: bool = True, - aux_grad_scale: Optional[FloatLike] = None, - prob: FloatLike = 0.25, - initial_scale: float = 1.0, - ): - super().__init__() - if aux_grad_scale is None: - aux_grad_scale = ScheduledFloat((0.0, 1.0), (1000.0, 0.1), - (2000.0, 0.01), (8000.0, 0.0)) - - self.aux_grad_scale = aux_grad_scale - self.prob = prob - - self.weight = nn.Parameter(torch.randn(out_channels, in_channels) - * (in_channels ** -0.5) * initial_scale) - if bias: - self.bias = nn.Parameter(torch.randn(out_channels) * - 0.01 * initial_scale) - else: - self.register_parameter('bias', None) - - - def forward(self, - x: Tensor): - aux_grad_scale = float(self.aux_grad_scale) - if (not self.training or torch.jit.is_scripting() or - aux_grad_scale == 0.0 or random.random() > float(self.prob)): - return torch.nn.functional.linear(x, self.weight, self.bias) - else: - ans = LinearWithAuxLossFunction.apply(x, self.weight, - aux_grad_scale) - if self.bias is not None: - ans += self.bias - return ans def ScaledLinear(*args, @@ -711,12 +597,14 @@ class ActivationBalancer(torch.nn.Module): scale_gain_factor: FloatLike = 0.04, min_abs: FloatLike = 0.2, max_abs: FloatLike = 100.0, - min_prob: FloatLike = 0.1, + prob: Optional[FloatLike] = None, ): super(ActivationBalancer, self).__init__() - # CAUTION: this code expects self.batch_count to be overwritten in the main training - # loop. - self.batch_count = 0 + + + if prob is None: + prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4) + self.prob = prob # actually self.num_channels is no longer needed except for an assertion. self.num_channels = num_channels @@ -726,7 +614,6 @@ class ActivationBalancer(torch.nn.Module): self.max_factor = max_factor self.min_abs = min_abs self.max_abs = max_abs - self.min_prob = min_prob self.sign_gain_factor = sign_gain_factor self.scale_gain_factor = scale_gain_factor @@ -738,9 +625,7 @@ class ActivationBalancer(torch.nn.Module): if torch.jit.is_scripting() or not x.requires_grad: return _no_op(x) - # the prob of doing some work exponentially decreases from 0.5 till it hits - # a floor at min_prob (==0.1, by default) - prob = max(float(self.min_prob), 0.5 ** (1 + (self.batch_count / 4000.0))) + prob = float(self.prob) if random.random() < prob: assert x.shape[self.channel_dim] == self.num_channels diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index ac8e5c19d..50c2e6b71 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -36,7 +36,6 @@ from scaling import ( ScaledConv1d, ScaledConv2d, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - LinearWithAuxLoss, Whiten, Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. penalize_abs_values_gt, @@ -355,16 +354,11 @@ class Zipformer(EncoderInterface): def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: return ScheduledFloat((0.0, x), - (12000.0, ratio * x), + (20000.0, ratio * x), default=x) -def _aux_grad_scale() -> float: - return 0.2 -def _aux_grad_prob_out() -> ScheduledFloat: - return 0.0 # ScheduledFloat((0.0, 0.25), (1000.0, 0.05), (8000.0, 0.0125)) -def _aux_grad_prob_in() -> ScheduledFloat: - return 0.0 # ScheduledFloat((0.0, 0.25), (1000.0, 0.0)) - #return ScheduledFloat((0.0, 0.25), (1000.0, 0.05), (8000.0, 0.0125)) +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) @@ -398,8 +392,8 @@ class ZipformerEncoderLayer(nn.Module): # treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom() # to work correctly. layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), - dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0), - nonlin_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (20000, 0.0), default=0), + attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), + conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0), bypass_max: FloatLike = 1.0, @@ -410,10 +404,10 @@ class ZipformerEncoderLayer(nn.Module): # probability of skipping the entire layer. self.layer_skip_rate = copy.deepcopy(layer_skip_rate) # skip probability for dynamic modules (meaning: anything but feedforward). - self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate) - # an additional skip probability that applies to NoninAttentionModule to stop it from + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) + # an additional skip probability that applies to ConvModule to stop it from # contributing too much early on. - self.nonlin_skip_rate = copy.deepcopy(nonlin_skip_rate) + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) # min and max for self.bypass_scale, applied with probability 0.5 to avoid grads # ever becoming zero. @@ -507,7 +501,7 @@ class ZipformerEncoderLayer(nn.Module): src_orig = src # dropout rate for non-feedforward submodules - dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0 + attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 # attn_weights: (num_heads, batch_size, seq_len, seq_len) if self.self_attn_weights is not None: @@ -528,7 +522,7 @@ class ZipformerEncoderLayer(nn.Module): # skip the layer return src, attn_weights - use_self_attn = (random.random() >= dynamic_skip_rate) + use_self_attn = (random.random() >= attention_skip_rate) if use_self_attn: selected_attn_weights = attn_weights[head_offset:head_offset+2] if random.random() < float(self.const_attention_rate): @@ -541,7 +535,7 @@ class ZipformerEncoderLayer(nn.Module): selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1) - if torch.jit.is_scripting() or (use_self_attn and random.random() >= float(self.nonlin_skip_rate)): + if torch.jit.is_scripting() or use_self_attn: src = src + self.nonlin_attention_module(src, selected_attn_weights[0:1]) @@ -555,7 +549,7 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.self_attn( src, attn_weights) - if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate: + if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate): src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward2(src) @@ -750,6 +744,7 @@ class AttentionDownsample(torch.nn.Module): """ super(AttentionDownsample, self).__init__() self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + self.bias = nn.Parameter(torch.zeros(downsample)) self.name = None # will be set from training code self.dropout = copy.deepcopy(dropout) @@ -783,8 +778,9 @@ class AttentionDownsample(torch.nn.Module): assert src.shape[0] == d_seq_len * ds src = src.reshape(d_seq_len, ds, batch_size, in_channels) - # scores: (d_seq_len, downsample, batch_size) + # scores: (d_seq_len, downsample, batch_size, 1) scores = (src * self.query).sum(dim=-1, keepdim=True) + scores = scores + self.bias.unsqueeze(-1).unsqueeze(-1) scores = penalize_abs_values_gt(scores, limit=20.0, @@ -820,7 +816,7 @@ class SimpleUpsample(torch.nn.Module): num_channels: int, upsample: int): super(SimpleUpsample, self).__init__() - self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) + self.upsample = upsample def forward(self, src: Tensor) -> Tensor: @@ -829,10 +825,9 @@ class SimpleUpsample(torch.nn.Module): Returns a tensor of shape ( (seq_len*upsample), batch_size, num_channels) """ - upsample = self.bias.shape[0] + upsample = self.upsample (seq_len, batch_size, num_channels) = src.shape src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src + self.bias.unsqueeze(1) src = src.reshape(seq_len * upsample, batch_size, num_channels) return src @@ -1296,15 +1291,11 @@ class AttentionSqueeze(nn.Module): super().__init__() self.bottleneck_dim = bottleneck_dim - self.in_proj = LinearWithAuxLoss(embed_dim, hidden_dim, - bias=False, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()) - - self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim, - bottleneck_dim, - aux_grad_scale=_aux_grad_scale(), - prob=_aux_grad_prob_in()) + self.in_proj = nn.Linear(embed_dim, hidden_dim, + bias=False) + self.to_bottleneck_proj = nn.Linear(embed_dim, + bottleneck_dim) # bottleneck_balancer is before the actiation. Mostly, for well-trained # instances of this module, the mean absolute values per channel are in @@ -1315,7 +1306,6 @@ class AttentionSqueeze(nn.Module): min_positive=0.2, max_positive=0.8, min_abs=0.05, max_abs=ScheduledFloat((0.0, 0.5), (4000.0, 1.0), default=1.0), - min_prob=0.1, ) self.bottleneck_activation = TanSwish() # in bottleneck self.activation = Identity() # for diagnostics @@ -1329,13 +1319,13 @@ class AttentionSqueeze(nn.Module): hidden_dim, channel_dim=-1, min_positive=0.2, max_positive=0.8, min_abs=0.2, max_abs=1.0, - min_prob=0.05, + prob=_balancer_schedule(0.05), ) self.activation_balancer = ActivationBalancer( hidden_dim, channel_dim=-1, min_positive=0.2, max_positive=0.8, min_abs=0.2, max_abs=1.0, - min_prob=0.05, + prob=_balancer_schedule(0.05), ) self.activation_whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(4.0, ratio=3.0), @@ -1343,17 +1333,16 @@ class AttentionSqueeze(nn.Module): grad_scale=0.01) - self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, hidden_dim) + self.from_bottleneck_proj = nn.Linear(bottleneck_dim, hidden_dim) - self.out_proj = LinearWithAuxLoss(hidden_dim, embed_dim, - aux_grad_scale=_aux_grad_scale(), - prob=_aux_grad_prob_out(), - bias=False, initial_scale=0.05) + self.out_proj = ScaledLinear(hidden_dim, embed_dim, + bias=False, initial_scale=0.05) self.out_balancer = ActivationBalancer( embed_dim, channel_dim=-1, min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.001), (8000.0, 0.005)), + min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)), + prob=0.05, # out of concern for memory usage ) @@ -1404,21 +1393,19 @@ class FeedforwardModule(nn.Module): feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() - self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()) + self.in_proj = nn.Linear(embed_dim, feedforward_dim) self.hidden_balancer = ActivationBalancer(feedforward_dim, channel_dim=-1, min_positive=0.3, max_positive=1.0, min_abs=0.75, - max_abs=5.0, - min_prob=0.25) + max_abs=5.0) self.activation = SwooshL() self.dropout = Dropout2(dropout) - self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim, - initial_scale=0.01, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out()) + self.out_proj = ScaledLinear(feedforward_dim, embed_dim, + initial_scale=0.01) + self.out_whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(7.5), prob=(0.025, 0.25), @@ -1488,11 +1475,11 @@ class NonlinAttentionModule(nn.Module): self.balancer2 = ActivationBalancer( channels, channel_dim=-1, min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.001), (8000.0, 0.005)), + min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)), + prob=0.05, # out of concern for memory usage ) - def forward(self, x: Tensor, attn_weights: Tensor, @@ -1560,9 +1547,8 @@ class ConvolutionModule(nn.Module): bottleneck_dim = channels - self.in_proj = LinearWithAuxLoss( + self.in_proj = nn.Linear( channels, 2 * bottleneck_dim, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in() ) @@ -1618,9 +1604,8 @@ class ConvolutionModule(nn.Module): prob=(0.025, 0.25), grad_scale=0.01) - self.out_proj = LinearWithAuxLoss( + self.out_proj = ScaledLinear( bottleneck_dim, channels, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(), initial_scale=0.05, ) @@ -1834,8 +1819,7 @@ class Conv2dSubsampling(nn.Module): self.scale_max = 1.0 self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1)) - self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out()) + self.out = nn.Linear(out_height * layer3_channels, out_channels) self.dropout = Dropout2(dropout)