mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
14 lines
462 B
Python
14 lines
462 B
Python
import dataset
|
|
import torch
|
|
|
|
|
|
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
|
sampler = dataset.LmBatchSampler(test, symbols_per_batch=1000, world_size=2, rank=0)
|
|
a = iter(sampler)
|
|
print(str(next(a)))
|
|
|
|
collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=1, eos_sym=1, blank_sym=0, debug=True))
|
|
train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, collate_fn=collate_fn)
|
|
x = iter(train_dl)
|
|
print(str(next(x)))
|