mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 13:34:20 +00:00
Add simulate streaming decoding
This commit is contained in:
parent
995f260f91
commit
1b14b13047
@ -197,6 +197,162 @@ class Conformer(EncoderInterface):
|
||||
|
||||
return x, lengths
|
||||
|
||||
@torch.jit.export
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
states: Optional[List[Tensor]] = None,
|
||||
processed_lens: Optional[Tensor] = None,
|
||||
left_context: int = 64,
|
||||
right_context: int = 4,
|
||||
chunk_size: int = 16,
|
||||
simulate_streaming: bool = False,
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`x` before padding.
|
||||
states:
|
||||
The decode states for previous frames which contains the cached data.
|
||||
It has two elements, the first element is the attn_cache which has
|
||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||
the second element is the conv_cache which has a shape of
|
||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: states will be modified in this function.
|
||||
processed_lens:
|
||||
How many frames (after subsampling) have been processed for each sequence.
|
||||
left_context:
|
||||
How many previous frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `left_context` frames
|
||||
of left context, some have more.
|
||||
right_context:
|
||||
How many future frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `right_context` frames
|
||||
of right context, some have more.
|
||||
chunk_size:
|
||||
The chunk size for decoding, this will be used to simulate streaming
|
||||
decoding using masking.
|
||||
simulate_streaming:
|
||||
If setting True, it will use a masking strategy to simulate streaming
|
||||
fashion (i.e. every chunk data only see limited left context and
|
||||
right context). The whole sequence is supposed to be send at a time
|
||||
When using simulate_streaming.
|
||||
warmup:
|
||||
A floating point value that gradually increases from 0 throughout
|
||||
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||
to turn modules on sequentially.
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `logits` before padding.
|
||||
- decode_states, the updated states including the information
|
||||
of current chunk.
|
||||
"""
|
||||
|
||||
# x: [N, T, C]
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
|
||||
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
||||
#
|
||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
|
||||
if not simulate_streaming:
|
||||
assert states is not None
|
||||
assert processed_lens is not None
|
||||
assert (
|
||||
len(states) == 2
|
||||
and states[0].shape
|
||||
== (self.encoder_layers, left_context, x.size(0), self.d_model)
|
||||
and states[1].shape
|
||||
== (
|
||||
self.encoder_layers,
|
||||
self.cnn_module_kernel - 1,
|
||||
x.size(0),
|
||||
self.d_model,
|
||||
)
|
||||
), f"""The length of states MUST be equal to 2, and the shape of
|
||||
first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)},
|
||||
given {states[0].shape}. the shape of second element should be
|
||||
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
|
||||
given {states[1].shape}."""
|
||||
|
||||
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
|
||||
processed_mask = torch.arange(left_context, device=x.device).expand(
|
||||
x.size(0), left_context
|
||||
)
|
||||
processed_lens = processed_lens.view(x.size(0), 1)
|
||||
processed_mask = (processed_lens <= processed_mask).flip(1)
|
||||
|
||||
src_key_padding_mask = torch.cat(
|
||||
[processed_mask, src_key_padding_mask], dim=1
|
||||
)
|
||||
|
||||
embed = self.encoder_embed(x)
|
||||
|
||||
# cut off 1 frame on each size of embed as they see the padding
|
||||
# value which causes a training and decoding mismatch.
|
||||
embed = embed[:, 1:-1, :]
|
||||
|
||||
embed, pos_enc = self.encoder_pos(embed, left_context)
|
||||
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||
|
||||
x, states = self.encoder.chunk_forward(
|
||||
embed,
|
||||
pos_enc,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
states=states,
|
||||
left_context=left_context,
|
||||
right_context=right_context,
|
||||
) # (T, B, F)
|
||||
if right_context > 0:
|
||||
x = x[0:-right_context, ...]
|
||||
lengths -= right_context
|
||||
else:
|
||||
assert states is None
|
||||
states = [] # just to make torch.script.jit happy
|
||||
# this branch simulates streaming decoding using mask as we are
|
||||
# using in training time.
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
assert x.size(0) == lengths.max().item()
|
||||
|
||||
num_left_chunks = -1
|
||||
if left_context >= 0:
|
||||
assert left_context % chunk_size == 0
|
||||
num_left_chunks = left_context // chunk_size
|
||||
|
||||
mask = ~subsequent_chunk_mask(
|
||||
size=x.size(0),
|
||||
chunk_size=chunk_size,
|
||||
num_left_chunks=num_left_chunks,
|
||||
device=x.device,
|
||||
)
|
||||
x = self.encoder(
|
||||
x,
|
||||
pos_emb,
|
||||
mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
) # (T, N, C)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return x, lengths, states
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
|
@ -96,6 +96,7 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -132,6 +133,8 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -298,6 +301,29 @@ def get_parser():
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simulate-streaming",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||
test a streaming model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -352,9 +378,26 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
chunk_size=params.decode_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True,
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
@ -621,6 +664,10 @@ def main():
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.simulate_streaming:
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
@ -658,6 +705,11 @@ def main():
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.simulate_streaming:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
), "Decoding in streaming requires causal convolution"
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
Loading…
x
Reference in New Issue
Block a user