diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py index 4f466a9e1..6c28c21ca 100644 --- a/egs/librispeech/ASR/conformer_lm/dataset.py +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -14,8 +14,8 @@ class LmDataset(torch.utils.data.Dataset): Torch dataset for language modeling data. This is a map-style dataset. The indices are integers. """ - def __init__(self, sentences: k2.RaggedInt, - words: k2.RaggedInt): + def __init__(self, sentences: k2.RaggedTensor, + words: k2.RaggedTensor): super(LmDataset, self).__init__() self.sentences = sentences self.words = words @@ -30,12 +30,17 @@ class LmDataset(torch.utils.data.Dataset): Return the i'th sentence, as a list of ints (representing BPE pieces, without bos or eos symbols). """ - # It would be nicer if we could just return self.sentences[i].tolist(), but - # for now that operator on k2.RaggedInt is not implemented. - row_splits = self.sentences.row_splits(1) + # in future will just do: + #return self.words[self.sentences[i]].tolist() + + # It would be nicer if we could just return self.sentences[i].tolist(), + # but for now that operator on k2.RaggedInt does not support when the + # ragged has only 2 axes. + row_splits = self.sentences.shape.row_splits(1) (begin, end) = row_splits[i:i+2].tolist() - sentence = self.sentences.values()[begin:end] - return k2.index(self.words, sentence).values().tolist() + sentence = self.sentences.data[begin:end] + sentence, _ = self.words.index(sentence, axis=0, need_value_indexes=False) + return sentence.data.tolist() def load_train_test_lm_dataset(archive_fn: Union[str,Path], @@ -45,22 +50,21 @@ def load_train_test_lm_dataset(archive_fn: Union[str,Path], """ d = torch.load(archive_fn) - words = d['words'] # a k2.RaggedInt with 2 axes, maps from word-ids to sequences of BPE pieces - sentences = d['data'] # a k2.RaggedInt + words = d['words'] # a k2.RaggedTensor with 2 axes, maps from word-ids to sequences of BPE pieces + sentences = d['data'] # a k2.RaggedTensor with torch.random.fork_rng(devices=[]): g = torch.manual_seed(0) num_sentences = sentences.tot_size(0) # probably the generator (g) argument to torch.randperm below is not necessary. sentence_perm = torch.randperm(num_sentences, generator=g, dtype=torch.int32) - sentences = k2.index(sentences, sentence_perm) + sentences, _ = sentences.index(sentence_perm, axis=0, need_value_indexes=False) num_test_sentences = int(num_sentences * test_proportion) axis=0 - train_sents = _k2.ragged_int_arange(sentences, axis, - num_test_sentences, num_sentences) - test_sents = _k2.ragged_int_arange(sentences, axis, 0, num_test_sentences) + train_sents = sentences.arange(axis, num_test_sentences, num_sentences) + test_sents = sentences.arange(axis, 0, num_test_sentences) return LmDataset(train_sents, words), LmDataset(test_sents, words) @@ -683,27 +687,25 @@ class LmBatchSampler(torch.utils.data.Sampler): # sampler is reponsible for (all of them, in the non-distributed case). data_indexes = torch.arange(self.rank, len(dataset), self.world_size, dtype=torch.int32) # dtype=torch.int32 - word_row_splits = dataset.words.row_splits(1) # dtype=torch.int32 + word_row_splits = dataset.words.shape.row_splits(1) # dtype=torch.int32 word_lengths = word_row_splits[1:] - word_row_splits[:-1] # dtype=torch.int32 # the sentences this sampler is responsible for, as sequences of words. # It's a ragged tensor of int32 - sentences = k2.index(dataset.sentences, data_indexes) + sentences, _ = dataset.sentences.index(data_indexes, axis=0) - # sentence_lengths is a k2.RaggedInt like `sentences`, but with the words replaced + # sentence_lengths is a k2.RaggedTensor like `sentences`, but with the words replaced # with their respective lengths, in BPE pieces. - sentence_lengths = k2.index(word_lengths, sentences) + sentence_lengths = k2.ragged.index(word_lengths, sentences) del sentences # save memory - assert isinstance(sentence_lengths, k2.RaggedInt) + assert isinstance(sentence_lengths, k2.RaggedTensor) # convert to float so sum_per_sublist() will work (TODO: sum_per_sublist() will eventually # support int32.) - sentence_lengths = k2.RaggedFloat(sentence_lengths.shape(), - sentence_lengths.values().to(torch.float32)) - assert isinstance(sentence_lengths, k2.RaggedFloat) + sentence_lengths = sentence_lengths.to(dtype=torch.float32) # Convert into a simple tensor of float by adding lengths of words. - sentence_lengths = k2.ragged.sum_per_sublist(sentence_lengths) + sentence_lengths = sentence_lengths.sum() assert isinstance(sentence_lengths, torch.Tensor) assert sentence_lengths.dtype == torch.float32 diff --git a/egs/librispeech/ASR/conformer_lm/test_dataset.py b/egs/librispeech/ASR/conformer_lm/test_dataset.py index b82da7899..4cadaa939 100644 --- a/egs/librispeech/ASR/conformer_lm/test_dataset.py +++ b/egs/librispeech/ASR/conformer_lm/test_dataset.py @@ -9,8 +9,6 @@ import torch.distributed as dist def local_collate_fn(sentences): return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True) -x = _k2.RaggedInt('[[1]]') # make sure library initialized? - if __name__ == '__main__': #mp.set_start_method('spawn') diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index 04c3f8ccd..2d1c1a4c3 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -136,13 +136,14 @@ def get_params() -> AttributeDict: # exp_4, vs. exp_3, is using the Gloam optimizer with # in exp_5, vs. exp_4, we change Gloam to have a 1/sqrt(t) factor # as well as the exponential part. - "exp_dir": Path("conformer_lm/exp_5"), + # exp_6, we change the decay from 0.85 to 0.9. + "exp_dir": Path("conformer_lm/exp_6"), "lm_dataset": Path("data/lm_training_5000/lm_data.pt"), "num_tokens": 5000, "blank_sym": 0, "bos_sym": 1, "eos_sym": 1, - "start_epoch": 0, + "start_epoch": 2, "num_epochs": 20, "num_valid_batches": 200, "symbols_per_batch": 5000, @@ -529,7 +530,7 @@ def run(rank, world_size, args): model.parameters(), max_lrate=params.max_lrate, first_decrease_epoch=1, - decay_per_epoch=0.85 + decay_per_epoch=0.9 ) if checkpoints: diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py index b6e0931f4..a836bb017 100755 --- a/egs/librispeech/ASR/local/prepare_lm_training_data.py +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -82,8 +82,8 @@ def main(): sentences.append([ word2index[w] for w in line_words]) output = dict() - output['words' ] = k2.ragged.create_ragged2(words2bpe) - output['data'] = k2.ragged.create_ragged2(sentences) + output['words' ] = k2.ragged.RaggedTensor(words2bpe) + output['data'] = k2.ragged.RaggedTensor(sentences) torch.save(output, args.lm_archive) print(f"Saved to {args.lm_archive}") diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 94c408c6e..0e7dc510f 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -3,7 +3,7 @@ set -eou pipefail nj=15 -stage=-1 +stage=9 stop_stage=100 # We assume dl_dir (download dir) contains the following @@ -195,7 +195,7 @@ fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} - lm_dir=lm_dir=data/lm_training_${vocab_size} + lm_dir=data/lm_training_${vocab_size} mkdir -p $lm_dir log "Stage 9: creating $lm_dir/lm_data.pt" ./local/prepare_lm_training_data.py data/lang_bpe_${vocab_size}/bpe.model download/lm/librispeech-lm-norm.txt $lm_dir/lm_data.pt