mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
add webdataset for dataload
This commit is contained in:
parent
460ae4cb97
commit
84090511f8
@ -322,12 +322,10 @@ def decode_one_batch(
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
import time
|
||||
st1 = time.time()
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
ed1 = time.time()
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
@ -346,15 +344,12 @@ def decode_one_batch(
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
st2 = time.time()
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
)
|
||||
ed2 = time.time()
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
ed3 = time.time()
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
@ -438,8 +433,6 @@ def decode_dataset(
|
||||
else:
|
||||
log_interval = 2
|
||||
|
||||
import time
|
||||
ed = time.time()
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
@ -450,8 +443,6 @@ def decode_dataset(
|
||||
texts = [pinyin(text) for text in texts]
|
||||
for i in range(len(texts)):
|
||||
texts[i] = [token[0] for token in texts[i]]
|
||||
st = time.time()
|
||||
print(f"loading data time: {st - ed}")
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -459,7 +450,6 @@ def decode_dataset(
|
||||
batch=batch,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
ed = time.time()
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
@ -469,14 +459,14 @@ def decode_dataset(
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@ -594,11 +584,12 @@ def main():
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# Note: Please use "pip install webdataset==0.1.103"
|
||||
|
||||
# Note: Please use "pip install webdataset==0.1.103"
|
||||
# for installing the webdataset.
|
||||
import os
|
||||
import glob
|
||||
import os
|
||||
|
||||
from lhotse import CutSet
|
||||
from lhotse.dataset.webdataset import export_to_webdataset
|
||||
|
||||
@ -606,14 +597,15 @@ def main():
|
||||
|
||||
test_net = "test_net"
|
||||
test_meet = "test_meet"
|
||||
if os.path.exists(f"{test_net}/shared-0.tar"):
|
||||
if not os.path.exists(f"{test_net}/shared-0.tar"):
|
||||
test_net_cuts = wenetspeech.test_net_cuts()
|
||||
export_to_webdataset(
|
||||
test_net_cuts,
|
||||
output_path=f"{test_net}/shared-%d.tar",
|
||||
shard_size=300,
|
||||
)
|
||||
if os.path.exists(f"{test_meet}/shared-0.tar"):
|
||||
|
||||
if not os.path.exists(f"{test_meet}/shared-0.tar"):
|
||||
test_meeting_cuts = wenetspeech.test_meeting_cuts()
|
||||
export_to_webdataset(
|
||||
test_meeting_cuts,
|
||||
@ -621,15 +613,21 @@ def main():
|
||||
shard_size=300,
|
||||
)
|
||||
|
||||
test_net_shards = [str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))]
|
||||
test_net_shards = [
|
||||
str(path)
|
||||
for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
|
||||
]
|
||||
cuts_test_net_webdataset = CutSet.from_webdataset(
|
||||
test_net_shards,
|
||||
split_by_worker=True,
|
||||
split_by_node=True,
|
||||
shuffle_shards=True,
|
||||
)
|
||||
)
|
||||
|
||||
test_meet_shards = [str(path) for path in sorted(glob.glob(os.path.join(test_meet, "shared-*.tar")))]
|
||||
test_meet_shards = [
|
||||
str(path)
|
||||
for path in sorted(glob.glob(os.path.join(test_meet, "shared-*.tar")))
|
||||
]
|
||||
cuts_test_meet_webdataset = CutSet.from_webdataset(
|
||||
test_meet_shards,
|
||||
split_by_worker=True,
|
||||
|
Loading…
x
Reference in New Issue
Block a user