mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
add webdataset for dataload
This commit is contained in:
parent
8a854d2130
commit
460ae4cb97
@ -34,12 +34,14 @@ from lhotse.cut import Cut
|
|||||||
from lhotse.dataset import (
|
from lhotse.dataset import (
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
CutMix,
|
CutMix,
|
||||||
|
BucketingSampler,
|
||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SingleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
|
from lhotse.dataset.webdataset import export_to_webdataset
|
||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -361,10 +363,15 @@ class WenetSpeechAsrDataModule:
|
|||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
cuts, max_duration=self.args.max_duration, shuffle=False
|
cuts, max_duration=self.args.max_duration, shuffle=False
|
||||||
)
|
)
|
||||||
test_dl = DataLoader(
|
|
||||||
test,
|
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
|
||||||
batch_size=None,
|
test_iter_dataset = IterableDatasetWrapper(
|
||||||
|
dataset=test,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
|
)
|
||||||
|
test_dl = DataLoader(
|
||||||
|
test_iter_dataset,
|
||||||
|
batch_size=None,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
)
|
)
|
||||||
return test_dl
|
return test_dl
|
||||||
|
@ -206,7 +206,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
default=3,
|
default=1,
|
||||||
help="""Maximum number of symbols per frame.
|
help="""Maximum number of symbols per frame.
|
||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
@ -322,10 +322,12 @@ def decode_one_batch(
|
|||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
import time
|
||||||
|
st1 = time.time()
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
|
ed1 = time.time()
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
@ -344,12 +346,15 @@ def decode_one_batch(
|
|||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
):
|
):
|
||||||
|
st2 = time.time()
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
)
|
)
|
||||||
|
ed2 = time.time()
|
||||||
for i in range(encoder_out.size(0)):
|
for i in range(encoder_out.size(0)):
|
||||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
|
ed3 = time.time()
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
@ -433,6 +438,8 @@ def decode_dataset(
|
|||||||
else:
|
else:
|
||||||
log_interval = 2
|
log_interval = 2
|
||||||
|
|
||||||
|
import time
|
||||||
|
ed = time.time()
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
@ -443,7 +450,8 @@ def decode_dataset(
|
|||||||
texts = [pinyin(text) for text in texts]
|
texts = [pinyin(text) for text in texts]
|
||||||
for i in range(len(texts)):
|
for i in range(len(texts)):
|
||||||
texts[i] = [token[0] for token in texts[i]]
|
texts[i] = [token[0] for token in texts[i]]
|
||||||
|
st = time.time()
|
||||||
|
print(f"loading data time: {st - ed}")
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -451,6 +459,7 @@ def decode_dataset(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
|
ed = time.time()
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
@ -460,13 +469,14 @@ def decode_dataset(
|
|||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
num_cuts += len(texts)
|
num_cuts += len(texts)
|
||||||
|
|
||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -584,13 +594,51 @@ def main():
|
|||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# Note: Please use "pip install webdataset==0.1.103"
|
||||||
|
# for installing the webdataset.
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
from lhotse import CutSet
|
||||||
|
from lhotse.dataset.webdataset import export_to_webdataset
|
||||||
|
|
||||||
wenetspeech = WenetSpeechAsrDataModule(args)
|
wenetspeech = WenetSpeechAsrDataModule(args)
|
||||||
test_net_cuts = wenetspeech.test_net_cuts()
|
|
||||||
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
|
||||||
|
|
||||||
test_net_dl = wenetspeech.valid_dataloaders(test_net_cuts)
|
test_net = "test_net"
|
||||||
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
|
test_meet = "test_meet"
|
||||||
|
if os.path.exists(f"{test_net}/shared-0.tar"):
|
||||||
|
test_net_cuts = wenetspeech.test_net_cuts()
|
||||||
|
export_to_webdataset(
|
||||||
|
test_net_cuts,
|
||||||
|
output_path=f"{test_net}/shared-%d.tar",
|
||||||
|
shard_size=300,
|
||||||
|
)
|
||||||
|
if os.path.exists(f"{test_meet}/shared-0.tar"):
|
||||||
|
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
||||||
|
export_to_webdataset(
|
||||||
|
test_meeting_cuts,
|
||||||
|
output_path=f"{test_meet}/shared-%d.tar",
|
||||||
|
shard_size=300,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_net_shards = [str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))]
|
||||||
|
cuts_test_net_webdataset = CutSet.from_webdataset(
|
||||||
|
test_net_shards,
|
||||||
|
split_by_worker=True,
|
||||||
|
split_by_node=True,
|
||||||
|
shuffle_shards=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_meet_shards = [str(path) for path in sorted(glob.glob(os.path.join(test_meet, "shared-*.tar")))]
|
||||||
|
cuts_test_meet_webdataset = CutSet.from_webdataset(
|
||||||
|
test_meet_shards,
|
||||||
|
split_by_worker=True,
|
||||||
|
split_by_node=True,
|
||||||
|
shuffle_shards=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
|
||||||
|
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meet_webdataset)
|
||||||
|
|
||||||
test_sets = ["TEST_NET", "TEST_MEETING"]
|
test_sets = ["TEST_NET", "TEST_MEETING"]
|
||||||
test_dl = [test_net_dl, test_meeting_dl]
|
test_dl = [test_net_dl, test_meeting_dl]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user