mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-10 22:45:27 +00:00
Update egs/mls_english/ASR/zipformer/streaming_decode.py
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
parent
9d389cdca7
commit
8c846399a5
@ -386,12 +386,12 @@ def streaming_forward(
|
|||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
|
||||||
# processed_mask is used to mask out initial states
|
# processed_mask is used to mask out initial states
|
||||||
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
processed_lens = states[-1] # (batch,)
|
||||||
|
idx = torch.arange(left_context_len, device=x.device).unsqueeze(0).expand(
|
||||||
x.size(0), left_context_len
|
x.size(0), left_context_len
|
||||||
)
|
)
|
||||||
processed_lens = states[-1] # (batch,)
|
# True means padding positions (not yet available in cache).
|
||||||
# (batch, left_context_size)
|
processed_mask = idx >= processed_lens.unsqueeze(1)
|
||||||
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
|
|
||||||
# Update processed lengths
|
# Update processed lengths
|
||||||
new_processed_lens = processed_lens + x_lens
|
new_processed_lens = processed_lens + x_lens
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user