From 118b09463df2f36bad3255664d0e6a80be95d576 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 20 May 2022 15:26:45 +0800 Subject: [PATCH] Add more documents --- .../ASR/pruned_transducer_stateless/decode.py | 77 +++++++++++++--- .../ASR/pruned_transducer_stateless/train.py | 52 ++++++++--- .../ASR/transducer_stateless/conformer.py | 92 +++++++++++++++++-- 3 files changed, 189 insertions(+), 32 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index fee28d4fe..cdb8b5d1b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -66,6 +66,18 @@ Usage: --beam 8 \ --max-contexts 8 \ --max-states 64 + +(6) decode in streaming mode (take greedy search as an example) +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --right-chunk-size 16 \ + --left-context 64 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method greedy_search """ @@ -249,24 +261,66 @@ def get_parser(): help="""Maximum number of symbols per frame. Used only when --decoding_method is greedy_search""", ) + parser.add_argument( - "--streaming-mode", + "--dynamic-chunk-training", type=str2bool, default=False, - help=""" + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="""How many left context can be seen in chunks when calculating attention. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. """, ) parser.add_argument( "--right-chunk-size", type=int, default=16, - help="right context to attend during decoding", + help="The chunk size for decoding (in frames after subsampling)", ) parser.add_argument( "--left-context", type=int, default=64, - help="left context to attend during decoding", + help="left context can be seen during decoding (in frames after subsampling)", ) return parser @@ -320,13 +374,13 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - if params.streaming_mode: + if params.simulate_streaming: encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, chunk_size=params.right_chunk_size, left_context=params.left_context, - streaming_data=False + simulate_streaming=True ) else: encoder_out, encoder_out_lens = model.encoder( @@ -554,7 +608,7 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.streaming_mode: + if params.simulate_streaming: params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}" params.suffix += f"-left-context-{params.left_context}" @@ -590,13 +644,14 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + logging.info(params) logging.info("About to create model") - # TODO(wei kang): make following config more elegant - params.dynamic_chunk_training=params.streaming_mode - params.short_chunk_size=25 - params.num_left_chunks=params.left_context // params.right_chunk_size model = get_transducer_model(params) if params.iter > 0: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index a922192e5..cb8f13c54 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -28,6 +28,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir pruned_transducer_stateless/exp \ --full-libri 1 \ --max-duration 300 + +# train a streaming model +./pruned_transducer_stateless/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir pruned_transducer_stateless/exp \ + --full-libri 1 \ + --dynamic-chunk-training 1 \ + --causal-convolution 1 \ + --short-chunk-size 25 \ + --num-left-chunks 4 \ + --max-duration 300 """ @@ -222,28 +235,40 @@ def get_parser(): """, ) + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + parser.add_argument( "--short-chunk-size", type=int, default=25, - help="chunk length of dynamic training", + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, ) parser.add_argument( "--num-left-chunks", type=int, default=4, - help="chunk length of dynamic training", + help="How many left context can be seen in chunks when calculating attention.", ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True - """, - ) return parser @@ -336,7 +361,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: dynamic_chunk_training=params.dynamic_chunk_training, short_chunk_size=params.short_chunk_size, num_left_chunks=params.num_left_chunks, - causal=True if params.dynamic_chunk_training else False, + causal=params.causal_convolution, ) return encoder @@ -789,6 +814,11 @@ def run(rank, world_size, args): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.dynamic_chunk_training: + assert ( + params.causal_convolution + ), "dynamic_chunk_training requires causal convolution" + logging.info(params) logging.info("About to create model") diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 046345508..f2a051471 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -105,6 +105,26 @@ class Conformer(Transformer): cnn_module_kernel (int): Kernel size of convolution module normalize_before (bool): whether to use layer_norm before the first block. vgg_frontend (bool): whether to use vgg frontend. + dynamic_chunk_training (bool): whether to use dynamic chunk training, if + you want to train a streaming model, this is expected to be True. + When setting True, it will use a masking strategy to make the attention + see only limited left and right context. + short_chunk_threshold (float): a threshold to determinize the chunk size + to be used in masking training, if the randomly generated chunk size + is greater than ``max_len * short_chunk_threshold`` (max_len is the + max sequence length of current batch) then it will use + full context in training (i.e. with chunk size equals to max_len). + This will be used only when dynamic_chunk_training is True. + short_chunk_size (int): see docs above, if the randomly generated chunk + size equals to or less than ``max_len * short_chunk_threshold``, the + chunk size will be sampled uniformly from 1 to short_chunk_size. + This also will be used only when dynamic_chunk_training is True. + num_left_chunks (int): the left context attention can see in chunks, the + chunk size is decided by short_chunk_threshold and short_chunk_size. + A minus value means seeing full left context. + This also will be used only when dynamic_chunk_training is True. + causal (bool): Whether to use causal convolution in conformer encoder + layer. This MUST be True when using dynamic_chunk_training. """ def __init__( @@ -229,16 +249,45 @@ class Conformer(Transformer): x: torch.Tensor, x_lens: torch.Tensor, decode_states: Optional[DecodeStates] = None, - chunk_size: int = 32, + chunk_size: int = 16, left_context: int = 64, - streaming_data: bool = True, + simulate_streaming: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, DecodeStates]: - # x: [N, T, C] + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + decode_states: + The decode states for previous frames which contains the cached data + and the offset of current chunk in the whole sequence. + chunk_size: + The chunk size for decoding, this will be used to simulate streaming + decoding using masking. + left_context: + How many old frames the attention can see in current chunk, it MUST + be equal to left_context in decode_states. + simulate_streaming: + If setting True, it will use a masking strategy to simulate streaming + fashion (i.e. every chunk data only see limited left context and + right context). The whole sequence is supposed to be send at a time + When using simulate_streaming. + Returns: + Return a tuple containing 2 tensors: + - logits, its shape is (batch_size, output_seq_len, output_dim) + - logit_lens, a tensor of shape (batch_size,) containing the number + of frames in `logits` before padding. + - decode_states, the updated DecodeStates including the information + of current chunk. + """ + # x: [N, T, C] # Caution: We assume the subsampling factor is 4! lengths = ((x_lens - 1) // 2 - 1) // 2 - if streaming_data: + if not simulate_streaming: assert ( decode_states is not None ), "Require cache when sending data in streaming mode" @@ -309,7 +358,9 @@ class ConformerEncoderLayer(nn.Module): dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module. - normalize_before: whether to use layer_norm before the first block. + normalize_before (bool): whether to use layer_norm before the first block. + causal (bool): Whether to use causal convolution in conformer encoder + layer. This MUST be True when using dynamic_chunk_training and streaming decoding. Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -386,10 +437,14 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - + attn_cache: attention cache for previous frames. + conv_cache: convolution cache for previous frames. + left_context: left context in frames used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Shape: src: (S, N, E). - pos_emb: (N, 2*S-1, E) + pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E). src_mask: (S, S). src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number @@ -502,10 +557,16 @@ class ConformerEncoder(nn.Module): pos_emb: Positional embedding tensor (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). + attn_cache: attention cache for previous frames. + conv_cache: convolution cache for previous frames. + left_context: left context in frames used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. + Shape: Shape: src: (S, N, E). - pos_emb: (N, 2*S-1, E) + pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E). mask: (S, S). src_key_padding_mask: (N, S). S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number @@ -606,7 +667,9 @@ class RelPositionalEncoding(torch.nn.Module): Args: x (torch.Tensor): Input tensor (batch, time, `*`). - + context (int): left context in frames used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Returns: torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). @@ -699,6 +762,9 @@ class RelPositionMultiheadAttention(nn.Module): need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. + left_context (int): left context in frames used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Shape: - Inputs: @@ -753,6 +819,9 @@ class RelPositionMultiheadAttention(nn.Module): Args: x: Input tensor (batch, head, time1, 2*time1-1). time1 means the length of query vector. + left_context (int): left context in frames used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Returns: Tensor: tensor of shape (batch, head, time1, time2) @@ -810,6 +879,9 @@ class RelPositionMultiheadAttention(nn.Module): need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. + left_context (int): left context in frames used during streaming decoding. + this is used only in real streaming decoding, in other circumstances, + it MUST be 0. Shape: Inputs: @@ -1083,7 +1155,7 @@ class ConvolutionModule(nn.Module): channels (int): The number of channels of conv layers. kernel_size (int): Kernerl size of conv layers. bias (bool): Whether to use bias in conv layers (default=True). - + causal (bool): Whether to use causal convolution. """ def __init__(