mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
webdataset dataload for dev
This commit is contained in:
parent
80b2cfee23
commit
c6b2c3b038
@ -360,15 +360,22 @@ class WenetSpeechAsrDataModule:
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
buffer_size=30000,
|
||||
rank=0,
|
||||
world_size=1,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
|
||||
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
|
||||
|
||||
dev_iter_dataset = IterableDatasetWrapper(
|
||||
dataset=validate,
|
||||
sampler=valid_sampler,
|
||||
)
|
||||
valid_dl = DataLoader(
|
||||
dev_iter_dataset,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
@ -410,13 +417,13 @@ class WenetSpeechAsrDataModule:
|
||||
logging.info("use lazy cuts")
|
||||
cuts_train = CutSet.from_jsonl_lazy(
|
||||
self.args.manifest_dir
|
||||
/ "cuts_L_50_pieces.jsonl.gz"
|
||||
/ "cuts_L.jsonl.gz"
|
||||
# use cuts_L_50_pieces.jsonl.gz for original experiments
|
||||
)
|
||||
else:
|
||||
cuts_train = CutSet.from_file(
|
||||
self.args.manifest_dir
|
||||
/ "cuts_L_50_pieces.jsonl.gz"
|
||||
/ "cuts_L.jsonl.gz"
|
||||
# use cuts_L_50_pieces.jsonl.gz for original experiments
|
||||
)
|
||||
return cuts_train
|
||||
|
@ -531,9 +531,18 @@ def main():
|
||||
|
||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
||||
|
||||
dev = "dev"
|
||||
test_net = "test_net"
|
||||
test_meet = "test_meet"
|
||||
|
||||
if not os.path.exists(f"{dev}/shared-0.tar"):
|
||||
dev_cuts = wenetspeech.valid_cuts()
|
||||
export_to_webdataset(
|
||||
dev_cuts,
|
||||
output_path=f"{dev}/shared-%d.tar",
|
||||
shard_size=300,
|
||||
)
|
||||
|
||||
if not os.path.exists(f"{test_net}/shared-0.tar"):
|
||||
test_net_cuts = wenetspeech.test_net_cuts()
|
||||
export_to_webdataset(
|
||||
@ -550,6 +559,17 @@ def main():
|
||||
shard_size=300,
|
||||
)
|
||||
|
||||
dev_shards = [
|
||||
str(path)
|
||||
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
|
||||
]
|
||||
cuts_dev_webdataset = CutSet.from_webdataset(
|
||||
dev_shards,
|
||||
split_by_worker=True,
|
||||
split_by_node=True,
|
||||
shuffle_shards=True,
|
||||
)
|
||||
|
||||
test_net_shards = [
|
||||
str(path)
|
||||
for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
|
||||
@ -572,11 +592,12 @@ def main():
|
||||
shuffle_shards=True,
|
||||
)
|
||||
|
||||
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
|
||||
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
|
||||
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meet_webdataset)
|
||||
|
||||
test_sets = ["TEST_NET", "TEST_MEETING"]
|
||||
test_dl = [test_net_dl, test_meeting_dl]
|
||||
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
|
||||
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
|
Loading…
x
Reference in New Issue
Block a user