mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
add webdataset for dataloading
This commit is contained in:
parent
5319429d76
commit
3fe3a0c492
@ -383,12 +383,22 @@ class WenetSpeechAsrDataModule:
|
|||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
cuts, max_duration=self.args.max_duration, shuffle=False
|
cuts,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
rank=0,
|
||||||
|
world_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
|
||||||
|
|
||||||
|
test_iter_dataset = IterableDatasetWrapper(
|
||||||
|
dataset=test,
|
||||||
|
sampler=sampler,
|
||||||
)
|
)
|
||||||
test_dl = DataLoader(
|
test_dl = DataLoader(
|
||||||
test,
|
test_iter_dataset,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
sampler=sampler,
|
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
)
|
)
|
||||||
return test_dl
|
return test_dl
|
||||||
|
@ -129,10 +129,22 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--lang-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_char",
|
||||||
help="Path to the BPE model",
|
help="""The lang dir
|
||||||
|
It contains language related input files such as
|
||||||
|
"lexicon.txt"
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--token-type",
|
||||||
|
type=str,
|
||||||
|
default="char",
|
||||||
|
help="""The type of token
|
||||||
|
It must be in ["char", "pinyin", "lazy_pinyin"]
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -268,8 +280,10 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
)
|
)
|
||||||
|
# print(hyp_tokens)
|
||||||
|
# print(lexicon.token_table)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -277,7 +291,7 @@ def decode_one_batch(
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
@ -358,6 +372,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
texts = [list(str(text)) for text in texts]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -371,8 +386,7 @@ def decode_dataset(
|
|||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for hyp_words, ref_text in zip(hyps, texts):
|
||||||
ref_words = ref_text.split()
|
this_batch.append((ref_text, hyp_words))
|
||||||
this_batch.append((ref_words, hyp_words))
|
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -507,12 +521,59 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
# Note: Please use "pip install webdataset==0.1.103"
|
||||||
test_net_cuts = wenetspeech.test_net_cuts()
|
# for installing the webdataset.
|
||||||
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
import glob
|
||||||
|
import os
|
||||||
|
|
||||||
test_net_dl = wenetspeech.valid_dataloaders(test_net_cuts)
|
from lhotse import CutSet
|
||||||
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
from lhotse.dataset.webdataset import export_to_webdataset
|
||||||
|
|
||||||
|
wenetspeech = WenetSpeechAsrDataModule(args)
|
||||||
|
|
||||||
|
test_net = "test_net"
|
||||||
|
test_meet = "test_meet"
|
||||||
|
|
||||||
|
if not os.path.exists(f"{test_net}/shared-0.tar"):
|
||||||
|
test_net_cuts = wenetspeech.test_net_cuts()
|
||||||
|
export_to_webdataset(
|
||||||
|
test_net_cuts,
|
||||||
|
output_path=f"{test_net}/shared-%d.tar",
|
||||||
|
shard_size=300,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not os.path.exists(f"{test_meet}/shared-0.tar"):
|
||||||
|
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
||||||
|
export_to_webdataset(
|
||||||
|
test_meeting_cuts,
|
||||||
|
output_path=f"{test_meet}/shared-%d.tar",
|
||||||
|
shard_size=300,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_net_shards = [
|
||||||
|
str(path)
|
||||||
|
for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
|
||||||
|
]
|
||||||
|
cuts_test_net_webdataset = CutSet.from_webdataset(
|
||||||
|
test_net_shards,
|
||||||
|
split_by_worker=True,
|
||||||
|
split_by_node=True,
|
||||||
|
shuffle_shards=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_meet_shards = [
|
||||||
|
str(path)
|
||||||
|
for path in sorted(glob.glob(os.path.join(test_meet, "shared-*.tar")))
|
||||||
|
]
|
||||||
|
cuts_test_meet_webdataset = CutSet.from_webdataset(
|
||||||
|
test_meet_shards,
|
||||||
|
split_by_worker=True,
|
||||||
|
split_by_node=True,
|
||||||
|
shuffle_shards=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
|
||||||
|
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meet_webdataset)
|
||||||
|
|
||||||
test_sets = ["TEST_NET", "TEST_MEETING"]
|
test_sets = ["TEST_NET", "TEST_MEETING"]
|
||||||
test_dl = [test_net_dl, test_meeting_dl]
|
test_dl = [test_net_dl, test_meeting_dl]
|
||||||
|
@ -45,6 +45,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -84,6 +85,8 @@ LRSchedulerType = Union[
|
|||||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
]
|
]
|
||||||
|
|
||||||
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -332,7 +335,7 @@ def get_params() -> AttributeDict:
|
|||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 2000,
|
"valid_interval": 3000,
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
@ -867,6 +870,7 @@ def run(rank, world_size, args):
|
|||||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
wenetspeech = WenetSpeechAsrDataModule(args)
|
||||||
|
|
||||||
train_cuts = wenetspeech.train_cuts()
|
train_cuts = wenetspeech.train_cuts()
|
||||||
|
valid_cuts = wenetspeech.valid_cuts()
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
@ -890,8 +894,8 @@ def run(rank, world_size, args):
|
|||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
if params.token_type == "pinyin":
|
if params.token_type == "pinyin":
|
||||||
train_cuts = train_cuts.map(text_to_words)
|
train_cuts = train_cuts.map(text_to_words)
|
||||||
|
# valid_cuts = valid_cuts.map(text_to_words)
|
||||||
|
|
||||||
valid_cuts = wenetspeech.valid_cuts()
|
|
||||||
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
|
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user