add webdataset for dataload

This commit is contained in:
luomingshuang 2022-04-14 21:45:00 +08:00
parent 460ae4cb97
commit 84090511f8

View File

@ -322,12 +322,10 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
import time
st1 = time.time()
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
ed1 = time.time()
hyps = []
if params.decoding_method == "fast_beam_search":
@ -346,15 +344,12 @@ def decode_one_batch(
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
st2 = time.time()
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
ed2 = time.time()
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
ed3 = time.time()
else:
batch_size = encoder_out.size(0)
@ -438,8 +433,6 @@ def decode_dataset(
else:
log_interval = 2
import time
ed = time.time()
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
@ -450,8 +443,6 @@ def decode_dataset(
texts = [pinyin(text) for text in texts]
for i in range(len(texts)):
texts[i] = [token[0] for token in texts[i]]
st = time.time()
print(f"loading data time: {st - ed}")
hyps_dict = decode_one_batch(
params=params,
model=model,
@ -459,7 +450,6 @@ def decode_dataset(
batch=batch,
decoding_graph=decoding_graph,
)
ed = time.time()
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
@ -597,8 +587,9 @@ def main():
# Note: Please use "pip install webdataset==0.1.103"
# for installing the webdataset.
import os
import glob
import os
from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset
@ -606,14 +597,15 @@ def main():
test_net = "test_net"
test_meet = "test_meet"
if os.path.exists(f"{test_net}/shared-0.tar"):
if not os.path.exists(f"{test_net}/shared-0.tar"):
test_net_cuts = wenetspeech.test_net_cuts()
export_to_webdataset(
test_net_cuts,
output_path=f"{test_net}/shared-%d.tar",
shard_size=300,
)
if os.path.exists(f"{test_meet}/shared-0.tar"):
if not os.path.exists(f"{test_meet}/shared-0.tar"):
test_meeting_cuts = wenetspeech.test_meeting_cuts()
export_to_webdataset(
test_meeting_cuts,
@ -621,7 +613,10 @@ def main():
shard_size=300,
)
test_net_shards = [str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))]
test_net_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
]
cuts_test_net_webdataset = CutSet.from_webdataset(
test_net_shards,
split_by_worker=True,
@ -629,7 +624,10 @@ def main():
shuffle_shards=True,
)
test_meet_shards = [str(path) for path in sorted(glob.glob(os.path.join(test_meet, "shared-*.tar")))]
test_meet_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_meet, "shared-*.tar")))
]
cuts_test_meet_webdataset = CutSet.from_webdataset(
test_meet_shards,
split_by_worker=True,