diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index a806244ff..6022406eb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -374,21 +374,6 @@ def get_parser(): help="Whether to use half precision training.", ) - parser.add_argument( - "--filter-uneven-sized-batch", - type=str2bool, - default=True, - help="""Whether to filter uneven-sized minibatch. - For the uneven-sized batch, the total duration after padding would possibly - cause OOM. Hence, for each batch, which is sorted descendingly by length, - we simply drop the last few shortest samples, so that the retained total frames - (after padding) would not exceed `allowed_max_frames`: - `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, - where `max_frames = max_duration * 1000 // frame_shift_ms`. - We set allowed_excess_duration_ratio=0.1. - """, - ) - add_model_arguments(parser) return parser @@ -442,7 +427,6 @@ def get_params() -> AttributeDict: params = AttributeDict( { "frame_shift_ms": 10.0, - # only used when params.filter_uneven_sized_batch is True "allowed_excess_duration_ratio": 0.1, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), @@ -666,12 +650,16 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - if params.filter_uneven_sized_batch: - max_frames = params.max_duration * 1000 // params.frame_shift_ms - allowed_max_frames = int( - max_frames * (1.0 + params.allowed_excess_duration_ratio) - ) - batch = filter_uneven_sized_batch(batch, allowed_max_frames) + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"]