From 6b1ab71dc9c715fe08f5ba7dadc6d7c083be904c Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 27 Jan 2023 21:24:12 +0800 Subject: [PATCH] hardcode --filter-uneven-sized-batch (#854) --- .../ASR/pruned_transducer_stateless7/train.py | 32 ++++++------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index a806244ff..6022406eb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -374,21 +374,6 @@ def get_parser(): help="Whether to use half precision training.", ) - parser.add_argument( - "--filter-uneven-sized-batch", - type=str2bool, - default=True, - help="""Whether to filter uneven-sized minibatch. - For the uneven-sized batch, the total duration after padding would possibly - cause OOM. Hence, for each batch, which is sorted descendingly by length, - we simply drop the last few shortest samples, so that the retained total frames - (after padding) would not exceed `allowed_max_frames`: - `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - where `max_frames = max_duration * 1000 // frame_shift_ms`. - We set allowed_excess_duration_ratio=0.1. - """, - ) - add_model_arguments(parser) return parser @@ -442,7 +427,6 @@ def get_params() -> AttributeDict: params = AttributeDict( { "frame_shift_ms": 10.0, - # only used when params.filter_uneven_sized_batch is True "allowed_excess_duration_ratio": 0.1, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), @@ -666,12 +650,16 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - if params.filter_uneven_sized_batch: - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int( - max_frames * (1.0 + params.allowed_excess_duration_ratio) - ) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"]