diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py index 341579acb..c74d212d4 100644 --- a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py @@ -109,6 +109,25 @@ class AsrDataModule: help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) + group.add_argument( + "--num-cuts-for-bins-estimate", + type=int, + default=10000, + help="We will draw this many cuts to estimate the duration" + "bins for creating similar-duration buckets. Larger number" + "means a better estimate to the data distribution, possibly" + "at a longer init cost.", + ) + group.add_argument( + "--quadratic-duration", + type=float, + default=None, + help="When set, it adds an extra penalty that's quadratic" + "in size w.r.t. a cuts duration. This helps get a more" + "even GPU utilization across different input lengths when" + "models have quadratic input complexity.0 Set between 15" + "and 40 for transformers.", + ) group.add_argument( "--concatenate-cuts", type=str2bool, @@ -205,6 +224,8 @@ class AsrDataModule: self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, + world_size: Optional[int] = None, + rank: Optional[int] = None, ) -> DataLoader: """ Args: @@ -295,11 +316,15 @@ class AsrDataModule: train_sampler = DynamicBucketingSampler( cuts_train, max_duration=self.args.max_duration, + quadratic_duration=self.args.quadratic_duration, + num_cuts_for_bins_estimate=self.args.num_cuts_for_bins_estimate, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, buffer_size=self.args.num_buckets * 2000, shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, ) else: logging.info("Using SimpleCutSampler.") @@ -307,6 +332,8 @@ class AsrDataModule: cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, + world_size=world_size, + rank=rank, ) logging.info("About to create train dataloader") @@ -330,7 +357,12 @@ class AsrDataModule: return train_dl - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: Optional[int] = None, + rank: Optional[int] = None, + ) -> DataLoader: transforms = [] if self.args.concatenate_cuts: transforms = [ @@ -355,6 +387,8 @@ class AsrDataModule: cuts_valid, max_duration=self.args.max_duration, shuffle=False, + world_size=world_size, + rank=rank, ) logging.info("About to create dev dataloader") valid_dl = DataLoader(