diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 1a230b24b..75b9b87bd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -691,7 +691,8 @@ class ActivationBalancer(torch.nn.Module): return _no_op(x) -def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: +def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, + name: str = None) -> Tensor: """ Returns x unmodified, but in backprop will put a penalty for the excess of the absolute values of elements of x over the limit "limit". E.g. if @@ -701,6 +702,8 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: in automatic mixed precision training. For this reasons we use this, it shouldn't really matter, or may even be helpful; we just use this to disallow really implausible values of scores to be given to softmax. + + The name is for randomly printed debug info. """ x_sign = x.sign() over_limit = (x.abs() - limit) > 0 @@ -714,7 +717,7 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) # note: we don't do sum() here on aux)_loss, but it's as if we had done # sum() due to how with_loss() works. - x = with_loss(x, aux_loss) + x = with_loss(x, aux_loss, name) # you must use x for something, or this will be ineffective. return x @@ -887,17 +890,23 @@ class Whiten(nn.Module): class WithLoss(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, y: Tensor): - ctx.y_shape = y.shape + def forward(ctx, x: Tensor, y: Tensor, name: str): + ctx.name = name + ctx.save_for_backward(y) # just for printing the name, and the shape return x @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, + y, = ctx.saved_tensors + if random.random() < 0.002 and ctx.name is not None: + loss_sum = y.sum().item() + logging.info(f"WithLoss: name={ctx.name}, loss-sum={loss_sum:.3e}") + + return ans_grad, torch.ones(y.shape, dtype=ans_grad.dtype, - device=ans_grad.device) -def with_loss(x, y): + device=ans_grad.device), None +def with_loss(x, y, name): # returns x but adds y.sum() to the loss function. - return WithLoss.apply(x, y) + return WithLoss.apply(x, y, name) class ScaleGradFunction(torch.autograd.Function): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c2b3e81b2..95dc8b6ae 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -797,6 +797,8 @@ class AttentionDownsample(torch.nn.Module): super(AttentionDownsample, self).__init__() self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + self.name = None # will be set from training code + # fill in the extra dimensions with a projection of the input if out_channels > in_channels: self.extra_proj = nn.Linear(in_channels * downsample, @@ -830,7 +832,8 @@ class AttentionDownsample(torch.nn.Module): scores = penalize_abs_values_gt(scores, limit=20.0, - penalty=1.0e-04) + penalty=1.0e-04, + name=self.name) weights = scores.softmax(dim=1) @@ -1203,7 +1206,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # under normal circumstances. attn_scores = penalize_abs_values_gt(attn_scores, limit=25.0, - penalty=1.0e-04) + penalty=1.0e-04, + name=self.name) assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) @@ -1866,6 +1870,7 @@ class AttentionCombine(nn.Module): num_inputs)) self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) + self.name = None # will be set from training code assert 0 <= random_prob <= 1, random_prob assert 0 <= single_prob <= 1, single_prob @@ -1922,7 +1927,8 @@ class AttentionCombine(nn.Module): if self.training and random.random() < 0.1: scores = penalize_abs_values_gt(scores, limit=10.0, - penalty=1.0e-04) + penalty=1.0e-04, + name=self.name) weights = scores.softmax(dim=1)