From 4f2ef237ebababe189a6fa996c0ea7af0aeb6f03 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 26 May 2022 10:22:05 +0800 Subject: [PATCH] streaming for pruned_transducer_stateless4 --- .../pruned_transducer_stateless4/decode.py | 107 +++++++++++++++++- .../ASR/pruned_transducer_stateless4/train.py | 86 +++++++++++++- 2 files changed, 190 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 9982cc530..50e5c5a09 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -54,6 +54,18 @@ Usage: --beam 4 \ --max-contexts 4 \ --max-states 8 + +(5) decode in streaming mode (take greedy search as an example) +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ + --avg 15 \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --right-chunk-size 16 \ + --left-context 64 \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method greedy_search """ @@ -91,6 +103,7 @@ from icefall.utils import ( write_error_stats, ) +LOG_EPS = math.log(1e-10) def get_parser(): parser = argparse.ArgumentParser( @@ -212,6 +225,70 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="""How many left context can be seen in chunks when calculating attention. + Note: not needed for decoding, adding it here to construct transducer model, + as we reuse the code in train.py. + """, + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--right-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + return parser @@ -260,9 +337,26 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, ) + + if params.simulate_streaming: + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.right_chunk_size, + left_context=params.left_context, + simulate_streaming=True + ) + else: + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + hyps = [] if params.decoding_method == "fast_beam_search": @@ -475,6 +569,10 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -507,6 +605,11 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + logging.info(params) logging.info("About to create model") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 4ff69d521..89428110e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -41,6 +41,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --full-libri 1 \ --max-duration 550 +# train a streaming model +./pruned_transducer_stateless4/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless4/exp \ + --full-libri 1 \ + --dynamic-chunk-training 1 \ + --causal-convolution 1 \ + --short-chunk-size 25 \ + --num-left-chunks 4 \ + --max-duration 300 + """ @@ -281,6 +294,59 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value to penalize symbol delay, this may be + needed when training with time masking, to avoid the time masking + encouraging the network to delay symbols. + """, + ) + + parser.add_argument( + "--return-sym-delay", + type=str2bool, + default=False, + help="""Whether to return `sym_delay` during training, this is a stat + to measure symbols emission delay, especially for time masking training. + """, + ) + return parser @@ -367,6 +433,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, + dynamic_chunk_training=params.dynamic_chunk_training, + short_chunk_size=params.short_chunk_size, + num_left_chunks=params.num_left_chunks, + causal=params.causal_convolution, ) return encoder @@ -571,7 +641,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + simple_loss, pruned_loss, sym_delay = model( x=feature, x_lens=feature_lens, y=y, @@ -579,6 +649,8 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, warmup=warmup, + delay_penalty=params.delay_penalty, + return_sym_delay=params.return_sym_delay, ) # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid @@ -608,6 +680,9 @@ def compute_loss( info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.return_sym_delay: + info["sym_delay"] = sym_delay.detach().cpu().item() + return loss, info @@ -847,6 +922,15 @@ def run(rank, world_size, args): params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if params.dynamic_chunk_training: + assert ( + params.causal_convolution + ), "dynamic_chunk_training requires causal convolution" + else: + assert ( + params.delay_penalty == 0.0 + ), "delay_penalty is intended for dynamic_chunk_training" + logging.info(params) logging.info("About to create model")