diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index e1d220f5f..3b443e1c1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -711,7 +711,8 @@ class ActivationBalancer(torch.nn.Module): return _no_op(x) -def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: +def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, + name: str = None) -> Tensor: """ Returns x unmodified, but in backprop will put a penalty for the excess of the absolute values of elements of x over the limit "limit". E.g. if @@ -721,6 +722,8 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: in automatic mixed precision training. For this reasons we use this, it shouldn't really matter, or may even be helpful; we just use this to disallow really implausible values of scores to be given to softmax. + + The name is for randomly printed debug info. """ x_sign = x.sign() over_limit = (x.abs() - limit) > 0 @@ -734,7 +737,7 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) # note: we don't do sum() here on aux)_loss, but it's as if we had done # sum() due to how with_loss() works. - x = with_loss(x, aux_loss) + x = with_loss(x, aux_loss, name) # you must use x for something, or this will be ineffective. return x @@ -907,17 +910,23 @@ class Whiten(nn.Module): class WithLoss(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, y: Tensor): - ctx.y_shape = y.shape + def forward(ctx, x: Tensor, y: Tensor, name: str): + ctx.name = name + ctx.save_for_backward(y) # just for printing the name, and the shape return x @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, + y, = ctx.saved_tensors + if random.random() < 0.002 and ctx.name is not None: + loss_sum = y.sum().item() + logging.info(f"WithLoss: name={ctx.name}, loss-sum={loss_sum:.3e}") + + return ans_grad, torch.ones(y.shape, dtype=ans_grad.dtype, - device=ans_grad.device) -def with_loss(x, y): + device=ans_grad.device), None +def with_loss(x, y, name): # returns x but adds y.sum() to the loss function. - return WithLoss.apply(x, y) + return WithLoss.apply(x, y, name) class ScaleGradFunction(torch.autograd.Function): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 315b10730..3ea11189a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -801,6 +801,8 @@ class AttentionDownsample(torch.nn.Module): super(AttentionDownsample, self).__init__() self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + self.name = None # will be set from training code + # fill in the extra dimensions with a projection of the input if out_channels > in_channels: self.extra_proj = nn.Linear(in_channels * downsample, @@ -833,8 +835,9 @@ class AttentionDownsample(torch.nn.Module): scores = (src * self.query).sum(dim=-1, keepdim=True) scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04) + limit=20.0, + penalty=1.0e-04, + name=self.name) weights = scores.softmax(dim=1) @@ -1207,7 +1210,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # under normal circumstances. attn_scores = penalize_abs_values_gt(attn_scores, limit=25.0, - penalty=1.0e-04) + penalty=1.0e-04, + name=self.name) assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) @@ -1870,6 +1874,7 @@ class AttentionCombine(nn.Module): num_inputs)) self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) + self.name = None # will be set from training code assert 0 <= random_prob <= 1, random_prob assert 0 <= single_prob <= 1, single_prob @@ -1926,7 +1931,8 @@ class AttentionCombine(nn.Module): if self.training and random.random() < 0.1: scores = penalize_abs_values_gt(scores, limit=10.0, - penalty=1.0e-04) + penalty=1.0e-04, + name=self.name) weights = scores.softmax(dim=1)