mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
webdataset dataload for dev
This commit is contained in:
parent
80b2cfee23
commit
c6b2c3b038
@ -360,15 +360,22 @@ class WenetSpeechAsrDataModule:
|
|||||||
valid_sampler = DynamicBucketingSampler(
|
valid_sampler = DynamicBucketingSampler(
|
||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
buffer_size=30000,
|
rank=0,
|
||||||
|
world_size=1,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
logging.info("About to create dev dataloader")
|
logging.info("About to create dev dataloader")
|
||||||
valid_dl = DataLoader(
|
|
||||||
validate,
|
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
|
||||||
|
|
||||||
|
dev_iter_dataset = IterableDatasetWrapper(
|
||||||
|
dataset=validate,
|
||||||
sampler=valid_sampler,
|
sampler=valid_sampler,
|
||||||
|
)
|
||||||
|
valid_dl = DataLoader(
|
||||||
|
dev_iter_dataset,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
num_workers=2,
|
num_workers=self.args.num_workers,
|
||||||
persistent_workers=False,
|
persistent_workers=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -410,13 +417,13 @@ class WenetSpeechAsrDataModule:
|
|||||||
logging.info("use lazy cuts")
|
logging.info("use lazy cuts")
|
||||||
cuts_train = CutSet.from_jsonl_lazy(
|
cuts_train = CutSet.from_jsonl_lazy(
|
||||||
self.args.manifest_dir
|
self.args.manifest_dir
|
||||||
/ "cuts_L_50_pieces.jsonl.gz"
|
/ "cuts_L.jsonl.gz"
|
||||||
# use cuts_L_50_pieces.jsonl.gz for original experiments
|
# use cuts_L_50_pieces.jsonl.gz for original experiments
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cuts_train = CutSet.from_file(
|
cuts_train = CutSet.from_file(
|
||||||
self.args.manifest_dir
|
self.args.manifest_dir
|
||||||
/ "cuts_L_50_pieces.jsonl.gz"
|
/ "cuts_L.jsonl.gz"
|
||||||
# use cuts_L_50_pieces.jsonl.gz for original experiments
|
# use cuts_L_50_pieces.jsonl.gz for original experiments
|
||||||
)
|
)
|
||||||
return cuts_train
|
return cuts_train
|
||||||
|
@ -531,9 +531,18 @@ def main():
|
|||||||
|
|
||||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
wenetspeech = WenetSpeechAsrDataModule(args)
|
||||||
|
|
||||||
|
dev = "dev"
|
||||||
test_net = "test_net"
|
test_net = "test_net"
|
||||||
test_meet = "test_meet"
|
test_meet = "test_meet"
|
||||||
|
|
||||||
|
if not os.path.exists(f"{dev}/shared-0.tar"):
|
||||||
|
dev_cuts = wenetspeech.valid_cuts()
|
||||||
|
export_to_webdataset(
|
||||||
|
dev_cuts,
|
||||||
|
output_path=f"{dev}/shared-%d.tar",
|
||||||
|
shard_size=300,
|
||||||
|
)
|
||||||
|
|
||||||
if not os.path.exists(f"{test_net}/shared-0.tar"):
|
if not os.path.exists(f"{test_net}/shared-0.tar"):
|
||||||
test_net_cuts = wenetspeech.test_net_cuts()
|
test_net_cuts = wenetspeech.test_net_cuts()
|
||||||
export_to_webdataset(
|
export_to_webdataset(
|
||||||
@ -550,6 +559,17 @@ def main():
|
|||||||
shard_size=300,
|
shard_size=300,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dev_shards = [
|
||||||
|
str(path)
|
||||||
|
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
|
||||||
|
]
|
||||||
|
cuts_dev_webdataset = CutSet.from_webdataset(
|
||||||
|
dev_shards,
|
||||||
|
split_by_worker=True,
|
||||||
|
split_by_node=True,
|
||||||
|
shuffle_shards=True,
|
||||||
|
)
|
||||||
|
|
||||||
test_net_shards = [
|
test_net_shards = [
|
||||||
str(path)
|
str(path)
|
||||||
for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
|
for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
|
||||||
@ -572,11 +592,12 @@ def main():
|
|||||||
shuffle_shards=True,
|
shuffle_shards=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
|
||||||
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
|
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
|
||||||
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meet_webdataset)
|
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meet_webdataset)
|
||||||
|
|
||||||
test_sets = ["TEST_NET", "TEST_MEETING"]
|
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
||||||
test_dl = [test_net_dl, test_meeting_dl]
|
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user