diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index ee7bfac53..d2345f6ad 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -99,7 +99,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -262,38 +262,6 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="""How many left context can be seen in chunks when calculating attention. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - parser.add_argument( "--simulate-streaming", type=str2bool, @@ -303,15 +271,6 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - parser.add_argument( "--decode-chunk-size", type=int, @@ -325,6 +284,8 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + add_model_arguments(parser) + return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index 8f3ec3a20..b5a151878 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -49,7 +49,7 @@ from pathlib import Path import sentencepiece as spm import torch -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -109,38 +109,6 @@ def get_parser(): "2 means tri-gram", ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - Note: not needed here, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - Note: not needed for here, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="""How many left context can be seen in chunks when calculating attention. - Note: not needed here, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - parser.add_argument( "--streaming-model", type=str2bool, @@ -150,14 +118,7 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - exporting a streaming model. - """, - ) + add_model_arguments(parser) return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index 87a35082a..6867dedb8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -44,7 +44,7 @@ from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet from torch.nn.utils.rnn import pad_sequence -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -57,7 +57,6 @@ from icefall.utils import ( get_texts, setup_logger, store_transcripts, - str2bool, write_error_stats, ) @@ -153,38 +152,6 @@ def get_parser(): "2 means tri-gram", ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="""How many left context can be seen in chunks when calculating attention. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - parser.add_argument( "--decode-chunk-size", type=int, @@ -213,6 +180,8 @@ def get_parser(): help="The number of streams that can be decoded parallel.", ) + add_model_arguments(parser) + return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 0740f64a5..3708c17ef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -86,6 +86,42 @@ from icefall.utils import ( ) +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -235,39 +271,7 @@ def get_parser(): """, ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - """, - ) - - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) + add_model_arguments(parser) return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 9cd494b05..1f79b93cf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -87,7 +87,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -215,38 +215,6 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="""How many left context can be seen in chunks when calculating attention. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - parser.add_argument( "--simulate-streaming", type=str2bool, @@ -256,15 +224,6 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - parser.add_argument( "--decode-chunk-size", type=int, @@ -279,6 +238,8 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + add_model_arguments(parser) + return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index 97f2facb2..f1a8ea589 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -49,7 +49,7 @@ from pathlib import Path import sentencepiece as spm import torch -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -124,38 +124,6 @@ def get_parser(): "2 means tri-gram", ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - Note: not needed here, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - Note: not needed for here, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="""How many left context can be seen in chunks when calculating attention. - Note: not needed here, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - parser.add_argument( "--streaming-model", type=str2bool, @@ -165,15 +133,7 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - exporting a streaming model. - """, - ) - + add_model_arguments(parser) return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index 8b5650ad7..850a574ae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -44,7 +44,7 @@ from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet from torch.nn.utils.rnn import pad_sequence -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -57,7 +57,6 @@ from icefall.utils import ( get_texts, setup_logger, store_transcripts, - str2bool, write_error_stats, ) @@ -153,38 +152,6 @@ def get_parser(): "2 means tri-gram", ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="""How many left context can be seen in chunks when calculating attention. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - parser.add_argument( "--decode-chunk-size", type=int, @@ -213,6 +180,8 @@ def get_parser(): help="The number of streams that can be decoded parallel.", ) + add_model_arguments(parser) + return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index ec0e26e05..a03712643 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -95,6 +95,42 @@ LRSchedulerType = Union[ ] +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -275,39 +311,7 @@ def get_parser(): help="Whether to use half precision training.", ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - """, - ) - - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) + add_model_arguments(parser) return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 154e1f074..c37e1685e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -77,7 +77,7 @@ from beam_search import ( modified_beam_search, ) from librispeech import LibriSpeech -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -224,38 +224,6 @@ def get_parser(): """, ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="""How many left context can be seen in chunks when calculating attention. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - parser.add_argument( "--simulate-streaming", type=str2bool, @@ -265,15 +233,6 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - parser.add_argument( "--decode-chunk-size", type=int, @@ -287,6 +246,9 @@ def get_parser(): default=64, help="left context can be seen during decoding (in frames after subsampling)", ) + + add_model_arguments(parser) + return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index bab8a9910..53ea306ff 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -50,7 +50,7 @@ from pathlib import Path import sentencepiece as spm import torch -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -125,38 +125,6 @@ def get_parser(): "2 means tri-gram", ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - Note: not needed here, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - Note: not needed for here, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="""How many left context can be seen in chunks when calculating attention. - Note: not needed here, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - parser.add_argument( "--streaming-model", type=str2bool, @@ -166,14 +134,7 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - exporting a streaming model. - """, - ) + add_model_arguments(parser) return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 18776c763..490ca54da 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -45,7 +45,7 @@ from kaldifeat import Fbank, FbankOptions from lhotse import CutSet from librispeech import LibriSpeech from torch.nn.utils.rnn import pad_sequence -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -58,7 +58,6 @@ from icefall.utils import ( get_texts, setup_logger, store_transcripts, - str2bool, write_error_stats, ) @@ -154,38 +153,6 @@ def get_parser(): "2 means tri-gram", ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="""How many left context can be seen in chunks when calculating attention. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - parser.add_argument( "--decode-chunk-size", type=int, @@ -214,6 +181,8 @@ def get_parser(): help="The number of streams that can be decoded parallel.", ) + add_model_arguments(parser) + return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 9793f02e5..70ce73504 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -91,6 +91,42 @@ LRSchedulerType = Union[ ] +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -286,40 +322,6 @@ def get_parser(): help="The probability to select a batch from the GigaSpeech dataset", ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - """, - ) - - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 0ab41a7ef..1adab7f71 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -88,7 +88,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -227,38 +227,6 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="""How many left context can be seen in chunks when calculating attention. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - parser.add_argument( "--simulate-streaming", type=str2bool, @@ -268,15 +236,6 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - parser.add_argument( "--decode-chunk-size", type=int, @@ -291,6 +250,8 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + add_model_arguments(parser) + return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index 270b68b0a..ce7518ceb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -50,7 +50,7 @@ from pathlib import Path import sentencepiece as spm import torch -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -137,38 +137,6 @@ def get_parser(): "2 means tri-gram", ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - Note: not needed here, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - Note: not needed for here, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="""How many left context can be seen in chunks when calculating attention. - Note: not needed here, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - parser.add_argument( "--streaming-model", type=str2bool, @@ -178,14 +146,7 @@ def get_parser(): """, ) - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - exporting a streaming model. - """, - ) + add_model_arguments(parser) return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index 2ca00ec40..ed14fc056 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -44,7 +44,7 @@ from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet from torch.nn.utils.rnn import pad_sequence -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -165,38 +165,6 @@ def get_parser(): "2 means tri-gram", ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="""How many left context can be seen in chunks when calculating attention. - Note: not needed for decoding, adding it here to construct transducer model, - as we reuse the code in train.py. - """, - ) - parser.add_argument( "--decode-chunk-size", type=int, @@ -225,6 +193,8 @@ def get_parser(): help="The number of streams that can be decoded parallel.", ) + add_model_arguments(parser) + return parser @@ -654,6 +624,7 @@ def main(): params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + # Decoding in streaming requires causal convolution params.causal_convolution = True diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 741ccf071..aa065eb0c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -100,6 +100,42 @@ LRSchedulerType = Union[ ] +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -293,39 +329,7 @@ def get_parser(): help="Whether to use half precision training.", ) - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=False, - help="""Whether to use dynamic_chunk_training, if you want a streaming - model, this requires to be True. - """, - ) - - parser.add_argument( - "--causal-convolution", - type=str2bool, - default=False, - help="""Whether to use causal convolution, this requires to be True when - using dynamic_chunk_training. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=25, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) + add_model_arguments(parser) return parser