diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 39b08e169..b380fa145 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -561,11 +561,13 @@ class WhiteningPenaltyFunction(torch.autograd.Function): x: Tensor, num_groups: int, whitening_limit: float, - grad_scale: float) -> Tensor: + grad_scale: float, + name: Optional[str]) -> Tensor: ctx.save_for_backward(x) ctx.num_groups = num_groups ctx.whitening_limit = whitening_limit ctx.grad_scale = grad_scale + ctx.name = name return x @staticmethod @@ -580,7 +582,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, ctx.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + logging.info(f"Whitening: name={ctx.name}, num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") (metric - ctx.whitening_limit).relu().backward() @@ -588,7 +590,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20)) penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None, None, None + return x_grad + penalty_grad.to(x_grad.dtype), None, None, None, None @@ -630,7 +632,7 @@ class Whiten(nn.Module): (self.min_prob, self.max_prob) = prob assert 0 < self.min_prob < self.max_prob <= 1 self.prob = self.max_prob - + self.name = None # will be set in training loop self.grad_scale = grad_scale def forward(self, @@ -666,7 +668,8 @@ class Whiten(nn.Module): return WhiteningPenaltyFunction.apply(x, self.num_groups, self.whitening_limit, - self.grad_scale) + self.grad_scale, + self.name) class WithLoss(torch.autograd.Function): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index b138b8236..9d9c01f0b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -98,8 +98,8 @@ def set_batch_count( for name, module in model.named_modules(): if hasattr(module, 'batch_count'): module.batch_count = batch_count - if hasattr(module, 'name'): - module.name = name + if hasattr(module, 'name'): + module.name = name def add_model_arguments(parser: argparse.ArgumentParser): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 5410b36b9..d20be6145 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -367,7 +367,7 @@ class ZipformerEncoderLayer(nn.Module): # to work correctly. layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), dynamic_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0), - squeeze_const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0), + const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0), bypass_max: FloatLike = 1.0, ) -> None: @@ -382,7 +382,7 @@ class ZipformerEncoderLayer(nn.Module): # ever becoming zero. self.bypass_min = copy.deepcopy(bypass_min) self.bypass_max = copy.deepcopy(bypass_max) - self.squeeze_const_attention_rate = copy.deepcopy(squeeze_const_attention_rate) + self.const_attention_rate = copy.deepcopy(const_attention_rate) self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, pos_dim=pos_dim, num_heads=num_heads, @@ -480,27 +480,23 @@ class ZipformerEncoderLayer(nn.Module): key_padding_mask=src_key_padding_mask, ) - - squeeze_weights = attn_weights[1:2] - if random.random() < float(self.squeeze_const_attention_rate): - # this form of dropout makes the attention-weights used for the - # squeeze-excite modules constant wherever they are not masked. The intention - # is to encourage these modules to do something similar to an averaging-over-time - # operation. - squeeze_weights = (squeeze_weights > 0.0).to(squeeze_weights.dtype) - # make sure they sum to 1 over the last axis. - squeeze_weights = squeeze_weights * (1.0 / squeeze_weights.sum(dim=-1, keepdim=True)) + first_attn_weights = attn_weights[0:3] + if random.random() < float(self.const_attention_rate): + # Make attention weights constant. The intention is to + # encourage these modules to do something similar to an + # averaging-over-time operation. + first_attn_weights = (first_attn_weights > 0.0).to(first_attn_weights.dtype) + first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True)) if torch.jit.is_scripting() or use_self_attn: src = src + self.nonlin_attention_module(src, - attn_weights[0:1]) - + first_attn_weights[0:1]) src = src + self.feed_forward1(src) # pooling module if torch.jit.is_scripting() or use_self_attn: - src = src + self.attention_squeeze1(src, squeeze_weights) + src = src + self.attention_squeeze1(src, first_attn_weights[1:2]) if torch.jit.is_scripting() or use_self_attn: src = src + self.self_attn( @@ -513,7 +509,7 @@ class ZipformerEncoderLayer(nn.Module): # pooling module if torch.jit.is_scripting() or use_self_attn: - src = src + self.attention_squeeze2(src, squeeze_weights) + src = src + self.attention_squeeze2(src, first_attn_weights[2:3]) src = self.norm_final(self.balancer(src))