mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Add simulate streaming decoding
This commit is contained in:
parent
995f260f91
commit
1b14b13047
@ -197,6 +197,162 @@ class Conformer(EncoderInterface):
|
|||||||
|
|
||||||
return x, lengths
|
return x, lengths
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def streaming_forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
states: Optional[List[Tensor]] = None,
|
||||||
|
processed_lens: Optional[Tensor] = None,
|
||||||
|
left_context: int = 64,
|
||||||
|
right_context: int = 4,
|
||||||
|
chunk_size: int = 16,
|
||||||
|
simulate_streaming: bool = False,
|
||||||
|
warmup: float = 1.0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||||
|
x_lens:
|
||||||
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
|
`x` before padding.
|
||||||
|
states:
|
||||||
|
The decode states for previous frames which contains the cached data.
|
||||||
|
It has two elements, the first element is the attn_cache which has
|
||||||
|
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||||
|
the second element is the conv_cache which has a shape of
|
||||||
|
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||||
|
Note: states will be modified in this function.
|
||||||
|
processed_lens:
|
||||||
|
How many frames (after subsampling) have been processed for each sequence.
|
||||||
|
left_context:
|
||||||
|
How many previous frames the attention can see in current chunk.
|
||||||
|
Note: It's not that each individual frame has `left_context` frames
|
||||||
|
of left context, some have more.
|
||||||
|
right_context:
|
||||||
|
How many future frames the attention can see in current chunk.
|
||||||
|
Note: It's not that each individual frame has `right_context` frames
|
||||||
|
of right context, some have more.
|
||||||
|
chunk_size:
|
||||||
|
The chunk size for decoding, this will be used to simulate streaming
|
||||||
|
decoding using masking.
|
||||||
|
simulate_streaming:
|
||||||
|
If setting True, it will use a masking strategy to simulate streaming
|
||||||
|
fashion (i.e. every chunk data only see limited left context and
|
||||||
|
right context). The whole sequence is supposed to be send at a time
|
||||||
|
When using simulate_streaming.
|
||||||
|
warmup:
|
||||||
|
A floating point value that gradually increases from 0 throughout
|
||||||
|
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||||
|
to turn modules on sequentially.
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing 2 tensors:
|
||||||
|
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||||
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||||
|
of frames in `logits` before padding.
|
||||||
|
- decode_states, the updated states including the information
|
||||||
|
of current chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# x: [N, T, C]
|
||||||
|
# Caution: We assume the subsampling factor is 4!
|
||||||
|
|
||||||
|
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
||||||
|
#
|
||||||
|
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||||
|
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||||
|
|
||||||
|
if not simulate_streaming:
|
||||||
|
assert states is not None
|
||||||
|
assert processed_lens is not None
|
||||||
|
assert (
|
||||||
|
len(states) == 2
|
||||||
|
and states[0].shape
|
||||||
|
== (self.encoder_layers, left_context, x.size(0), self.d_model)
|
||||||
|
and states[1].shape
|
||||||
|
== (
|
||||||
|
self.encoder_layers,
|
||||||
|
self.cnn_module_kernel - 1,
|
||||||
|
x.size(0),
|
||||||
|
self.d_model,
|
||||||
|
)
|
||||||
|
), f"""The length of states MUST be equal to 2, and the shape of
|
||||||
|
first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)},
|
||||||
|
given {states[0].shape}. the shape of second element should be
|
||||||
|
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
|
||||||
|
given {states[1].shape}."""
|
||||||
|
|
||||||
|
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
|
processed_mask = torch.arange(left_context, device=x.device).expand(
|
||||||
|
x.size(0), left_context
|
||||||
|
)
|
||||||
|
processed_lens = processed_lens.view(x.size(0), 1)
|
||||||
|
processed_mask = (processed_lens <= processed_mask).flip(1)
|
||||||
|
|
||||||
|
src_key_padding_mask = torch.cat(
|
||||||
|
[processed_mask, src_key_padding_mask], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
embed = self.encoder_embed(x)
|
||||||
|
|
||||||
|
# cut off 1 frame on each size of embed as they see the padding
|
||||||
|
# value which causes a training and decoding mismatch.
|
||||||
|
embed = embed[:, 1:-1, :]
|
||||||
|
|
||||||
|
embed, pos_enc = self.encoder_pos(embed, left_context)
|
||||||
|
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||||
|
|
||||||
|
x, states = self.encoder.chunk_forward(
|
||||||
|
embed,
|
||||||
|
pos_enc,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
warmup=warmup,
|
||||||
|
states=states,
|
||||||
|
left_context=left_context,
|
||||||
|
right_context=right_context,
|
||||||
|
) # (T, B, F)
|
||||||
|
if right_context > 0:
|
||||||
|
x = x[0:-right_context, ...]
|
||||||
|
lengths -= right_context
|
||||||
|
else:
|
||||||
|
assert states is None
|
||||||
|
states = [] # just to make torch.script.jit happy
|
||||||
|
# this branch simulates streaming decoding using mask as we are
|
||||||
|
# using in training time.
|
||||||
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
|
x = self.encoder_embed(x)
|
||||||
|
x, pos_emb = self.encoder_pos(x)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
|
num_left_chunks = -1
|
||||||
|
if left_context >= 0:
|
||||||
|
assert left_context % chunk_size == 0
|
||||||
|
num_left_chunks = left_context // chunk_size
|
||||||
|
|
||||||
|
mask = ~subsequent_chunk_mask(
|
||||||
|
size=x.size(0),
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
num_left_chunks=num_left_chunks,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
x = self.encoder(
|
||||||
|
x,
|
||||||
|
pos_emb,
|
||||||
|
mask=mask,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
warmup=warmup,
|
||||||
|
) # (T, N, C)
|
||||||
|
|
||||||
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
|
return x, lengths, states
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoderLayer(nn.Module):
|
class ConformerEncoderLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
@ -96,6 +96,7 @@ Usage:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -132,6 +133,8 @@ from icefall.utils import (
|
|||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -298,6 +301,29 @@ def get_parser():
|
|||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--simulate-streaming",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||||
|
test a streaming model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="The chunk size for decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--left-context",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -352,9 +378,26 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
feature_lens += params.left_context
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
|
x=feature,
|
||||||
|
x_lens=feature_lens,
|
||||||
|
chunk_size=params.decode_chunk_size,
|
||||||
|
left_context=params.left_context,
|
||||||
|
simulate_streaming=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
@ -621,6 +664,10 @@ def main():
|
|||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||||
|
params.suffix += f"-left-context-{params.left_context}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
@ -658,6 +705,11 @@ def main():
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
assert (
|
||||||
|
params.causal_convolution
|
||||||
|
), "Decoding in streaming requires causal convolution"
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user