From 8c846399a514a508db43fa2c2cc2e65084ada85d Mon Sep 17 00:00:00 2001 From: Bailey Machiko Hirota <53164945+baileyeet@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:32:18 +0900 Subject: [PATCH] Update egs/mls_english/ASR/zipformer/streaming_decode.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- egs/mls_english/ASR/zipformer/streaming_decode.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/mls_english/ASR/zipformer/streaming_decode.py b/egs/mls_english/ASR/zipformer/streaming_decode.py index 7e3199e09..e8e330481 100755 --- a/egs/mls_english/ASR/zipformer/streaming_decode.py +++ b/egs/mls_english/ASR/zipformer/streaming_decode.py @@ -386,12 +386,12 @@ def streaming_forward( src_key_padding_mask = make_pad_mask(x_lens) # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( + processed_lens = states[-1] # (batch,) + idx = torch.arange(left_context_len, device=x.device).unsqueeze(0).expand( x.size(0), left_context_len ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # True means padding positions (not yet available in cache). + processed_mask = idx >= processed_lens.unsqueeze(1) # Update processed lengths new_processed_lens = processed_lens + x_lens