Update egs/mls_english/ASR/zipformer/streaming_decode.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
Bailey Machiko Hirota 2025-09-11 15:32:18 +09:00 committed by GitHub
parent 9d389cdca7
commit 8c846399a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -386,12 +386,12 @@ def streaming_forward(
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
processed_lens = states[-1] # (batch,)
idx = torch.arange(left_context_len, device=x.device).unsqueeze(0).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# True means padding positions (not yet available in cache).
processed_mask = idx >= processed_lens.unsqueeze(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens