mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 05:55:26 +00:00
Update egs/mls_english/ASR/zipformer/streaming_decode.py
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
parent
9d389cdca7
commit
8c846399a5
@ -386,12 +386,12 @@ def streaming_forward(
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
|
||||
# processed_mask is used to mask out initial states
|
||||
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
||||
processed_lens = states[-1] # (batch,)
|
||||
idx = torch.arange(left_context_len, device=x.device).unsqueeze(0).expand(
|
||||
x.size(0), left_context_len
|
||||
)
|
||||
processed_lens = states[-1] # (batch,)
|
||||
# (batch, left_context_size)
|
||||
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
|
||||
# True means padding positions (not yet available in cache).
|
||||
processed_mask = idx >= processed_lens.unsqueeze(1)
|
||||
# Update processed lengths
|
||||
new_processed_lens = processed_lens + x_lens
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user