mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
import k2
|
|
import torch
|
|
import _k2
|
|
import dataset
|
|
import os
|
|
from torch import multiprocessing as mp
|
|
import torch.distributed as dist
|
|
|
|
def local_collate_fn(sentences):
|
|
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)
|
|
|
|
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
|
|
|
|
if __name__ == '__main__':
|
|
|
|
#mp.set_start_method('spawn')
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
os.environ["MASTER_PORT"] = "12344"
|
|
|
|
dist.init_process_group(backend="nccl", group_name="main",
|
|
rank=0, world_size=1)
|
|
|
|
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
|
sampler = dataset.LmBatchSampler(test, symbols_per_batch=5000, world_size=2, rank=0)
|
|
print("len(sampler) = ", len(sampler))
|
|
|
|
a = iter(sampler)
|
|
print(str(next(a)))
|
|
|
|
train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler,
|
|
collate_fn=local_collate_fn,
|
|
num_workers=2)
|
|
x = iter(train_dl)
|
|
print(str(next(x)))
|