Use latest APIs from k2's master branch.

This commit is contained in:
Fangjun Kuang 2021-11-17 13:42:59 +08:00
parent b29e4bdd03
commit 7c3ab28a68
3 changed files with 8 additions and 17 deletions

View File

@ -30,17 +30,10 @@ class LmDataset(torch.utils.data.Dataset):
Return the i'th sentence, as a list of ints (representing BPE pieces, without
bos or eos symbols).
"""
# in future will just do:
#return self.words[self.sentences[i]].tolist()
# It would be nicer if we could just return self.sentences[i].tolist(),
# but for now that operator on k2.RaggedInt does not support when the
# ragged has only 2 axes.
row_splits = self.sentences.shape.row_splits(1)
(begin, end) = row_splits[i:i+2].tolist()
sentence = self.sentences.data[begin:end]
sentence, _ = self.words.index(sentence, axis=0, need_value_indexes=False)
return sentence.data.tolist()
# self.sentences[i] returns a 1-D tensor containing word indexes
# self.words[self.sentences[i]] returns a ragged tensor with axes
# [word][token].
return self.words[self.sentences[i]].values.tolist()
def load_train_test_lm_dataset(archive_fn: Union[str,Path],

View File

@ -18,7 +18,7 @@ if __name__ == '__main__':
dist.init_process_group(backend="nccl", group_name="main",
rank=0, world_size=1)
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_500/lm_data.pt')
sampler = dataset.LmBatchSampler(test, symbols_per_batch=5000, world_size=2, rank=0)
print("len(sampler) = ", len(sampler))

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3
import k2
import torch
import _k2
import dataset
from dataset import LmDataset
import os
@ -10,8 +10,6 @@ import torch.distributed as dist
def local_collate_fn(sentences):
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=False)
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
if __name__ == '__main__':
mp.set_start_method('spawn')
@ -21,8 +19,8 @@ if __name__ == '__main__':
dist.init_process_group(backend="nccl", group_name="main",
rank=0, world_size=1)
words = k2.RaggedInt('[[0][1 2]]')
sentences = k2.RaggedInt('[[1][][][][][]]')
words = k2.RaggedTensor('[[0][1 2]]')
sentences = k2.RaggedTensor('[[1][][][][][]]')
train = LmDataset(sentences, words)