mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Use latest APIs from k2's master branch.
This commit is contained in:
parent
b29e4bdd03
commit
7c3ab28a68
@ -30,17 +30,10 @@ class LmDataset(torch.utils.data.Dataset):
|
||||
Return the i'th sentence, as a list of ints (representing BPE pieces, without
|
||||
bos or eos symbols).
|
||||
"""
|
||||
# in future will just do:
|
||||
#return self.words[self.sentences[i]].tolist()
|
||||
|
||||
# It would be nicer if we could just return self.sentences[i].tolist(),
|
||||
# but for now that operator on k2.RaggedInt does not support when the
|
||||
# ragged has only 2 axes.
|
||||
row_splits = self.sentences.shape.row_splits(1)
|
||||
(begin, end) = row_splits[i:i+2].tolist()
|
||||
sentence = self.sentences.data[begin:end]
|
||||
sentence, _ = self.words.index(sentence, axis=0, need_value_indexes=False)
|
||||
return sentence.data.tolist()
|
||||
# self.sentences[i] returns a 1-D tensor containing word indexes
|
||||
# self.words[self.sentences[i]] returns a ragged tensor with axes
|
||||
# [word][token].
|
||||
return self.words[self.sentences[i]].values.tolist()
|
||||
|
||||
|
||||
def load_train_test_lm_dataset(archive_fn: Union[str,Path],
|
||||
|
@ -18,7 +18,7 @@ if __name__ == '__main__':
|
||||
dist.init_process_group(backend="nccl", group_name="main",
|
||||
rank=0, world_size=1)
|
||||
|
||||
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
||||
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_500/lm_data.pt')
|
||||
sampler = dataset.LmBatchSampler(test, symbols_per_batch=5000, world_size=2, rank=0)
|
||||
print("len(sampler) = ", len(sampler))
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
import k2
|
||||
import torch
|
||||
import _k2
|
||||
import dataset
|
||||
from dataset import LmDataset
|
||||
import os
|
||||
@ -10,8 +10,6 @@ import torch.distributed as dist
|
||||
def local_collate_fn(sentences):
|
||||
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=False)
|
||||
|
||||
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
mp.set_start_method('spawn')
|
||||
@ -21,8 +19,8 @@ if __name__ == '__main__':
|
||||
dist.init_process_group(backend="nccl", group_name="main",
|
||||
rank=0, world_size=1)
|
||||
|
||||
words = k2.RaggedInt('[[0][1 2]]')
|
||||
sentences = k2.RaggedInt('[[1][][][][][]]')
|
||||
words = k2.RaggedTensor('[[0][1 2]]')
|
||||
sentences = k2.RaggedTensor('[[1][][][][][]]')
|
||||
|
||||
train = LmDataset(sentences, words)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user