use quadratic-duration

This commit is contained in:
yfyeung 2025-05-10 17:47:05 +00:00
parent c75767f600
commit cd3adad46d

View File

@ -109,6 +109,25 @@ class AsrDataModule:
help="The number of buckets for the DynamicBucketingSampler" help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).", "(you might want to increase it for larger datasets).",
) )
group.add_argument(
"--num-cuts-for-bins-estimate",
type=int,
default=10000,
help="We will draw this many cuts to estimate the duration"
"bins for creating similar-duration buckets. Larger number"
"means a better estimate to the data distribution, possibly"
"at a longer init cost.",
)
group.add_argument(
"--quadratic-duration",
type=float,
default=None,
help="When set, it adds an extra penalty that's quadratic"
"in size w.r.t. a cuts duration. This helps get a more"
"even GPU utilization across different input lengths when"
"models have quadratic input complexity.0 Set between 15"
"and 40 for transformers.",
)
group.add_argument( group.add_argument(
"--concatenate-cuts", "--concatenate-cuts",
type=str2bool, type=str2bool,
@ -205,6 +224,8 @@ class AsrDataModule:
self, self,
cuts_train: CutSet, cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None, sampler_state_dict: Optional[Dict[str, Any]] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader: ) -> DataLoader:
""" """
Args: Args:
@ -295,11 +316,15 @@ class AsrDataModule:
train_sampler = DynamicBucketingSampler( train_sampler = DynamicBucketingSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
quadratic_duration=self.args.quadratic_duration,
num_cuts_for_bins_estimate=self.args.num_cuts_for_bins_estimate,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000, buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000, shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
) )
else: else:
logging.info("Using SimpleCutSampler.") logging.info("Using SimpleCutSampler.")
@ -307,6 +332,8 @@ class AsrDataModule:
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
world_size=world_size,
rank=rank,
) )
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
@ -330,7 +357,12 @@ class AsrDataModule:
return train_dl return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def valid_dataloaders(
self,
cuts_valid: CutSet,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
transforms = [] transforms = []
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
transforms = [ transforms = [
@ -355,6 +387,8 @@ class AsrDataModule:
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
world_size=world_size,
rank=rank,
) )
logging.info("About to create dev dataloader") logging.info("About to create dev dataloader")
valid_dl = DataLoader( valid_dl = DataLoader(