mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
make states in forward streaming optional
This commit is contained in:
parent
61f3c87d48
commit
8d37175ffb
@ -377,7 +377,6 @@ def decode_one_batch(
|
|||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
states=[],
|
|
||||||
chunk_size=params.decode_chunk_size,
|
chunk_size=params.decode_chunk_size,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
simulate_streaming=True,
|
simulate_streaming=True,
|
||||||
|
@ -249,7 +249,7 @@ class Conformer(EncoderInterface):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
states: List[Tensor],
|
states: Optional[List[Tensor]] = None,
|
||||||
processed_lens: Optional[Tensor] = None,
|
processed_lens: Optional[Tensor] = None,
|
||||||
left_context: int = 64,
|
left_context: int = 64,
|
||||||
right_context: int = 4,
|
right_context: int = 4,
|
||||||
@ -311,6 +311,8 @@ class Conformer(EncoderInterface):
|
|||||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||||
|
|
||||||
if not simulate_streaming:
|
if not simulate_streaming:
|
||||||
|
assert states is not None
|
||||||
|
assert processed_lens is not None
|
||||||
assert (
|
assert (
|
||||||
len(states) == 2
|
len(states) == 2
|
||||||
and states[0].shape
|
and states[0].shape
|
||||||
@ -332,8 +334,6 @@ class Conformer(EncoderInterface):
|
|||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
assert processed_lens is not None
|
|
||||||
|
|
||||||
processed_mask = torch.arange(left_context, device=x.device).expand(
|
processed_mask = torch.arange(left_context, device=x.device).expand(
|
||||||
x.size(0), left_context
|
x.size(0), left_context
|
||||||
)
|
)
|
||||||
@ -366,6 +366,8 @@ class Conformer(EncoderInterface):
|
|||||||
x = x[0:-right_context, ...]
|
x = x[0:-right_context, ...]
|
||||||
lengths -= right_context
|
lengths -= right_context
|
||||||
else:
|
else:
|
||||||
|
assert states is None
|
||||||
|
states = [] # just to make torch.script.jit happy
|
||||||
# this branch simulates streaming decoding using mask as we are
|
# this branch simulates streaming decoding using mask as we are
|
||||||
# using in training time.
|
# using in training time.
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
|
@ -391,7 +391,6 @@ def decode_one_batch(
|
|||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
states=[],
|
|
||||||
chunk_size=params.decode_chunk_size,
|
chunk_size=params.decode_chunk_size,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
simulate_streaming=True,
|
simulate_streaming=True,
|
||||||
|
@ -377,7 +377,6 @@ def decode_one_batch(
|
|||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
states=[],
|
|
||||||
chunk_size=params.decode_chunk_size,
|
chunk_size=params.decode_chunk_size,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
simulate_streaming=True,
|
simulate_streaming=True,
|
||||||
|
@ -403,7 +403,6 @@ def decode_one_batch(
|
|||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
states=[],
|
|
||||||
chunk_size=params.decode_chunk_size,
|
chunk_size=params.decode_chunk_size,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
simulate_streaming=True,
|
simulate_streaming=True,
|
||||||
|
@ -245,7 +245,7 @@ class Conformer(Transformer):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
states: List[torch.Tensor],
|
states: Optional[List[torch.Tensor]] = None,
|
||||||
processed_lens: Optional[Tensor] = None,
|
processed_lens: Optional[Tensor] = None,
|
||||||
left_context: int = 64,
|
left_context: int = 64,
|
||||||
right_context: int = 0,
|
right_context: int = 0,
|
||||||
@ -302,6 +302,8 @@ class Conformer(Transformer):
|
|||||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||||
|
|
||||||
if not simulate_streaming:
|
if not simulate_streaming:
|
||||||
|
assert states is not None
|
||||||
|
assert processed_lens is not None
|
||||||
assert (
|
assert (
|
||||||
len(states) == 2
|
len(states) == 2
|
||||||
and states[0].shape
|
and states[0].shape
|
||||||
@ -322,8 +324,6 @@ class Conformer(Transformer):
|
|||||||
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
|
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
assert processed_lens is not None
|
|
||||||
|
|
||||||
processed_mask = torch.arange(left_context, device=x.device).expand(
|
processed_mask = torch.arange(left_context, device=x.device).expand(
|
||||||
x.size(0), left_context
|
x.size(0), left_context
|
||||||
)
|
)
|
||||||
@ -352,6 +352,8 @@ class Conformer(Transformer):
|
|||||||
right_context=right_context,
|
right_context=right_context,
|
||||||
) # (T, B, F)
|
) # (T, B, F)
|
||||||
else:
|
else:
|
||||||
|
assert states is None
|
||||||
|
states = [] # just to make torch.script.jit happy
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x, pos_emb = self.encoder_pos(x)
|
x, pos_emb = self.encoder_pos(x)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user