diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py index 8c06eb4f9..51b9d19da 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py @@ -161,6 +161,13 @@ def get_parser(): """, ) + parser.add_argument( + "--giga-prob", + type=float, + default=0.2, + help="The probability to select a batch from the GigaSpeech dataset", + ) + return parser @@ -523,7 +530,7 @@ def train_one_epoch( # index 0: for LibriSpeech # index 1: for GigaSpeech # This sets the probabilities for choosing which datasets - dl_weights = [0.8, 0.2] + dl_weights = [1 - params.giga_prob, params.giga_prob] iter_libri = iter(train_dl) iter_giga = iter(giga_train_dl) @@ -861,6 +868,8 @@ def main(): args = parser.parse_args() args.exp_dir = Path(args.exp_dir) + assert 0 < args.giga_prob < 1, args.giga_prob + world_size = args.world_size assert world_size >= 1 if world_size > 1: