From f9f546968c7d6aca2b3e493864a676ae57237a35 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 11 Feb 2023 18:46:05 +0800 Subject: [PATCH] Revert warmup_batches change; make code change to avoid non in attn_weights --- egs/librispeech/ASR/pruned_transducer_stateless7/train.py | 3 +-- .../ASR/pruned_transducer_stateless7/zipformer.py | 8 ++++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 12ecb0521..52f25ae15 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -1119,8 +1119,7 @@ def run(rank, world_size, args): clipping_scale=2.0, ) - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, - warmup_batches=1000.0) + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 36f09d211..3c7953b5a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1412,13 +1412,17 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if attn_mask is not None: assert attn_mask.dtype == torch.bool - attn_scores = attn_scores.masked_fill(attn_mask, float("-inf")) + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) if key_padding_mask is not None: assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape attn_scores = attn_scores.masked_fill( key_padding_mask.unsqueeze(1), - float("-inf"), + -1000, ) # We use our own version of softmax, defined in scaling.py, which should