From 8d37175ffb4501f8599d98162270dd134852e18f Mon Sep 17 00:00:00 2001 From: pkufool Date: Sat, 25 Jun 2022 17:44:16 +0800 Subject: [PATCH] make states in forward streaming optional --- egs/librispeech/ASR/pruned_transducer_stateless/decode.py | 1 - .../ASR/pruned_transducer_stateless2/conformer.py | 8 +++++--- .../ASR/pruned_transducer_stateless2/decode.py | 1 - .../ASR/pruned_transducer_stateless3/decode.py | 1 - .../ASR/pruned_transducer_stateless4/decode.py | 1 - egs/librispeech/ASR/transducer_stateless/conformer.py | 8 +++++--- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 109a53536..b7558089c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -377,7 +377,6 @@ def decode_one_batch( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, - states=[], chunk_size=params.decode_chunk_size, left_context=params.left_context, simulate_streaming=True, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 21589731e..e3a0d8e4e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -249,7 +249,7 @@ class Conformer(EncoderInterface): self, x: torch.Tensor, x_lens: torch.Tensor, - states: List[Tensor], + states: Optional[List[Tensor]] = None, processed_lens: Optional[Tensor] = None, left_context: int = 64, right_context: int = 4, @@ -311,6 +311,8 @@ class Conformer(EncoderInterface): lengths = (((x_lens - 1) >> 1) - 1) >> 1 if not simulate_streaming: + assert states is not None + assert processed_lens is not None assert ( len(states) == 2 and states[0].shape @@ -332,8 +334,6 @@ class Conformer(EncoderInterface): src_key_padding_mask = make_pad_mask(lengths) - assert processed_lens is not None - processed_mask = torch.arange(left_context, device=x.device).expand( x.size(0), left_context ) @@ -366,6 +366,8 @@ class Conformer(EncoderInterface): x = x[0:-right_context, ...] lengths -= right_context else: + assert states is None + states = [] # just to make torch.script.jit happy # this branch simulates streaming decoding using mask as we are # using in training time. src_key_padding_mask = make_pad_mask(lengths) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index fc3e58bdc..60a948a99 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -391,7 +391,6 @@ def decode_one_batch( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, - states=[], chunk_size=params.decode_chunk_size, left_context=params.left_context, simulate_streaming=True, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index c4e475c5c..44fc34640 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -377,7 +377,6 @@ def decode_one_batch( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, - states=[], chunk_size=params.decode_chunk_size, left_context=params.left_context, simulate_streaming=True, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 95750616b..d8ae8e026 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -403,7 +403,6 @@ def decode_one_batch( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, - states=[], chunk_size=params.decode_chunk_size, left_context=params.left_context, simulate_streaming=True, diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index f460bffe5..d327656cd 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -245,7 +245,7 @@ class Conformer(Transformer): self, x: torch.Tensor, x_lens: torch.Tensor, - states: List[torch.Tensor], + states: Optional[List[torch.Tensor]] = None, processed_lens: Optional[Tensor] = None, left_context: int = 64, right_context: int = 0, @@ -302,6 +302,8 @@ class Conformer(Transformer): lengths = (((x_lens - 1) >> 1) - 1) >> 1 if not simulate_streaming: + assert states is not None + assert processed_lens is not None assert ( len(states) == 2 and states[0].shape @@ -322,8 +324,6 @@ class Conformer(Transformer): lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output src_key_padding_mask = make_pad_mask(lengths) - assert processed_lens is not None - processed_mask = torch.arange(left_context, device=x.device).expand( x.size(0), left_context ) @@ -352,6 +352,8 @@ class Conformer(Transformer): right_context=right_context, ) # (T, B, F) else: + assert states is None + states = [] # just to make torch.script.jit happy src_key_padding_mask = make_pad_mask(lengths) x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x)