From 1506b83c7b6e6ca440e37b9c1ff26ae3699d6029 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 15 Dec 2022 19:25:21 +0800 Subject: [PATCH 01/10] Change nonlin_skip_rate to be conv_skip_rate. --- .../ASR/pruned_transducer_stateless7/zipformer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index f0a52f605..0c0b6ae21 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -399,7 +399,7 @@ class ZipformerEncoderLayer(nn.Module): # to work correctly. layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0), - nonlin_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (20000, 0.0), default=0), + conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (20000, 0.0), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0), bypass_max: FloatLike = 1.0, @@ -411,9 +411,9 @@ class ZipformerEncoderLayer(nn.Module): self.layer_skip_rate = copy.deepcopy(layer_skip_rate) # skip probability for dynamic modules (meaning: anything but feedforward). self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate) - # an additional skip probability that applies to NoninAttentionModule to stop it from + # an additional skip probability that applies to ConvModule to stop it from # contributing too much early on. - self.nonlin_skip_rate = copy.deepcopy(nonlin_skip_rate) + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) # min and max for self.bypass_scale, applied with probability 0.5 to avoid grads # ever becoming zero. @@ -541,7 +541,7 @@ class ZipformerEncoderLayer(nn.Module): selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1) - if torch.jit.is_scripting() or (use_self_attn and random.random() >= float(self.nonlin_skip_rate)): + if torch.jit.is_scripting() or use_self_attn: src = src + self.nonlin_attention_module(src, selected_attn_weights[0:1]) @@ -555,7 +555,7 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.self_attn( src, attn_weights) - if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate: + if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate + float(self.conv_skip_rate): src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward2(src) From 864ff96322352ab4e71b422fd97dc2173de3f072 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 15 Dec 2022 19:27:29 +0800 Subject: [PATCH 02/10] Remove nonlin_skip_rate, introduce conv_skip_rate. --- .../ASR/pruned_transducer_stateless7/zipformer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 0c0b6ae21..0f45d0212 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -398,8 +398,8 @@ class ZipformerEncoderLayer(nn.Module): # treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom() # to work correctly. layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), - dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0), - conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (20000, 0.0), default=0), + attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0), + conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (16000, 0.0), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0), bypass_max: FloatLike = 1.0, @@ -410,7 +410,7 @@ class ZipformerEncoderLayer(nn.Module): # probability of skipping the entire layer. self.layer_skip_rate = copy.deepcopy(layer_skip_rate) # skip probability for dynamic modules (meaning: anything but feedforward). - self.dynamic_skip_rate = copy.deepcopy(dynamic_skip_rate) + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) # an additional skip probability that applies to ConvModule to stop it from # contributing too much early on. self.conv_skip_rate = copy.deepcopy(conv_skip_rate) @@ -507,7 +507,7 @@ class ZipformerEncoderLayer(nn.Module): src_orig = src # dropout rate for non-feedforward submodules - dynamic_skip_rate = float(self.dynamic_skip_rate) if self.training else 0.0 + attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 # attn_weights: (num_heads, batch_size, seq_len, seq_len) if self.self_attn_weights is not None: @@ -528,7 +528,7 @@ class ZipformerEncoderLayer(nn.Module): # skip the layer return src, attn_weights - use_self_attn = (random.random() >= dynamic_skip_rate) + use_self_attn = (random.random() >= attention_skip_rate) if use_self_attn: selected_attn_weights = attn_weights[head_offset:head_offset+2] if random.random() < float(self.const_attention_rate): @@ -555,7 +555,7 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.self_attn( src, attn_weights) - if torch.jit.is_scripting() or random.random() >= dynamic_skip_rate + float(self.conv_skip_rate): + if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate): src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) src = src + self.feed_forward2(src) From 3213c18a2261652e8b0a99ecbd8970d960ca13c7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 16 Dec 2022 14:58:15 +0800 Subject: [PATCH 03/10] Changes to schedules: _whitening_schedule longer, min_abs schedule on attention_squeeze+nonlin_attention shorter; dip in conv_skip_rate. --- .../ASR/pruned_transducer_stateless7/zipformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 0f45d0212..4f823501d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -355,7 +355,7 @@ class Zipformer(EncoderInterface): def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: return ScheduledFloat((0.0, x), - (12000.0, ratio * x), + (20000.0, ratio * x), default=x) def _aux_grad_scale() -> float: @@ -399,7 +399,7 @@ class ZipformerEncoderLayer(nn.Module): # to work correctly. layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0), - conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (16000, 0.0), default=0), + conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0), bypass_max: FloatLike = 1.0, @@ -1408,7 +1408,7 @@ class AttentionSqueeze(nn.Module): self.out_balancer = ActivationBalancer( embed_dim, channel_dim=-1, min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.001), (8000.0, 0.005)), + min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)), ) @@ -1543,7 +1543,7 @@ class NonlinAttentionModule(nn.Module): self.balancer2 = ActivationBalancer( channels, channel_dim=-1, min_positive=0.3, max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.001), (8000.0, 0.005)), + min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)), ) From 56ac7354dfc171008098072b96635b68f1691df3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 16 Dec 2022 15:07:42 +0800 Subject: [PATCH 04/10] Remove LinearWithAuxLoss; simplify schedule of prob in ActivationBalancer. --- .../pruned_transducer_stateless7/scaling.py | 129 +----------------- .../pruned_transducer_stateless7/zipformer.py | 59 +++----- 2 files changed, 28 insertions(+), 160 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index ced3b96d1..b126bd99d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -407,120 +407,6 @@ class BasicNorm(torch.nn.Module): return x * scales -class LinearWithAuxLossFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, weight: Tensor, - aux_grad_scale: float) -> Tensor: - """ - Returns matmul(x, weight.t()). - In the backward pass it will include an auxiliary loss based on predicting x from - matmul(y, weight). - """ - if torch.is_autocast_enabled(): - x = x.to(torch.float16) - ctx.save_for_backward(x, weight) - ctx.aux_grad_scale = aux_grad_scale - return torch.matmul(x, weight.t()) - - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]: - x, weight = ctx.saved_tensors - - x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype)) - weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(), - x.reshape(-1, x.shape[-1]).to(ans_grad.dtype)) - - - with torch.cuda.amp.autocast(enabled=False): - with torch.enable_grad(): - x = x.to(weight.dtype) - x, weight = x.detach(), weight.detach() - weight.requires_grad = True - # recompute y as we need the gradient; this is easier to implement than - # saving y in the context. - y = torch.matmul(x, weight.t()) - z = torch.matmul(y, weight) - # subtract mean - dims_to_mean = tuple(range(x.ndim-1)) - x = x - x.mean(dim=dims_to_mean) - z = z - z.mean(dim=dims_to_mean) - # compute optimal scale on z - with torch.no_grad(): - alpha = (x * z).sum() / ((z * z).sum() + 1.0e-20) - diff = x - alpha * z - # meansq is the loss function. - meansq = (diff ** 2).mean() - meansq.backward() - weight_aux_grad = weight.grad - - - with torch.cuda.amp.autocast(enabled=False): - weight_grad_norm = weight_grad.to(torch.float32).norm() - aux_grad_norm = weight_aux_grad.norm() - weight_grad_scale = ctx.aux_grad_scale * weight_grad_norm / (aux_grad_norm + 1.0e-20) - weight_grad = weight_grad + (weight_grad_scale * weight_aux_grad).to(weight_grad.dtype) - - return x_grad, weight_grad, None - - - -class LinearWithAuxLoss(nn.Module): - """ - A linear layer with an auxiliary loss that you can put on a schedule, that - encourages it to correspond to the largest-variance directions of the - input features. - - Suppose the input is x, and this layer computes: - y = M x - (the bias is applied separately), then we define: - z = exp(alpha) * M^T y - where alpha is learnable; and the auxiliary loss will be: - aux_loss = normalize_mean(z - x)^2. - (normalize_mean refers to subtracting the average value per channel, - over the minibatch). - In the backward pass we compute the derivative of the auxiliary loss - and add it to the weight and bias grads, with a scale chosen such - that the extra grad's norm equals `aux_grad_scales` times the norm - of the existing grad. - """ - def __init__(self, - in_channels: int, - out_channels: int, - bias: bool = True, - aux_grad_scale: Optional[FloatLike] = None, - prob: FloatLike = 0.25, - initial_scale: float = 1.0, - ): - super().__init__() - if aux_grad_scale is None: - aux_grad_scale = ScheduledFloat((0.0, 1.0), (1000.0, 0.1), - (2000.0, 0.01), (8000.0, 0.0)) - - self.aux_grad_scale = aux_grad_scale - self.prob = prob - - self.weight = nn.Parameter(torch.randn(out_channels, in_channels) - * (in_channels ** -0.5) * initial_scale) - if bias: - self.bias = nn.Parameter(torch.randn(out_channels) * - 0.01 * initial_scale) - else: - self.register_parameter('bias', None) - - - def forward(self, - x: Tensor): - aux_grad_scale = float(self.aux_grad_scale) - if (not self.training or torch.jit.is_scripting() or - aux_grad_scale == 0.0 or random.random() > float(self.prob)): - return torch.nn.functional.linear(x, self.weight, self.bias) - else: - ans = LinearWithAuxLossFunction.apply(x, self.weight, - aux_grad_scale) - if self.bias is not None: - ans += self.bias - return ans def ScaledLinear(*args, @@ -655,12 +541,14 @@ class ActivationBalancer(torch.nn.Module): scale_gain_factor: FloatLike = 0.04, min_abs: FloatLike = 0.2, max_abs: FloatLike = 100.0, - min_prob: FloatLike = 0.1, + prob: Optional[FloatLike] = None, ): super(ActivationBalancer, self).__init__() - # CAUTION: this code expects self.batch_count to be overwritten in the main training - # loop. - self.batch_count = 0 + + + if prob is None: + prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1)) + self.prob = prob # actually self.num_channels is no longer needed except for an assertion. self.num_channels = num_channels @@ -670,7 +558,6 @@ class ActivationBalancer(torch.nn.Module): self.max_factor = max_factor self.min_abs = min_abs self.max_abs = max_abs - self.min_prob = min_prob self.sign_gain_factor = sign_gain_factor self.scale_gain_factor = scale_gain_factor @@ -682,9 +569,7 @@ class ActivationBalancer(torch.nn.Module): if torch.jit.is_scripting() or not x.requires_grad: return _no_op(x) - # the prob of doing some work exponentially decreases from 0.5 till it hits - # a floor at min_prob (==0.1, by default) - prob = max(float(self.min_prob), 0.5 ** (1 + (self.batch_count / 4000.0))) + prob = float(self.prob) if random.random() < prob: assert x.shape[self.channel_dim] == self.num_channels diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 4f823501d..347f34ad0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -36,7 +36,6 @@ from scaling import ( ScaledConv1d, ScaledConv2d, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. - LinearWithAuxLoss, Whiten, Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. penalize_abs_values_gt, @@ -358,13 +357,8 @@ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: (20000.0, ratio * x), default=x) -def _aux_grad_scale() -> float: - return 0.2 -def _aux_grad_prob_out() -> ScheduledFloat: - return 0.0 # ScheduledFloat((0.0, 0.25), (1000.0, 0.05), (8000.0, 0.0125)) -def _aux_grad_prob_in() -> ScheduledFloat: - return 0.0 # ScheduledFloat((0.0, 0.25), (1000.0, 0.0)) - #return ScheduledFloat((0.0, 0.25), (1000.0, 0.05), (8000.0, 0.0125)) +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) @@ -1351,15 +1345,11 @@ class AttentionSqueeze(nn.Module): super().__init__() self.bottleneck_dim = bottleneck_dim - self.in_proj = LinearWithAuxLoss(embed_dim, hidden_dim, - bias=False, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()) - - self.to_bottleneck_proj = LinearWithAuxLoss(embed_dim, - bottleneck_dim, - aux_grad_scale=_aux_grad_scale(), - prob=_aux_grad_prob_in()) + self.in_proj = nn.Linear(embed_dim, hidden_dim, + bias=False) + self.to_bottleneck_proj = nn.Linear(embed_dim, + bottleneck_dim) # bottleneck_balancer is before the actiation. Mostly, for well-trained # instances of this module, the mean absolute values per channel are in @@ -1370,7 +1360,6 @@ class AttentionSqueeze(nn.Module): min_positive=0.2, max_positive=0.8, min_abs=0.05, max_abs=ScheduledFloat((0.0, 0.5), (4000.0, 1.0), default=1.0), - min_prob=0.1, ) self.bottleneck_activation = TanSwish() # in bottleneck self.activation = Identity() # for diagnostics @@ -1384,13 +1373,13 @@ class AttentionSqueeze(nn.Module): hidden_dim, channel_dim=-1, min_positive=0.2, max_positive=0.8, min_abs=0.2, max_abs=1.0, - min_prob=0.05, + prob=_balancer_schedule(0.05), ) self.activation_balancer = ActivationBalancer( hidden_dim, channel_dim=-1, min_positive=0.2, max_positive=0.8, min_abs=0.2, max_abs=1.0, - min_prob=0.05, + prob=_balancer_schedule(0.05), ) self.activation_whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(4.0, ratio=3.0), @@ -1398,17 +1387,16 @@ class AttentionSqueeze(nn.Module): grad_scale=0.01) - self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, hidden_dim) + self.from_bottleneck_proj = nn.Linear(bottleneck_dim, hidden_dim) - self.out_proj = LinearWithAuxLoss(hidden_dim, embed_dim, - aux_grad_scale=_aux_grad_scale(), - prob=_aux_grad_prob_out(), - bias=False, initial_scale=0.05) + self.out_proj = ScaledLinear(hidden_dim, embed_dim, + bias=False, initial_scale=0.05) self.out_balancer = ActivationBalancer( embed_dim, channel_dim=-1, min_positive=0.3, max_positive=0.7, min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)), + prob=0.05, # out of concern for memory usage ) @@ -1459,21 +1447,19 @@ class FeedforwardModule(nn.Module): feedforward_dim: int, dropout: FloatLike): super(FeedforwardModule, self).__init__() - self.in_proj = LinearWithAuxLoss(embed_dim, feedforward_dim, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()) + self.in_proj = nn.Linear(embed_dim, feedforward_dim) self.hidden_balancer = ActivationBalancer(feedforward_dim, channel_dim=-1, min_positive=0.3, max_positive=1.0, min_abs=0.75, - max_abs=5.0, - min_prob=0.25) + max_abs=5.0) self.activation = SwooshL() self.dropout = Dropout2(dropout) - self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim, - initial_scale=0.01, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out()) + self.out_proj = ScaledLinear(feedforward_dim, embed_dim, + initial_scale=0.01) + self.out_whiten = Whiten(num_groups=1, whitening_limit=_whitening_schedule(7.5), prob=(0.025, 0.25), @@ -1544,10 +1530,10 @@ class NonlinAttentionModule(nn.Module): channels, channel_dim=-1, min_positive=0.3, max_positive=0.7, min_abs=ScheduledFloat((0.0, 0.001), (4000.0, 0.005)), + prob=0.05, # out of concern for memory usage ) - def forward(self, x: Tensor, attn_weights: Tensor, @@ -1615,9 +1601,8 @@ class ConvolutionModule(nn.Module): bottleneck_dim = channels - self.in_proj = LinearWithAuxLoss( + self.in_proj = nn.Linear( channels, 2 * bottleneck_dim, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in() ) @@ -1673,9 +1658,8 @@ class ConvolutionModule(nn.Module): prob=(0.025, 0.25), grad_scale=0.01) - self.out_proj = LinearWithAuxLoss( + self.out_proj = ScaledLinear( bottleneck_dim, channels, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(), initial_scale=0.05, ) @@ -1818,8 +1802,7 @@ class Conv2dSubsampling(nn.Module): self.scale_max = 1.0 self.scale_min = ScheduledFloat((0.0, 0.9), (4000.0, 0.1)) - self.out = LinearWithAuxLoss(out_height * layer3_channels, out_channels, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out()) + self.out = nn.Linear(out_height * layer3_channels, out_channels) self.dropout = Dropout2(dropout) From 8e6c7ef3e28ee1b95161ba280872c2eea3ae1fe4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 16 Dec 2022 15:08:46 +0800 Subject: [PATCH 05/10] Adjust default prob of ActivationBalancer. --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index b126bd99d..be66d00f6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -547,7 +547,7 @@ class ActivationBalancer(torch.nn.Module): if prob is None: - prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1)) + prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4) self.prob = prob # actually self.num_channels is no longer needed except for an assertion. From 66465c8be47968e757b16bae62a4fc2ce06c9f54 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 16 Dec 2022 15:12:04 +0800 Subject: [PATCH 06/10] Give attention_skip_rate a longer tail --- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 347f34ad0..4ed61cf5a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -392,7 +392,7 @@ class ZipformerEncoderLayer(nn.Module): # treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom() # to work correctly. layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), - attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0), + attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0), From bc002a9eda8ac912cff235460dcdef2fd51b2f19 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 16 Dec 2022 16:30:49 +0800 Subject: [PATCH 07/10] Reduce const_attention_rate --- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 4ed61cf5a..42dfa642f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -394,7 +394,7 @@ class ZipformerEncoderLayer(nn.Module): layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), + const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0), bypass_max: FloatLike = 1.0, ) -> None: From 4eb3e9784847e125468d8f017aea476f68e24ad7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 16 Dec 2022 17:59:15 +0800 Subject: [PATCH 08/10] Remove bias from SimpleUpsample, add one to AttentionDownsample --- .../ASR/pruned_transducer_stateless7/zipformer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 42dfa642f..692f8ffbc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -799,6 +799,7 @@ class AttentionDownsample(torch.nn.Module): """ super(AttentionDownsample, self).__init__() self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + self.bias = nn.Parameter(torch.zeros(downsample)) self.name = None # will be set from training code self.dropout = copy.deepcopy(dropout) @@ -832,8 +833,9 @@ class AttentionDownsample(torch.nn.Module): assert src.shape[0] == d_seq_len * ds src = src.reshape(d_seq_len, ds, batch_size, in_channels) - # scores: (d_seq_len, downsample, batch_size) + # scores: (d_seq_len, downsample, batch_size, 1) scores = (src * self.query).sum(dim=-1, keepdim=True) + scores = scores + self.bias.unsqueeze(-1).unsqueeze(-1) scores = penalize_abs_values_gt(scores, limit=20.0, @@ -869,7 +871,7 @@ class SimpleUpsample(torch.nn.Module): num_channels: int, upsample: int): super(SimpleUpsample, self).__init__() - self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) + self.upsample = upsample def forward(self, src: Tensor) -> Tensor: @@ -878,10 +880,9 @@ class SimpleUpsample(torch.nn.Module): Returns a tensor of shape ( (seq_len*upsample), batch_size, num_channels) """ - upsample = self.bias.shape[0] + upsample = self.upsample (seq_len, batch_size, num_channels) = src.shape src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src + self.bias.unsqueeze(1) src = src.reshape(seq_len * upsample, batch_size, num_channels) return src From b9326e1ef25a79016de574db0a9f35ecdaa72a42 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 16 Dec 2022 18:07:43 +0800 Subject: [PATCH 09/10] Fix to print statement --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index be66d00f6..43d4e7bec 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -799,7 +799,7 @@ class WithLoss(torch.autograd.Function): ctx.y_shape = y.shape if random.random() < 0.002 and name is not None: loss_sum = y.sum().item() - logging.info(f"WithLoss: name={ctx.name}, loss-sum={loss_sum:.3e}") + logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") return x @staticmethod def backward(ctx, ans_grad: Tensor): From 35b63c138767539cfeb0dd073fbee06c9fc0e6ea Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 17 Dec 2022 13:27:17 +0800 Subject: [PATCH 10/10] Revert "Reduce const_attention_rate" This reverts commit bc002a9eda8ac912cff235460dcdef2fd51b2f19. --- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 692f8ffbc..da2ce25ae 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -394,7 +394,7 @@ class ZipformerEncoderLayer(nn.Module): layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), - const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), default=0), + const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0), bypass_max: FloatLike = 1.0, ) -> None: