make states in forward streaming optional

This commit is contained in:
pkufool 2022-06-25 17:44:16 +08:00
parent 61f3c87d48
commit 8d37175ffb
6 changed files with 10 additions and 10 deletions

View File

@ -377,7 +377,6 @@ def decode_one_batch(
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
states=[],
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,

View File

@ -249,7 +249,7 @@ class Conformer(EncoderInterface):
self,
x: torch.Tensor,
x_lens: torch.Tensor,
states: List[Tensor],
states: Optional[List[Tensor]] = None,
processed_lens: Optional[Tensor] = None,
left_context: int = 64,
right_context: int = 4,
@ -311,6 +311,8 @@ class Conformer(EncoderInterface):
lengths = (((x_lens - 1) >> 1) - 1) >> 1
if not simulate_streaming:
assert states is not None
assert processed_lens is not None
assert (
len(states) == 2
and states[0].shape
@ -332,8 +334,6 @@ class Conformer(EncoderInterface):
src_key_padding_mask = make_pad_mask(lengths)
assert processed_lens is not None
processed_mask = torch.arange(left_context, device=x.device).expand(
x.size(0), left_context
)
@ -366,6 +366,8 @@ class Conformer(EncoderInterface):
x = x[0:-right_context, ...]
lengths -= right_context
else:
assert states is None
states = [] # just to make torch.script.jit happy
# this branch simulates streaming decoding using mask as we are
# using in training time.
src_key_padding_mask = make_pad_mask(lengths)

View File

@ -391,7 +391,6 @@ def decode_one_batch(
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
states=[],
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,

View File

@ -377,7 +377,6 @@ def decode_one_batch(
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
states=[],
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,

View File

@ -403,7 +403,6 @@ def decode_one_batch(
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
states=[],
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,

View File

@ -245,7 +245,7 @@ class Conformer(Transformer):
self,
x: torch.Tensor,
x_lens: torch.Tensor,
states: List[torch.Tensor],
states: Optional[List[torch.Tensor]] = None,
processed_lens: Optional[Tensor] = None,
left_context: int = 64,
right_context: int = 0,
@ -302,6 +302,8 @@ class Conformer(Transformer):
lengths = (((x_lens - 1) >> 1) - 1) >> 1
if not simulate_streaming:
assert states is not None
assert processed_lens is not None
assert (
len(states) == 2
and states[0].shape
@ -322,8 +324,6 @@ class Conformer(Transformer):
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
src_key_padding_mask = make_pad_mask(lengths)
assert processed_lens is not None
processed_mask = torch.arange(left_context, device=x.device).expand(
x.size(0), left_context
)
@ -352,6 +352,8 @@ class Conformer(Transformer):
right_context=right_context,
) # (T, B, F)
else:
assert states is None
states = [] # just to make torch.script.jit happy
src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)