add webdataset for dataloading

This commit is contained in:
luomingshuang 2022-04-15 16:50:56 +08:00
parent 5319429d76
commit 3fe3a0c492
3 changed files with 92 additions and 17 deletions

View File

@ -383,12 +383,22 @@ class WenetSpeechAsrDataModule:
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False cuts,
max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False,
)
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
test_iter_dataset = IterableDatasetWrapper(
dataset=test,
sampler=sampler,
) )
test_dl = DataLoader( test_dl = DataLoader(
test, test_iter_dataset,
batch_size=None, batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
) )
return test_dl return test_dl

View File

@ -129,10 +129,22 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--lang-dir",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_char",
help="Path to the BPE model", help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--token-type",
type=str,
default="char",
help="""The type of token
It must be in ["char", "pinyin", "lazy_pinyin"]
""",
) )
parser.add_argument( parser.add_argument(
@ -268,8 +280,10 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
) )
# print(hyp_tokens)
# print(lexicon.token_table)
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -277,7 +291,7 @@ def decode_one_batch(
beam=params.beam_size, beam=params.beam_size,
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -358,6 +372,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
texts = [list(str(text)) for text in texts]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -371,8 +386,7 @@ def decode_dataset(
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split() this_batch.append((ref_text, hyp_words))
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -507,12 +521,59 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
wenetspeech = WenetSpeechAsrDataModule(args) # Note: Please use "pip install webdataset==0.1.103"
test_net_cuts = wenetspeech.test_net_cuts() # for installing the webdataset.
test_meeting_cuts = wenetspeech.test_meeting_cuts() import glob
import os
test_net_dl = wenetspeech.valid_dataloaders(test_net_cuts) from lhotse import CutSet
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) from lhotse.dataset.webdataset import export_to_webdataset
wenetspeech = WenetSpeechAsrDataModule(args)
test_net = "test_net"
test_meet = "test_meet"
if not os.path.exists(f"{test_net}/shared-0.tar"):
test_net_cuts = wenetspeech.test_net_cuts()
export_to_webdataset(
test_net_cuts,
output_path=f"{test_net}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_meet}/shared-0.tar"):
test_meeting_cuts = wenetspeech.test_meeting_cuts()
export_to_webdataset(
test_meeting_cuts,
output_path=f"{test_meet}/shared-%d.tar",
shard_size=300,
)
test_net_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
]
cuts_test_net_webdataset = CutSet.from_webdataset(
test_net_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_meet_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_meet, "shared-*.tar")))
]
cuts_test_meet_webdataset = CutSet.from_webdataset(
test_meet_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meet_webdataset)
test_sets = ["TEST_NET", "TEST_MEETING"] test_sets = ["TEST_NET", "TEST_MEETING"]
test_dl = [test_net_dl, test_meeting_dl] test_dl = [test_net_dl, test_meeting_dl]

View File

@ -45,6 +45,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -84,6 +85,8 @@ LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
] ]
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -332,7 +335,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 2000, "valid_interval": 3000,
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
@ -867,6 +870,7 @@ def run(rank, world_size, args):
wenetspeech = WenetSpeechAsrDataModule(args) wenetspeech = WenetSpeechAsrDataModule(args)
train_cuts = wenetspeech.train_cuts() train_cuts = wenetspeech.train_cuts()
valid_cuts = wenetspeech.valid_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
@ -890,8 +894,8 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
if params.token_type == "pinyin": if params.token_type == "pinyin":
train_cuts = train_cuts.map(text_to_words) train_cuts = train_cuts.map(text_to_words)
# valid_cuts = valid_cuts.map(text_to_words)
valid_cuts = wenetspeech.valid_cuts()
valid_dl = wenetspeech.valid_dataloaders(valid_cuts) valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: