Add simulate streaming decoding

This commit is contained in:
pkufool 2022-07-05 17:08:42 +08:00
parent 995f260f91
commit 1b14b13047
2 changed files with 210 additions and 2 deletions

View File

@ -197,6 +197,162 @@ class Conformer(EncoderInterface):
return x, lengths return x, lengths
@torch.jit.export
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
states: Optional[List[Tensor]] = None,
processed_lens: Optional[Tensor] = None,
left_context: int = 64,
right_context: int = 4,
chunk_size: int = 16,
simulate_streaming: bool = False,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
processed_lens:
How many frames (after subsampling) have been processed for each sequence.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
chunk_size:
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
simulate_streaming:
If setting True, it will use a masking strategy to simulate streaming
fashion (i.e. every chunk data only see limited left context and
right context). The whole sequence is supposed to be send at a time
When using simulate_streaming.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
- decode_states, the updated states including the information
of current chunk.
"""
# x: [N, T, C]
# Caution: We assume the subsampling factor is 4!
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
if not simulate_streaming:
assert states is not None
assert processed_lens is not None
assert (
len(states) == 2
and states[0].shape
== (self.encoder_layers, left_context, x.size(0), self.d_model)
and states[1].shape
== (
self.encoder_layers,
self.cnn_module_kernel - 1,
x.size(0),
self.d_model,
)
), f"""The length of states MUST be equal to 2, and the shape of
first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)},
given {states[0].shape}. the shape of second element should be
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
given {states[1].shape}."""
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
src_key_padding_mask = make_pad_mask(lengths)
processed_mask = torch.arange(left_context, device=x.device).expand(
x.size(0), left_context
)
processed_lens = processed_lens.view(x.size(0), 1)
processed_mask = (processed_lens <= processed_mask).flip(1)
src_key_padding_mask = torch.cat(
[processed_mask, src_key_padding_mask], dim=1
)
embed = self.encoder_embed(x)
# cut off 1 frame on each size of embed as they see the padding
# value which causes a training and decoding mismatch.
embed = embed[:, 1:-1, :]
embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
x, states = self.encoder.chunk_forward(
embed,
pos_enc,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
states=states,
left_context=left_context,
right_context=right_context,
) # (T, B, F)
if right_context > 0:
x = x[0:-right_context, ...]
lengths -= right_context
else:
assert states is None
states = [] # just to make torch.script.jit happy
# this branch simulates streaming decoding using mask as we are
# using in training time.
src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
assert x.size(0) == lengths.max().item()
num_left_chunks = -1
if left_context >= 0:
assert left_context % chunk_size == 0
num_left_chunks = left_context // chunk_size
mask = ~subsequent_chunk_mask(
size=x.size(0),
chunk_size=chunk_size,
num_left_chunks=num_left_chunks,
device=x.device,
)
x = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths, states
class ConformerEncoderLayer(nn.Module): class ConformerEncoderLayer(nn.Module):
""" """

View File

@ -96,6 +96,7 @@ Usage:
import argparse import argparse
import logging import logging
import math
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -132,6 +133,8 @@ from icefall.utils import (
write_error_stats, write_error_stats,
) )
LOG_EPS = math.log(1e-10)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -298,6 +301,29 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -352,9 +378,26 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -621,6 +664,10 @@ def main():
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
@ -658,6 +705,11 @@ def main():
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")