mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
Use latest APIs from k2's master branch.
This commit is contained in:
parent
b29e4bdd03
commit
7c3ab28a68
@ -30,17 +30,10 @@ class LmDataset(torch.utils.data.Dataset):
|
|||||||
Return the i'th sentence, as a list of ints (representing BPE pieces, without
|
Return the i'th sentence, as a list of ints (representing BPE pieces, without
|
||||||
bos or eos symbols).
|
bos or eos symbols).
|
||||||
"""
|
"""
|
||||||
# in future will just do:
|
# self.sentences[i] returns a 1-D tensor containing word indexes
|
||||||
#return self.words[self.sentences[i]].tolist()
|
# self.words[self.sentences[i]] returns a ragged tensor with axes
|
||||||
|
# [word][token].
|
||||||
# It would be nicer if we could just return self.sentences[i].tolist(),
|
return self.words[self.sentences[i]].values.tolist()
|
||||||
# but for now that operator on k2.RaggedInt does not support when the
|
|
||||||
# ragged has only 2 axes.
|
|
||||||
row_splits = self.sentences.shape.row_splits(1)
|
|
||||||
(begin, end) = row_splits[i:i+2].tolist()
|
|
||||||
sentence = self.sentences.data[begin:end]
|
|
||||||
sentence, _ = self.words.index(sentence, axis=0, need_value_indexes=False)
|
|
||||||
return sentence.data.tolist()
|
|
||||||
|
|
||||||
|
|
||||||
def load_train_test_lm_dataset(archive_fn: Union[str,Path],
|
def load_train_test_lm_dataset(archive_fn: Union[str,Path],
|
||||||
|
@ -18,7 +18,7 @@ if __name__ == '__main__':
|
|||||||
dist.init_process_group(backend="nccl", group_name="main",
|
dist.init_process_group(backend="nccl", group_name="main",
|
||||||
rank=0, world_size=1)
|
rank=0, world_size=1)
|
||||||
|
|
||||||
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_500/lm_data.pt')
|
||||||
sampler = dataset.LmBatchSampler(test, symbols_per_batch=5000, world_size=2, rank=0)
|
sampler = dataset.LmBatchSampler(test, symbols_per_batch=5000, world_size=2, rank=0)
|
||||||
print("len(sampler) = ", len(sampler))
|
print("len(sampler) = ", len(sampler))
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import _k2
|
|
||||||
import dataset
|
import dataset
|
||||||
from dataset import LmDataset
|
from dataset import LmDataset
|
||||||
import os
|
import os
|
||||||
@ -10,8 +10,6 @@ import torch.distributed as dist
|
|||||||
def local_collate_fn(sentences):
|
def local_collate_fn(sentences):
|
||||||
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=False)
|
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=False)
|
||||||
|
|
||||||
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
mp.set_start_method('spawn')
|
mp.set_start_method('spawn')
|
||||||
@ -21,8 +19,8 @@ if __name__ == '__main__':
|
|||||||
dist.init_process_group(backend="nccl", group_name="main",
|
dist.init_process_group(backend="nccl", group_name="main",
|
||||||
rank=0, world_size=1)
|
rank=0, world_size=1)
|
||||||
|
|
||||||
words = k2.RaggedInt('[[0][1 2]]')
|
words = k2.RaggedTensor('[[0][1 2]]')
|
||||||
sentences = k2.RaggedInt('[[1][][][][][]]')
|
sentences = k2.RaggedTensor('[[1][][][][][]]')
|
||||||
|
|
||||||
train = LmDataset(sentences, words)
|
train = LmDataset(sentences, words)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user