More fixes to use lazy CutSet.

This commit is contained in:
Fangjun Kuang 2022-06-05 22:23:32 +08:00
parent 0040ff2157
commit 113818fd00
16 changed files with 157 additions and 103 deletions

View File

@ -43,7 +43,7 @@ torch.set_num_interop_threads(1)
def compute_fbank_alimeeting(num_mel_bins: int = 80):
src_dir = Path("data/manifests/alimeeting")
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -52,11 +52,14 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
"eval",
"test",
)
prefix = "alimeeting"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix="alimeeting",
suffix="jsonl.gz",
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
@ -64,7 +67,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.json.gz").is_file():
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
@ -83,7 +86,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/feats_{partition}",
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=cur_num_jobs,
executor=ex,
@ -95,7 +98,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80):
keep_overlapping=False,
min_duration=None,
)
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
def get_args():

View File

@ -25,19 +25,19 @@ for usage.
"""
from lhotse import load_manifest
from lhotse import load_manifest_lazy
def main():
paths = [
"./data/fbank/cuts_train.json.gz",
"./data/fbank/cuts_eval.json.gz",
"./data/fbank/cuts_test.json.gz",
"./data/fbank/alimeeting_cuts_train.jsonl.gz",
"./data/fbank/alimeeting_cuts_eval.jsonl.gz",
"./data/fbank/alimeeting_cuts_test.jsonl.gz",
]
for path in paths:
print(f"Starting display the statistics for {path}")
cuts = load_manifest(path)
cuts = load_manifest_lazy(path)
cuts.describe()
@ -45,7 +45,7 @@ if __name__ == "__main__":
main()
"""
Starting display the statistics for ./data/fbank/cuts_train.json.gz
Starting display the statistics for ./data/fbank/alimeeting_cuts_train.jsonl.gz
Cuts count: 559092
Total duration (hours): 424.6
Speech duration (hours): 424.6 (100.0%)
@ -61,7 +61,7 @@ min 0.0
99.5% 14.7
99.9% 16.2
max 284.3
Starting display the statistics for ./data/fbank/cuts_eval.json.gz
Starting display the statistics for ./data/fbank/alimeeting_cuts_eval.jsonl.gz
Cuts count: 6457
Total duration (hours): 4.9
Speech duration (hours): 4.9 (100.0%)
@ -77,7 +77,7 @@ min 0.1
99.5% 14.1
99.9% 14.7
max 15.8
Starting display the statistics for ./data/fbank/cuts_test.json.gz
Starting display the statistics for ./data/fbank/alimeeting_cuts_test.jsonl.gz
Cuts count: 16358
Total duration (hours): 12.5
Speech duration (hours): 12.5 (100.0%)

View File

@ -27,7 +27,7 @@ from lhotse import (
CutSet,
Fbank,
FbankConfig,
load_manifest,
load_manifest_lazy,
set_caching_enabled,
)
from lhotse.dataset import (
@ -204,8 +204,8 @@ class AlimeetingAsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "cuts_musan.json.gz"
cuts_musan = load_manifest_lazy(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = []
@ -401,14 +401,20 @@ class AlimeetingAsrDataModule:
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest(self.args.manifest_dir / "cuts_train.json.gz")
return load_manifest_lazy(
self.args.manifest_dir / "alimeeting_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest(self.args.manifest_dir / "cuts_eval.json.gz")
return load_manifest_lazy(
self.args.manifest_dir / "alimeeting_cuts_eval.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
return load_manifest(self.args.manifest_dir / "cuts_test.json.gz")
return load_manifest_lazy(
self.args.manifest_dir / "alimeeting_cuts_test.jsonl.gz"
)

View File

@ -20,7 +20,7 @@ import logging
from functools import lru_cache
from pathlib import Path
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
CutMix,
@ -189,7 +189,7 @@ class GigaSpeechAsrDataModule:
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
cuts_musan = load_manifest_lazy(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
@ -362,7 +362,9 @@ class GigaSpeechAsrDataModule:
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
cuts_valid = load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
)
if self.args.small_dev:
return cuts_valid.subset(first=1000)
else:
@ -371,4 +373,6 @@ class GigaSpeechAsrDataModule:
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
return load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
)

View File

@ -23,7 +23,7 @@ from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
CutMix,
@ -216,7 +216,7 @@ class GigaSpeechAsrDataModule:
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
cuts_musan = load_manifest_lazy(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms.append(
@ -405,7 +405,9 @@ class GigaSpeechAsrDataModule:
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
cuts_valid = load_manifest_lazy(
self.args.manifest_dir / "cuts_DEV.jsonl.gz"
)
if self.args.small_dev:
return cuts_valid.subset(first=1000)
else:
@ -414,4 +416,4 @@ class GigaSpeechAsrDataModule:
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")

View File

@ -96,14 +96,14 @@ def get_parser():
- labels_xxx.h5
- aux_labels_xxx.h5
- cuts_xxx.json.gz
- librispeech_cuts_xxx.jsonl.gz
where xxx is the value of `--dataset`. For instance, if
`--dataset` is `train-clean-100`, it will contain 3 files:
- `labels_train-clean-100.h5`
- `aux_labels_train-clean-100.h5`
- `cuts_train-clean-100.json.gz`
- `librispeech_cuts_train-clean-100.jsonl.gz`
Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise
alignment. The difference is that labels_xxx.h5 contains repeats.
@ -289,7 +289,9 @@ def main():
out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5"
out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5"
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
out_manifest_filename = (
out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz"
)
for f in (
out_labels_ali_filename,

View File

@ -22,7 +22,7 @@ from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
CutMix,
@ -176,7 +176,7 @@ class SPGISpeechAsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
cuts_musan = load_manifest_lazy(
self.args.manifest_dir / "cuts_musan.jsonl.gz"
)

View File

@ -52,8 +52,13 @@ def compute_fbank_tedlium():
"test",
)
prefix = "tedlium"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
prefix="tedlium", dataset_parts=dataset_parts, output_dir=src_dir
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
@ -61,7 +66,7 @@ def compute_fbank_tedlium():
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.json.gz").is_file():
if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
@ -80,7 +85,7 @@ def compute_fbank_tedlium():
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/feats_{partition}",
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=cur_num_jobs,
executor=ex,
@ -88,7 +93,7 @@ def compute_fbank_tedlium():
)
# Split long cuts into many short and un-overlapping cuts
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
if __name__ == "__main__":

View File

@ -27,15 +27,15 @@ for usage.
"""
from lhotse import load_manifest
from lhotse import load_manifest_lazy
def main():
path = "./data/fbank/cuts_train.json.gz"
path = "./data/fbank/cuts_dev.json.gz"
path = "./data/fbank/cuts_test.json.gz"
path = "./data/fbank/tedlium_cuts_train.jsonl.gz"
path = "./data/fbank/tedlium_cuts_dev.jsonl.gz"
path = "./data/fbank/tedlium_cuts_test.jsonl.gz"
cuts = load_manifest(path)
cuts = load_manifest_lazy(path)
cuts.describe()

View File

@ -22,11 +22,11 @@ import logging
from functools import lru_cache
from pathlib import Path
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
@ -92,7 +92,7 @@ class TedLiumAsrDataModule:
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the BucketingSampler"
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
@ -179,8 +179,8 @@ class TedLiumAsrDataModule:
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
cuts_musan = load_manifest(
self.args.manifest_dir / "cuts_musan.json.gz"
cuts_musan = load_manifest_lazy(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms.append(
CutMix(
@ -261,13 +261,12 @@ class TedLiumAsrDataModule:
)
if self.args.bucketing_sampler:
logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler(
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
)
else:
@ -311,7 +310,7 @@ class TedLiumAsrDataModule:
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = BucketingSampler(
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
@ -335,8 +334,10 @@ class TedLiumAsrDataModule:
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
@ -350,14 +351,20 @@ class TedLiumAsrDataModule:
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest(self.args.manifest_dir / "cuts_train.json.gz")
return load_manifest_lazy(
self.args.manifest_dir / "tedlium_cuts_train.jsonl.gz"
)
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest(self.args.manifest_dir / "cuts_dev.json.gz")
return load_manifest_lazy(
self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest(self.args.manifest_dir / "cuts_test.json.gz")
return load_manifest_lazy(
self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz"
)

View File

@ -29,7 +29,7 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
@ -53,8 +53,13 @@ def compute_fbank_timit():
"DEV",
"TEST",
)
prefix = "timit"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
prefix="timit", dataset_parts=dataset_parts, output_dir=src_dir
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
@ -62,7 +67,8 @@ def compute_fbank_timit():
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.json.gz").is_file():
cuts_file = output_dir / f"{prefix}_cuts_{partition}.{suffix}"
if cuts_file.is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
@ -78,13 +84,13 @@ def compute_fbank_timit():
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/feats_{partition}",
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomHdf5Writer,
storage_type=LilcomChunkyWriter,
)
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
cut_set.to_file(cuts_file)
if __name__ == "__main__":

View File

@ -23,11 +23,11 @@ from functools import lru_cache
from pathlib import Path
from typing import List, Union
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
@ -92,7 +92,7 @@ class TimitAsrDataModule(DataModule):
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the BucketingSampler"
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
@ -154,7 +154,9 @@ class TimitAsrDataModule(DataModule):
cuts_train = self.train_cuts()
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
cuts_musan = load_manifest_lazy(
self.args.feature_dir / "cuts_musan.jsonl.gz"
)
logging.info("About to create train dataset")
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
@ -218,13 +220,12 @@ class TimitAsrDataModule(DataModule):
)
if self.args.bucketing_sampler:
logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler(
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
)
else:
@ -322,20 +323,26 @@ class TimitAsrDataModule(DataModule):
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(self.args.feature_dir / "cuts_TRAIN.json.gz")
cuts_train = load_manifest_lazy(
self.args.feature_dir / "timit_cuts_TRAIN.jsonl.gz"
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(self.args.feature_dir / "cuts_DEV.json.gz")
cuts_valid = load_manifest_lazy(
self.args.feature_dir / "timit_cuts_DEV.jsonl.gz"
)
return cuts_valid
@lru_cache()
def test_cuts(self) -> CutSet:
logging.debug("About to get test cuts")
cuts_test = load_manifest(self.args.feature_dir / "cuts_TEST.json.gz")
cuts_test = load_manifest_lazy(
self.args.feature_dir / "timit_cuts_TEST.jsonl.gz"
)
return cuts_test

View File

@ -26,7 +26,7 @@ for usage.
"""
from lhotse import load_manifest
from lhotse import load_manifest_lazy
def main():
@ -40,7 +40,7 @@ def main():
for path in paths:
print(f"Starting display the statistics for {path}")
cuts = load_manifest(path)
cuts = load_manifest_lazy(path)
cuts.describe()

View File

@ -27,7 +27,7 @@ from lhotse import (
CutSet,
Fbank,
FbankConfig,
load_manifest,
load_manifest_lazy,
set_caching_enabled,
)
from lhotse.dataset import (
@ -218,8 +218,8 @@ class WenetSpeechAsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "cuts_musan.json.gz"
cuts_musan = load_manifest_lazy(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms = []
@ -435,16 +435,18 @@ class WenetSpeechAsrDataModule:
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
@lru_cache()
def test_net_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_NET cuts")
return load_manifest(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz"
)
@lru_cache()
def test_meeting_cuts(self) -> List[CutSet]:
logging.info("About to get TEST_MEETING cuts")
return load_manifest(
return load_manifest_lazy(
self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz"
)

View File

@ -12,7 +12,7 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
@ -37,10 +37,13 @@ def compute_fbank_yesno():
"train",
"test",
)
prefix = "yesno"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix="yesno",
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
@ -50,7 +53,8 @@ def compute_fbank_yesno():
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.json.gz").is_file():
cuts_file = output_dir / f"{prefix}_cuts_{partition}.{suffix}"
if cuts_file.is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
@ -66,13 +70,13 @@ def compute_fbank_yesno():
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/feats_{partition}",
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 1, # use one job
executor=ex,
storage_type=LilcomHdf5Writer,
storage_type=LilcomChunkyWriter,
)
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
cut_set.to_file(cuts_file)
if __name__ == "__main__":

View File

@ -20,18 +20,19 @@ from functools import lru_cache
from pathlib import Path
from typing import List
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
class YesNoAsrDataModule(DataModule):
@ -84,7 +85,7 @@ class YesNoAsrDataModule(DataModule):
"--num-buckets",
type=int,
default=10,
help="The number of buckets for the BucketingSampler"
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
@ -186,18 +187,17 @@ class YesNoAsrDataModule(DataModule):
)
if self.args.bucketing_sampler:
logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler(
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = BucketingSampler(
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
@ -225,8 +225,10 @@ class YesNoAsrDataModule(DataModule):
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration, shuffle=False
sampler = DynamicBucketingSampler(
cuts_test,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
@ -240,11 +242,15 @@ class YesNoAsrDataModule(DataModule):
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(self.args.feature_dir / "cuts_train.json.gz")
cuts_train = load_manifest_lazy(
self.args.feature_dir / "yesno_cuts_train.jsonl.gz"
)
return cuts_train
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
cuts_test = load_manifest(self.args.feature_dir / "cuts_test.json.gz")
cuts_test = load_manifest_lazy(
self.args.feature_dir / "yesno_cuts_test.jsonl.gz"
)
return cuts_test