mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 13:34:20 +00:00
make states in forward streaming optional
This commit is contained in:
parent
61f3c87d48
commit
8d37175ffb
@ -377,7 +377,6 @@ def decode_one_batch(
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
states=[],
|
||||
chunk_size=params.decode_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True,
|
||||
|
@ -249,7 +249,7 @@ class Conformer(EncoderInterface):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
states: List[Tensor],
|
||||
states: Optional[List[Tensor]] = None,
|
||||
processed_lens: Optional[Tensor] = None,
|
||||
left_context: int = 64,
|
||||
right_context: int = 4,
|
||||
@ -311,6 +311,8 @@ class Conformer(EncoderInterface):
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
|
||||
if not simulate_streaming:
|
||||
assert states is not None
|
||||
assert processed_lens is not None
|
||||
assert (
|
||||
len(states) == 2
|
||||
and states[0].shape
|
||||
@ -332,8 +334,6 @@ class Conformer(EncoderInterface):
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
|
||||
assert processed_lens is not None
|
||||
|
||||
processed_mask = torch.arange(left_context, device=x.device).expand(
|
||||
x.size(0), left_context
|
||||
)
|
||||
@ -366,6 +366,8 @@ class Conformer(EncoderInterface):
|
||||
x = x[0:-right_context, ...]
|
||||
lengths -= right_context
|
||||
else:
|
||||
assert states is None
|
||||
states = [] # just to make torch.script.jit happy
|
||||
# this branch simulates streaming decoding using mask as we are
|
||||
# using in training time.
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
|
@ -391,7 +391,6 @@ def decode_one_batch(
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
states=[],
|
||||
chunk_size=params.decode_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True,
|
||||
|
@ -377,7 +377,6 @@ def decode_one_batch(
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
states=[],
|
||||
chunk_size=params.decode_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True,
|
||||
|
@ -403,7 +403,6 @@ def decode_one_batch(
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
states=[],
|
||||
chunk_size=params.decode_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True,
|
||||
|
@ -245,7 +245,7 @@ class Conformer(Transformer):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
states: List[torch.Tensor],
|
||||
states: Optional[List[torch.Tensor]] = None,
|
||||
processed_lens: Optional[Tensor] = None,
|
||||
left_context: int = 64,
|
||||
right_context: int = 0,
|
||||
@ -302,6 +302,8 @@ class Conformer(Transformer):
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
|
||||
if not simulate_streaming:
|
||||
assert states is not None
|
||||
assert processed_lens is not None
|
||||
assert (
|
||||
len(states) == 2
|
||||
and states[0].shape
|
||||
@ -322,8 +324,6 @@ class Conformer(Transformer):
|
||||
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
|
||||
assert processed_lens is not None
|
||||
|
||||
processed_mask = torch.arange(left_context, device=x.device).expand(
|
||||
x.size(0), left_context
|
||||
)
|
||||
@ -352,6 +352,8 @@ class Conformer(Transformer):
|
||||
right_context=right_context,
|
||||
) # (T, B, F)
|
||||
else:
|
||||
assert states is None
|
||||
states = [] # just to make torch.script.jit happy
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user