Fix wenetspeech decoding speed (#953)

This commit is contained in:
Wei Kang 2023-03-21 21:35:32 +08:00 committed by GitHub
parent 7948624a22
commit d74822d07b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 164 deletions

View File

@ -20,7 +20,7 @@ import logging
from pathlib import Path
import torch
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomHdf5Writer
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
@ -69,7 +69,7 @@ def compute_fbank_wenetspeech_dev_test():
storage_path=f"{in_out_dir}/feats_{partition}",
num_workers=num_workers,
batch_duration=batch_duration,
storage_type=LilcomHdf5Writer,
storage_type=LilcomChunkyWriter,
overwrite=True,
)

View File

@ -46,9 +46,6 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool
set_caching_enabled(False)
torch.set_num_threads(1)
class _SeedWorkers:
def __init__(self, seed: int):
@ -348,24 +345,18 @@ class WenetSpeechAsrDataModule:
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False,
)
logging.info("About to create dev dataloader")
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
dev_iter_dataset = IterableDatasetWrapper(
dataset=validate,
sampler=valid_sampler,
)
valid_dl = DataLoader(
dev_iter_dataset,
validate,
batch_size=None,
sampler=valid_sampler,
num_workers=self.args.num_workers,
persistent_workers=False,
)
@ -383,19 +374,13 @@ class WenetSpeechAsrDataModule:
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False,
)
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
test_iter_dataset = IterableDatasetWrapper(
dataset=test,
sampler=sampler,
)
test_dl = DataLoader(
test_iter_dataset,
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl

View File

@ -651,83 +651,18 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# Note: Please use "pip install webdataset==0.1.103"
# for installing the webdataset.
import glob
import os
from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset
# we need cut ids to display recognition results.
args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args)
dev = "dev"
test_net = "test_net"
test_meeting = "test_meeting"
dev_cuts = wenetspeech.valid_cuts()
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
if not os.path.exists(f"{dev}/shared-0.tar"):
os.makedirs(dev)
dev_cuts = wenetspeech.valid_cuts()
export_to_webdataset(
dev_cuts,
output_path=f"{dev}/shared-%d.tar",
shard_size=300,
)
test_net_cuts = wenetspeech.test_net_cuts()
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
if not os.path.exists(f"{test_net}/shared-0.tar"):
os.makedirs(test_net)
test_net_cuts = wenetspeech.test_net_cuts()
export_to_webdataset(
test_net_cuts,
output_path=f"{test_net}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_meeting}/shared-0.tar"):
os.makedirs(test_meeting)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
export_to_webdataset(
test_meeting_cuts,
output_path=f"{test_meeting}/shared-%d.tar",
shard_size=300,
)
dev_shards = [
str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
]
cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_net_shards = [
str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
]
cuts_test_net_webdataset = CutSet.from_webdataset(
test_net_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_meeting_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
]
cuts_test_meeting_webdataset = CutSet.from_webdataset(
test_meeting_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl]

View File

@ -661,83 +661,18 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# Note: Please use "pip install webdataset==0.1.103"
# for installing the webdataset.
import glob
import os
from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset
# we need cut ids to display recognition results.
args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args)
dev = "dev"
test_net = "test_net"
test_meeting = "test_meeting"
dev_cuts = wenetspeech.valid_cuts()
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
if not os.path.exists(f"{dev}/shared-0.tar"):
os.makedirs(dev)
dev_cuts = wenetspeech.valid_cuts()
export_to_webdataset(
dev_cuts,
output_path=f"{dev}/shared-%d.tar",
shard_size=300,
)
test_net_cuts = wenetspeech.test_net_cuts()
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
if not os.path.exists(f"{test_net}/shared-0.tar"):
os.makedirs(test_net)
test_net_cuts = wenetspeech.test_net_cuts()
export_to_webdataset(
test_net_cuts,
output_path=f"{test_net}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_meeting}/shared-0.tar"):
os.makedirs(test_meeting)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
export_to_webdataset(
test_meeting_cuts,
output_path=f"{test_meeting}/shared-%d.tar",
shard_size=300,
)
dev_shards = [
str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
]
cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_net_shards = [
str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
]
cuts_test_net_webdataset = CutSet.from_webdataset(
test_net_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_meeting_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
]
cuts_test_meeting_webdataset = CutSet.from_webdataset(
test_meeting_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl]