mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
use quadratic-duration
This commit is contained in:
parent
c75767f600
commit
cd3adad46d
@ -109,6 +109,25 @@ class AsrDataModule:
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-cuts-for-bins-estimate",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="We will draw this many cuts to estimate the duration"
|
||||
"bins for creating similar-duration buckets. Larger number"
|
||||
"means a better estimate to the data distribution, possibly"
|
||||
"at a longer init cost.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--quadratic-duration",
|
||||
type=float,
|
||||
default=None,
|
||||
help="When set, it adds an extra penalty that's quadratic"
|
||||
"in size w.r.t. a cuts duration. This helps get a more"
|
||||
"even GPU utilization across different input lengths when"
|
||||
"models have quadratic input complexity.0 Set between 15"
|
||||
"and 40 for transformers.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
@ -205,6 +224,8 @@ class AsrDataModule:
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
world_size: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
@ -295,11 +316,15 @@ class AsrDataModule:
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
quadratic_duration=self.args.quadratic_duration,
|
||||
num_cuts_for_bins_estimate=self.args.num_cuts_for_bins_estimate,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
@ -307,6 +332,8 @@ class AsrDataModule:
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
@ -330,7 +357,12 @@ class AsrDataModule:
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
def valid_dataloaders(
|
||||
self,
|
||||
cuts_valid: CutSet,
|
||||
world_size: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
@ -355,6 +387,8 @@ class AsrDataModule:
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
|
Loading…
x
Reference in New Issue
Block a user