From 7c3ab28a688ed4e58888677d900e9715788f2ed7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 17 Nov 2021 13:42:59 +0800 Subject: [PATCH] Use latest APIs from k2's master branch. --- egs/librispeech/ASR/conformer_lm/dataset.py | 15 ++++----------- egs/librispeech/ASR/conformer_lm/test_dataset.py | 2 +- .../ASR/conformer_lm/test_dataset_empty.py | 8 +++----- 3 files changed, 8 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py index 6c28c21ca..8d24873ed 100644 --- a/egs/librispeech/ASR/conformer_lm/dataset.py +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -30,17 +30,10 @@ class LmDataset(torch.utils.data.Dataset): Return the i'th sentence, as a list of ints (representing BPE pieces, without bos or eos symbols). """ - # in future will just do: - #return self.words[self.sentences[i]].tolist() - - # It would be nicer if we could just return self.sentences[i].tolist(), - # but for now that operator on k2.RaggedInt does not support when the - # ragged has only 2 axes. - row_splits = self.sentences.shape.row_splits(1) - (begin, end) = row_splits[i:i+2].tolist() - sentence = self.sentences.data[begin:end] - sentence, _ = self.words.index(sentence, axis=0, need_value_indexes=False) - return sentence.data.tolist() + # self.sentences[i] returns a 1-D tensor containing word indexes + # self.words[self.sentences[i]] returns a ragged tensor with axes + # [word][token]. + return self.words[self.sentences[i]].values.tolist() def load_train_test_lm_dataset(archive_fn: Union[str,Path], diff --git a/egs/librispeech/ASR/conformer_lm/test_dataset.py b/egs/librispeech/ASR/conformer_lm/test_dataset.py index 4cadaa939..3d6c3354b 100644 --- a/egs/librispeech/ASR/conformer_lm/test_dataset.py +++ b/egs/librispeech/ASR/conformer_lm/test_dataset.py @@ -18,7 +18,7 @@ if __name__ == '__main__': dist.init_process_group(backend="nccl", group_name="main", rank=0, world_size=1) - train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt') + train,test = dataset.load_train_test_lm_dataset('../data/lm_training_500/lm_data.pt') sampler = dataset.LmBatchSampler(test, symbols_per_batch=5000, world_size=2, rank=0) print("len(sampler) = ", len(sampler)) diff --git a/egs/librispeech/ASR/conformer_lm/test_dataset_empty.py b/egs/librispeech/ASR/conformer_lm/test_dataset_empty.py index 7e933f07b..573f206c7 100644 --- a/egs/librispeech/ASR/conformer_lm/test_dataset_empty.py +++ b/egs/librispeech/ASR/conformer_lm/test_dataset_empty.py @@ -1,6 +1,6 @@ +#!/usr/bin/env python3 import k2 import torch -import _k2 import dataset from dataset import LmDataset import os @@ -10,8 +10,6 @@ import torch.distributed as dist def local_collate_fn(sentences): return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=False) -x = _k2.RaggedInt('[[1]]') # make sure library initialized? - if __name__ == '__main__': mp.set_start_method('spawn') @@ -21,8 +19,8 @@ if __name__ == '__main__': dist.init_process_group(backend="nccl", group_name="main", rank=0, world_size=1) - words = k2.RaggedInt('[[0][1 2]]') - sentences = k2.RaggedInt('[[1][][][][][]]') + words = k2.RaggedTensor('[[0][1 2]]') + sentences = k2.RaggedTensor('[[1][][][][][]]') train = LmDataset(sentences, words)