diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 60a948a99..34e8e8fb9 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -18,36 +18,36 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless2/decode.py \ +./lstm_transducer_stateless/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./lstm_transducer_stateless/exp \ --max-duration 600 \ --decoding-method greedy_search (2) beam search (not recommended) -./pruned_transducer_stateless2/decode.py \ +./lstm_transducer_stateless/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./lstm_transducer_stateless/exp \ --max-duration 600 \ --decoding-method beam_search \ --beam-size 4 (3) modified beam search -./pruned_transducer_stateless2/decode.py \ +./lstm_transducer_stateless/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./lstm_transducer_stateless/exp \ --max-duration 600 \ --decoding-method modified_beam_search \ --beam-size 4 (4) fast beam search (one best) -./pruned_transducer_stateless2/decode.py \ +./lstm_transducer_stateless/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./lstm_transducer_stateless/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ --beam 20.0 \ @@ -55,10 +55,10 @@ Usage: --max-states 64 (5) fast beam search (nbest) -./pruned_transducer_stateless2/decode.py \ +./lstm_transducer_stateless/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./lstm_transducer_stateless/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest \ --beam 20.0 \ @@ -68,10 +68,10 @@ Usage: --nbest-scale 0.5 (6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless2/decode.py \ +./lstm_transducer_stateless/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./lstm_transducer_stateless/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_oracle \ --beam 20.0 \ @@ -81,30 +81,15 @@ Usage: --nbest-scale 0.5 (7) fast beam search (with LG) -./pruned_transducer_stateless2/decode.py \ +./lstm_transducer_stateless/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./lstm_transducer_stateless/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_LG \ --beam 20.0 \ --max-contexts 8 \ --max-states 64 - -(8) decode in streaming mode (take greedy search as an example) -./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --simulate-streaming 1 \ - --causal-convolution 1 \ - --decode-chunk-size 16 \ - --left-context 64 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method greedy_search - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 """ @@ -130,7 +115,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import add_model_arguments, get_params, get_transducer_model +from train import get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -142,7 +127,6 @@ from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, - str2bool, write_error_stats, ) @@ -286,29 +270,6 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) - parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. - """, - ) - - parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - parser.add_argument( "--num-paths", type=int, @@ -327,7 +288,6 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) - add_model_arguments(parser) return parser @@ -387,18 +347,9 @@ def decode_one_batch( value=LOG_EPS, ) - if params.simulate_streaming: - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( - x=feature, - x_lens=feature_lens, - chunk_size=params.decode_chunk_size, - left_context=params.left_context, - simulate_streaming=True, - ) - else: - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -674,10 +625,6 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -712,11 +659,6 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - logging.info(params) logging.info("About to create model") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index f52cb22ab..21bcf7cfd 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -77,9 +77,7 @@ from beam_search import ( modified_beam_search, ) from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.utils import str2bool +from train import get_params, get_transducer_model def get_parser(): @@ -180,30 +178,6 @@ def get_parser(): """, ) - parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. - """, - ) - - parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", - ) - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - - add_model_arguments(parser) - return parser @@ -248,11 +222,6 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - logging.info(f"{params}") device = torch.device("cpu") @@ -299,18 +268,9 @@ def main(): feature_lengths = torch.tensor(feature_lengths, device=device) - if params.simulate_streaming: - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( - x=features, - x_lens=feature_lengths, - chunk_size=params.decode_chunk_size, - left_context=params.left_context, - simulate_streaming=True, - ) - else: - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py b/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py index 1858d6bf0..5c49025bd 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/test_model.py @@ -34,31 +34,6 @@ def test_model(): params.context_size = 2 params.unk_id = 2 - params.dynamic_chunk_training = False - params.short_chunk_size = 25 - params.num_left_chunks = 4 - params.causal_convolution = False - - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - torch.jit.script(model) - - -def test_model_streaming(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.unk_id = 2 - - params.dynamic_chunk_training = True - params.short_chunk_size = 25 - params.num_left_chunks = 4 - params.causal_convolution = True - model = get_transducer_model(params) num_param = sum([p.numel() for p in model.parameters()]) @@ -69,7 +44,6 @@ def test_model_streaming(): def main(): test_model() - test_model_streaming() if __name__ == "__main__": diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index 13175c4c2..8ce5bdc54 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -21,37 +21,24 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless2/train.py \ +./lstm_transducer_stateless/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ + --exp-dir lstm_transducer_stateless/exp \ --full-libri 1 \ --max-duration 300 # For mix precision training: -./pruned_transducer_stateless2/train.py \ +./lstm_transducer_stateless/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ --use-fp16 1 \ - --exp-dir pruned_transducer_stateless2/exp \ + --exp-dir lstm_transducer_stateless/exp \ --full-libri 1 \ --max-duration 550 - -# train a streaming model -./pruned_transducer_stateless2/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 0 \ - --exp-dir pruned_transducer_stateless/exp \ - --full-libri 1 \ - --dynamic-chunk-training 1 \ - --causal-convolution 1 \ - --short-chunk-size 25 \ - --num-left-chunks 4 \ - --max-duration 300 """ @@ -69,7 +56,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from conformer import Conformer +from lstm import RNN from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -95,42 +82,6 @@ LRSchedulerType = Union[ ] -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - """, - ) - - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -311,8 +262,6 @@ def get_parser(): help="Whether to use half precision training.", ) - add_model_arguments(parser) - return parser @@ -374,7 +323,6 @@ def get_params() -> AttributeDict: "feature_dim": 80, "subsampling_factor": 4, "encoder_dim": 512, - "nhead": 8, "dim_feedforward": 2048, "num_encoder_layers": 12, # parameters for decoder @@ -392,17 +340,12 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( + encoder = RNN( num_features=params.feature_dim, subsampling_factor=params.subsampling_factor, d_model=params.encoder_dim, - nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, - dynamic_chunk_training=params.dynamic_chunk_training, - short_chunk_size=params.short_chunk_size, - num_left_chunks=params.num_left_chunks, - causal=params.causal_convolution, ) return encoder