From db543866d892c51f15a6f64215d47b18a8ca32f7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 11 Feb 2023 17:34:04 +0800 Subject: [PATCH] Remove unused debug statement. --- .../ASR/pruned_transducer_stateless7/zipformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 4376967a0..36f09d211 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -471,7 +471,7 @@ class Zipformer(EncoderInterface): seq_len = x.shape[0] # t is frame index, shape (seq_len,) - t = torch.arange(seq_len, dtype=torch.int32) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) # c is chunk index for each frame, shape (seq_len,) c = t // chunk_size src_c = c @@ -479,7 +479,7 @@ class Zipformer(EncoderInterface): attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) - if __name__ == "__main__" or random.random() < 0.1: + if __name__ == "__main__": logging.info(f"attn_mask = {attn_mask}") return attn_mask @@ -1412,7 +1412,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if attn_mask is not None: assert attn_mask.dtype == torch.bool - attn_scores.masked_fill_(attn_mask, float("-inf")) + attn_scores = attn_scores.masked_fill(attn_mask, float("-inf")) if key_padding_mask is not None: assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape