webdataset dataload for dev

This commit is contained in:
luomingshuang 2022-04-18 12:33:59 +08:00
parent 80b2cfee23
commit c6b2c3b038
2 changed files with 36 additions and 8 deletions

View File

@ -360,15 +360,22 @@ class WenetSpeechAsrDataModule:
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
buffer_size=30000,
rank=0,
world_size=1,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
dev_iter_dataset = IterableDatasetWrapper(
dataset=validate,
sampler=valid_sampler,
)
valid_dl = DataLoader(
dev_iter_dataset,
batch_size=None,
num_workers=2,
num_workers=self.args.num_workers,
persistent_workers=False,
)
@ -410,13 +417,13 @@ class WenetSpeechAsrDataModule:
logging.info("use lazy cuts")
cuts_train = CutSet.from_jsonl_lazy(
self.args.manifest_dir
/ "cuts_L_50_pieces.jsonl.gz"
/ "cuts_L.jsonl.gz"
# use cuts_L_50_pieces.jsonl.gz for original experiments
)
else:
cuts_train = CutSet.from_file(
self.args.manifest_dir
/ "cuts_L_50_pieces.jsonl.gz"
/ "cuts_L.jsonl.gz"
# use cuts_L_50_pieces.jsonl.gz for original experiments
)
return cuts_train

View File

@ -531,9 +531,18 @@ def main():
wenetspeech = WenetSpeechAsrDataModule(args)
dev = "dev"
test_net = "test_net"
test_meet = "test_meet"
if not os.path.exists(f"{dev}/shared-0.tar"):
dev_cuts = wenetspeech.valid_cuts()
export_to_webdataset(
dev_cuts,
output_path=f"{dev}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_net}/shared-0.tar"):
test_net_cuts = wenetspeech.test_net_cuts()
export_to_webdataset(
@ -550,6 +559,17 @@ def main():
shard_size=300,
)
dev_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
]
cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_net_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
@ -572,11 +592,12 @@ def main():
shuffle_shards=True,
)
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meet_webdataset)
test_sets = ["TEST_NET", "TEST_MEETING"]
test_dl = [test_net_dl, test_meeting_dl]
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(