diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index 6813c0aca..bb2f0e33c 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -541,16 +541,14 @@ class SubformerEncoderLayer(nn.Module): attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) if True: - selected_attn_weights = attn_weights[0:2] + selected_attn_weights = attn_weights[0:1] if random.random() < float(self.const_attention_rate): # Make attention weights constant. The intention is to # encourage these modules to do something similar to an # averaging-over-time operation. # only need the mask, can just use the 1st one and expand later - selected_attn_weights = selected_attn_weights[0:1] selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) - selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1) na = self.balancer_na(self.nonlin_attention(src, @@ -1383,6 +1381,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) self.score_penalty = AbsValuePenalizer( limit=25.0, penalty=1.0e-04, prob=0.1) + self.name = None # for diagnostics, will be set in train.py key_head_dim = query_head_dim in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads @@ -1622,7 +1621,7 @@ class MultiheadAttentionWeights(nn.Module): self.dropout = dropout self.score_penalty = AbsValuePenalizer( limit=25.0, penalty=1.0e-04, prob=0.1) - + self.name = None # for diagnostics, will be set in train.py # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5 that has been used in previous forms of attention, diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index 0d59fea3d..903586e8e 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -404,7 +404,7 @@ def get_params() -> AttributeDict: "warm_step": 2000, "env_info": get_env_info(), "bytes_per_segment": 2048, - "batch_size": 16, + "batch_size": 14, "train_file_list": "train.txt", "valid_file_list": "valid.txt", "num_workers": 4, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 210bc1e2f..22fc2608a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -253,8 +253,10 @@ class CutoffEstimator: class SoftmaxFunction(torch.autograd.Function): """ - Tries to handle half-precision derivatives in a randomized way that should - be more accurate for training than the default behavior. + An overloaded version of backprop for softmax that tries to save memory by + creating fp16 output (the default version of softmax creates fp32 output). + We convert back to fp32 internally during the backward pass, to minimize + roundoff error. """ @staticmethod def forward(ctx, x: Tensor, dim: int): @@ -272,13 +274,30 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): ans, = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): - ans_grad = ans_grad.to(torch.float32) - ans = ans.to(torch.float32) - x_grad = ans_grad * ans - ans *= x_grad.sum(dim=ctx.dim, keepdim=True) - x_grad -= ans - return x_grad, None + try: + with torch.cuda.amp.autocast(enabled=False): + if ans.dtype != torch.float32: + ans = ans.to(torch.float32) + x_grad = ans.mul_(ans_grad.to(torch.float32)) + else: + # out-of-place since it's not a copy + x_grad = ans_grad.to(torch.float32) * ans + ans *= x_grad.sum(dim=ctx.dim, keepdim=True) + x_grad -= ans + return x_grad, None + except Exception as e: + logging.info(f"Caught exception in SoftmaxFunction backward: {e}, size={list(ans.shape)}, dim={ctx.dim}, will try in half precision.") + x_grad = None + + + ans, = ctx.saved_tensors + ans_grad.mul_(ans) + x_grad = ans_grad + ans *= x_grad.sum(dim=ctx.dim, keepdim=True) + x_grad -= ans + return x_grad, None + +