mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
use quadratic-duration
This commit is contained in:
parent
c75767f600
commit
cd3adad46d
@ -109,6 +109,25 @@ class AsrDataModule:
|
|||||||
help="The number of buckets for the DynamicBucketingSampler"
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
"(you might want to increase it for larger datasets).",
|
"(you might want to increase it for larger datasets).",
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-cuts-for-bins-estimate",
|
||||||
|
type=int,
|
||||||
|
default=10000,
|
||||||
|
help="We will draw this many cuts to estimate the duration"
|
||||||
|
"bins for creating similar-duration buckets. Larger number"
|
||||||
|
"means a better estimate to the data distribution, possibly"
|
||||||
|
"at a longer init cost.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--quadratic-duration",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="When set, it adds an extra penalty that's quadratic"
|
||||||
|
"in size w.r.t. a cuts duration. This helps get a more"
|
||||||
|
"even GPU utilization across different input lengths when"
|
||||||
|
"models have quadratic input complexity.0 Set between 15"
|
||||||
|
"and 40 for transformers.",
|
||||||
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--concatenate-cuts",
|
"--concatenate-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -205,6 +224,8 @@ class AsrDataModule:
|
|||||||
self,
|
self,
|
||||||
cuts_train: CutSet,
|
cuts_train: CutSet,
|
||||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -295,11 +316,15 @@ class AsrDataModule:
|
|||||||
train_sampler = DynamicBucketingSampler(
|
train_sampler = DynamicBucketingSampler(
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
|
quadratic_duration=self.args.quadratic_duration,
|
||||||
|
num_cuts_for_bins_estimate=self.args.num_cuts_for_bins_estimate,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
buffer_size=self.args.num_buckets * 2000,
|
buffer_size=self.args.num_buckets * 2000,
|
||||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SimpleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
@ -307,6 +332,8 @@ class AsrDataModule:
|
|||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
@ -330,7 +357,12 @@ class AsrDataModule:
|
|||||||
|
|
||||||
return train_dl
|
return train_dl
|
||||||
|
|
||||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
def valid_dataloaders(
|
||||||
|
self,
|
||||||
|
cuts_valid: CutSet,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
|
) -> DataLoader:
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.concatenate_cuts:
|
if self.args.concatenate_cuts:
|
||||||
transforms = [
|
transforms = [
|
||||||
@ -355,6 +387,8 @@ class AsrDataModule:
|
|||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create dev dataloader")
|
logging.info("About to create dev dataloader")
|
||||||
valid_dl = DataLoader(
|
valid_dl = DataLoader(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user