UPdates for new k2 version; change LR decay from 0.85 to 0.9

This commit is contained in:
Daniel Povey 2021-09-13 20:57:02 +08:00
parent d0e5b9b8a5
commit 3ce1de337d
5 changed files with 32 additions and 31 deletions

View File

@ -14,8 +14,8 @@ class LmDataset(torch.utils.data.Dataset):
Torch dataset for language modeling data. This is a map-style dataset.
The indices are integers.
"""
def __init__(self, sentences: k2.RaggedInt,
words: k2.RaggedInt):
def __init__(self, sentences: k2.RaggedTensor,
words: k2.RaggedTensor):
super(LmDataset, self).__init__()
self.sentences = sentences
self.words = words
@ -30,12 +30,17 @@ class LmDataset(torch.utils.data.Dataset):
Return the i'th sentence, as a list of ints (representing BPE pieces, without
bos or eos symbols).
"""
# It would be nicer if we could just return self.sentences[i].tolist(), but
# for now that operator on k2.RaggedInt is not implemented.
row_splits = self.sentences.row_splits(1)
# in future will just do:
#return self.words[self.sentences[i]].tolist()
# It would be nicer if we could just return self.sentences[i].tolist(),
# but for now that operator on k2.RaggedInt does not support when the
# ragged has only 2 axes.
row_splits = self.sentences.shape.row_splits(1)
(begin, end) = row_splits[i:i+2].tolist()
sentence = self.sentences.values()[begin:end]
return k2.index(self.words, sentence).values().tolist()
sentence = self.sentences.data[begin:end]
sentence, _ = self.words.index(sentence, axis=0, need_value_indexes=False)
return sentence.data.tolist()
def load_train_test_lm_dataset(archive_fn: Union[str,Path],
@ -45,22 +50,21 @@ def load_train_test_lm_dataset(archive_fn: Union[str,Path],
"""
d = torch.load(archive_fn)
words = d['words'] # a k2.RaggedInt with 2 axes, maps from word-ids to sequences of BPE pieces
sentences = d['data'] # a k2.RaggedInt
words = d['words'] # a k2.RaggedTensor with 2 axes, maps from word-ids to sequences of BPE pieces
sentences = d['data'] # a k2.RaggedTensor
with torch.random.fork_rng(devices=[]):
g = torch.manual_seed(0)
num_sentences = sentences.tot_size(0)
# probably the generator (g) argument to torch.randperm below is not necessary.
sentence_perm = torch.randperm(num_sentences, generator=g, dtype=torch.int32)
sentences = k2.index(sentences, sentence_perm)
sentences, _ = sentences.index(sentence_perm, axis=0, need_value_indexes=False)
num_test_sentences = int(num_sentences * test_proportion)
axis=0
train_sents = _k2.ragged_int_arange(sentences, axis,
num_test_sentences, num_sentences)
test_sents = _k2.ragged_int_arange(sentences, axis, 0, num_test_sentences)
train_sents = sentences.arange(axis, num_test_sentences, num_sentences)
test_sents = sentences.arange(axis, 0, num_test_sentences)
return LmDataset(train_sents, words), LmDataset(test_sents, words)
@ -683,27 +687,25 @@ class LmBatchSampler(torch.utils.data.Sampler):
# sampler is reponsible for (all of them, in the non-distributed case).
data_indexes = torch.arange(self.rank, len(dataset), self.world_size, dtype=torch.int32) # dtype=torch.int32
word_row_splits = dataset.words.row_splits(1) # dtype=torch.int32
word_row_splits = dataset.words.shape.row_splits(1) # dtype=torch.int32
word_lengths = word_row_splits[1:] - word_row_splits[:-1] # dtype=torch.int32
# the sentences this sampler is responsible for, as sequences of words.
# It's a ragged tensor of int32
sentences = k2.index(dataset.sentences, data_indexes)
sentences, _ = dataset.sentences.index(data_indexes, axis=0)
# sentence_lengths is a k2.RaggedInt like `sentences`, but with the words replaced
# sentence_lengths is a k2.RaggedTensor like `sentences`, but with the words replaced
# with their respective lengths, in BPE pieces.
sentence_lengths = k2.index(word_lengths, sentences)
sentence_lengths = k2.ragged.index(word_lengths, sentences)
del sentences # save memory
assert isinstance(sentence_lengths, k2.RaggedInt)
assert isinstance(sentence_lengths, k2.RaggedTensor)
# convert to float so sum_per_sublist() will work (TODO: sum_per_sublist() will eventually
# support int32.)
sentence_lengths = k2.RaggedFloat(sentence_lengths.shape(),
sentence_lengths.values().to(torch.float32))
assert isinstance(sentence_lengths, k2.RaggedFloat)
sentence_lengths = sentence_lengths.to(dtype=torch.float32)
# Convert into a simple tensor of float by adding lengths of words.
sentence_lengths = k2.ragged.sum_per_sublist(sentence_lengths)
sentence_lengths = sentence_lengths.sum()
assert isinstance(sentence_lengths, torch.Tensor)
assert sentence_lengths.dtype == torch.float32

View File

@ -9,8 +9,6 @@ import torch.distributed as dist
def local_collate_fn(sentences):
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
if __name__ == '__main__':
#mp.set_start_method('spawn')

View File

@ -136,13 +136,14 @@ def get_params() -> AttributeDict:
# exp_4, vs. exp_3, is using the Gloam optimizer with
# in exp_5, vs. exp_4, we change Gloam to have a 1/sqrt(t) factor
# as well as the exponential part.
"exp_dir": Path("conformer_lm/exp_5"),
# exp_6, we change the decay from 0.85 to 0.9.
"exp_dir": Path("conformer_lm/exp_6"),
"lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
"num_tokens": 5000,
"blank_sym": 0,
"bos_sym": 1,
"eos_sym": 1,
"start_epoch": 0,
"start_epoch": 2,
"num_epochs": 20,
"num_valid_batches": 200,
"symbols_per_batch": 5000,
@ -529,7 +530,7 @@ def run(rank, world_size, args):
model.parameters(),
max_lrate=params.max_lrate,
first_decrease_epoch=1,
decay_per_epoch=0.85
decay_per_epoch=0.9
)
if checkpoints:

View File

@ -82,8 +82,8 @@ def main():
sentences.append([ word2index[w] for w in line_words])
output = dict()
output['words' ] = k2.ragged.create_ragged2(words2bpe)
output['data'] = k2.ragged.create_ragged2(sentences)
output['words' ] = k2.ragged.RaggedTensor(words2bpe)
output['data'] = k2.ragged.RaggedTensor(sentences)
torch.save(output, args.lm_archive)
print(f"Saved to {args.lm_archive}")

View File

@ -3,7 +3,7 @@
set -eou pipefail
nj=15
stage=-1
stage=9
stop_stage=100
# We assume dl_dir (download dir) contains the following
@ -195,7 +195,7 @@ fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
lm_dir=lm_dir=data/lm_training_${vocab_size}
lm_dir=data/lm_training_${vocab_size}
mkdir -p $lm_dir
log "Stage 9: creating $lm_dir/lm_data.pt"
./local/prepare_lm_training_data.py data/lang_bpe_${vocab_size}/bpe.model download/lm/librispeech-lm-norm.txt $lm_dir/lm_data.pt