From 1b14b130474fa26da86dd9d017a7cff6b747d819 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 5 Jul 2022 17:08:42 +0800 Subject: [PATCH] Add simulate streaming decoding --- .../pruned_transducer_stateless5/conformer.py | 156 ++++++++++++++++++ .../pruned_transducer_stateless5/decode.py | 56 ++++++- 2 files changed, 210 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index a0f37f148..8872ade03 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -197,6 +197,162 @@ class Conformer(EncoderInterface): return x, lengths + @torch.jit.export + def streaming_forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: Optional[List[Tensor]] = None, + processed_lens: Optional[Tensor] = None, + left_context: int = 64, + right_context: int = 4, + chunk_size: int = 16, + simulate_streaming: bool = False, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: + The decode states for previous frames which contains the cached data. + It has two elements, the first element is the attn_cache which has + a shape of (encoder_layers, left_context, batch, attention_dim), + the second element is the conv_cache which has a shape of + (encoder_layers, cnn_module_kernel-1, batch, conv_dim). + Note: states will be modified in this function. + processed_lens: + How many frames (after subsampling) have been processed for each sequence. + left_context: + How many previous frames the attention can see in current chunk. + Note: It's not that each individual frame has `left_context` frames + of left context, some have more. + right_context: + How many future frames the attention can see in current chunk. + Note: It's not that each individual frame has `right_context` frames + of right context, some have more. + chunk_size: + The chunk size for decoding, this will be used to simulate streaming + decoding using masking. + simulate_streaming: + If setting True, it will use a masking strategy to simulate streaming + fashion (i.e. every chunk data only see limited left context and + right context). The whole sequence is supposed to be send at a time + When using simulate_streaming. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + Returns: + Return a tuple containing 2 tensors: + - logits, its shape is (batch_size, output_seq_len, output_dim) + - logit_lens, a tensor of shape (batch_size,) containing the number + of frames in `logits` before padding. + - decode_states, the updated states including the information + of current chunk. + """ + + # x: [N, T, C] + # Caution: We assume the subsampling factor is 4! + + # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 + lengths = (((x_lens - 1) >> 1) - 1) >> 1 + + if not simulate_streaming: + assert states is not None + assert processed_lens is not None + assert ( + len(states) == 2 + and states[0].shape + == (self.encoder_layers, left_context, x.size(0), self.d_model) + and states[1].shape + == ( + self.encoder_layers, + self.cnn_module_kernel - 1, + x.size(0), + self.d_model, + ) + ), f"""The length of states MUST be equal to 2, and the shape of + first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)}, + given {states[0].shape}. the shape of second element should be + {(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)}, + given {states[1].shape}.""" + + lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output + + src_key_padding_mask = make_pad_mask(lengths) + + processed_mask = torch.arange(left_context, device=x.device).expand( + x.size(0), left_context + ) + processed_lens = processed_lens.view(x.size(0), 1) + processed_mask = (processed_lens <= processed_mask).flip(1) + + src_key_padding_mask = torch.cat( + [processed_mask, src_key_padding_mask], dim=1 + ) + + embed = self.encoder_embed(x) + + # cut off 1 frame on each size of embed as they see the padding + # value which causes a training and decoding mismatch. + embed = embed[:, 1:-1, :] + + embed, pos_enc = self.encoder_pos(embed, left_context) + embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + + x, states = self.encoder.chunk_forward( + embed, + pos_enc, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + states=states, + left_context=left_context, + right_context=right_context, + ) # (T, B, F) + if right_context > 0: + x = x[0:-right_context, ...] + lengths -= right_context + else: + assert states is None + states = [] # just to make torch.script.jit happy + # this branch simulates streaming decoding using mask as we are + # using in training time. + src_key_padding_mask = make_pad_mask(lengths) + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + assert x.size(0) == lengths.max().item() + + num_left_chunks = -1 + if left_context >= 0: + assert left_context % chunk_size == 0 + num_left_chunks = left_context // chunk_size + + mask = ~subsequent_chunk_mask( + size=x.size(0), + chunk_size=chunk_size, + num_left_chunks=num_left_chunks, + device=x.device, + ) + x = self.encoder( + x, + pos_emb, + mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) # (T, N, C) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return x, lengths, states + class ConformerEncoderLayer(nn.Module): """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index f87d23cc9..2d0965023 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -96,6 +96,7 @@ Usage: import argparse import logging +import math from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -132,6 +133,8 @@ from icefall.utils import ( write_error_stats, ) +LOG_EPS = math.log(1e-10) + def get_parser(): parser = argparse.ArgumentParser( @@ -298,6 +301,29 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + add_model_arguments(parser) return parser @@ -352,9 +378,26 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, ) + + if params.simulate_streaming: + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + hyps = [] if params.decoding_method == "fast_beam_search": @@ -621,6 +664,10 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -658,6 +705,11 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + logging.info(params) logging.info("About to create model")