diff --git a/egs/mls_english/ASR/zipformer/streaming_decode.py b/egs/mls_english/ASR/zipformer/streaming_decode.py index 7e3199e09..e8e330481 100755 --- a/egs/mls_english/ASR/zipformer/streaming_decode.py +++ b/egs/mls_english/ASR/zipformer/streaming_decode.py @@ -386,12 +386,12 @@ def streaming_forward( src_key_padding_mask = make_pad_mask(x_lens) # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( + processed_lens = states[-1] # (batch,) + idx = torch.arange(left_context_len, device=x.device).unsqueeze(0).expand( x.size(0), left_context_len ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # True means padding positions (not yet available in cache). + processed_mask = idx >= processed_lens.unsqueeze(1) # Update processed lengths new_processed_lens = processed_lens + x_lens