diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 31a3a0505..a806244ff 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -82,7 +82,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -368,6 +374,21 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--filter-uneven-sized-batch", + type=str2bool, + default=True, + help="""Whether to filter uneven-sized minibatch. + For the uneven-sized batch, the total duration after padding would possibly + cause OOM. Hence, for each batch, which is sorted descendingly by length, + we simply drop the last few shortest samples, so that the retained total frames + (after padding) would not exceed `allowed_max_frames`: + `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + where `max_frames = max_duration * 1000 // frame_shift_ms`. + We set allowed_excess_duration_ratio=0.1. + """, + ) + add_model_arguments(parser) return parser @@ -420,6 +441,9 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { + "frame_shift_ms": 10.0, + # only used when params.filter_uneven_sized_batch is True + "allowed_excess_duration_ratio": 0.1, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -642,6 +666,13 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ + if params.filter_uneven_sized_batch: + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int( + max_frames * (1.0 + params.allowed_excess_duration_ratio) + ) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) diff --git a/icefall/utils.py b/icefall/utils.py index 99e51a2a9..ba0b7fe43 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1395,3 +1395,39 @@ def is_module_available(*modules: str) -> bool: import importlib return all(importlib.util.find_spec(m) is not None for m in modules) + + +def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int): + """For the uneven-sized batch, the total duration after padding would possibly + cause OOM. Hence, for each batch, which is sorted descendingly by length, + we simply drop the last few shortest samples, so that the retained total frames + (after padding) would not exceed the given allow_max_frames. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + allowed_max_frames: + The allowed max number of frames in batch. + """ + features = batch["inputs"] + supervisions = batch["supervisions"] + + N, T, _ = features.size() + assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max()) + keep_num_utt = allowed_max_frames // T + + if keep_num_utt >= N: + return batch + + # Note: we assume the samples in batch is sorted descendingly by length + logging.info( + f"Filtering uneven-sized batch, original batch size is {N}, " + f"retained batch size is {keep_num_utt}." + ) + batch["inputs"] = features[:keep_num_utt] + for k, v in supervisions.items(): + assert len(v) == N, (len(v), N) + batch["supervisions"][k] = v[:keep_num_utt] + + return batch