mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
UPdates for new k2 version; change LR decay from 0.85 to 0.9
This commit is contained in:
parent
d0e5b9b8a5
commit
3ce1de337d
@ -14,8 +14,8 @@ class LmDataset(torch.utils.data.Dataset):
|
||||
Torch dataset for language modeling data. This is a map-style dataset.
|
||||
The indices are integers.
|
||||
"""
|
||||
def __init__(self, sentences: k2.RaggedInt,
|
||||
words: k2.RaggedInt):
|
||||
def __init__(self, sentences: k2.RaggedTensor,
|
||||
words: k2.RaggedTensor):
|
||||
super(LmDataset, self).__init__()
|
||||
self.sentences = sentences
|
||||
self.words = words
|
||||
@ -30,12 +30,17 @@ class LmDataset(torch.utils.data.Dataset):
|
||||
Return the i'th sentence, as a list of ints (representing BPE pieces, without
|
||||
bos or eos symbols).
|
||||
"""
|
||||
# It would be nicer if we could just return self.sentences[i].tolist(), but
|
||||
# for now that operator on k2.RaggedInt is not implemented.
|
||||
row_splits = self.sentences.row_splits(1)
|
||||
# in future will just do:
|
||||
#return self.words[self.sentences[i]].tolist()
|
||||
|
||||
# It would be nicer if we could just return self.sentences[i].tolist(),
|
||||
# but for now that operator on k2.RaggedInt does not support when the
|
||||
# ragged has only 2 axes.
|
||||
row_splits = self.sentences.shape.row_splits(1)
|
||||
(begin, end) = row_splits[i:i+2].tolist()
|
||||
sentence = self.sentences.values()[begin:end]
|
||||
return k2.index(self.words, sentence).values().tolist()
|
||||
sentence = self.sentences.data[begin:end]
|
||||
sentence, _ = self.words.index(sentence, axis=0, need_value_indexes=False)
|
||||
return sentence.data.tolist()
|
||||
|
||||
|
||||
def load_train_test_lm_dataset(archive_fn: Union[str,Path],
|
||||
@ -45,22 +50,21 @@ def load_train_test_lm_dataset(archive_fn: Union[str,Path],
|
||||
"""
|
||||
|
||||
d = torch.load(archive_fn)
|
||||
words = d['words'] # a k2.RaggedInt with 2 axes, maps from word-ids to sequences of BPE pieces
|
||||
sentences = d['data'] # a k2.RaggedInt
|
||||
words = d['words'] # a k2.RaggedTensor with 2 axes, maps from word-ids to sequences of BPE pieces
|
||||
sentences = d['data'] # a k2.RaggedTensor
|
||||
|
||||
with torch.random.fork_rng(devices=[]):
|
||||
g = torch.manual_seed(0)
|
||||
num_sentences = sentences.tot_size(0)
|
||||
# probably the generator (g) argument to torch.randperm below is not necessary.
|
||||
sentence_perm = torch.randperm(num_sentences, generator=g, dtype=torch.int32)
|
||||
sentences = k2.index(sentences, sentence_perm)
|
||||
sentences, _ = sentences.index(sentence_perm, axis=0, need_value_indexes=False)
|
||||
|
||||
num_test_sentences = int(num_sentences * test_proportion)
|
||||
|
||||
axis=0
|
||||
train_sents = _k2.ragged_int_arange(sentences, axis,
|
||||
num_test_sentences, num_sentences)
|
||||
test_sents = _k2.ragged_int_arange(sentences, axis, 0, num_test_sentences)
|
||||
train_sents = sentences.arange(axis, num_test_sentences, num_sentences)
|
||||
test_sents = sentences.arange(axis, 0, num_test_sentences)
|
||||
|
||||
return LmDataset(train_sents, words), LmDataset(test_sents, words)
|
||||
|
||||
@ -683,27 +687,25 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
||||
# sampler is reponsible for (all of them, in the non-distributed case).
|
||||
data_indexes = torch.arange(self.rank, len(dataset), self.world_size, dtype=torch.int32) # dtype=torch.int32
|
||||
|
||||
word_row_splits = dataset.words.row_splits(1) # dtype=torch.int32
|
||||
word_row_splits = dataset.words.shape.row_splits(1) # dtype=torch.int32
|
||||
word_lengths = word_row_splits[1:] - word_row_splits[:-1] # dtype=torch.int32
|
||||
|
||||
# the sentences this sampler is responsible for, as sequences of words.
|
||||
# It's a ragged tensor of int32
|
||||
sentences = k2.index(dataset.sentences, data_indexes)
|
||||
sentences, _ = dataset.sentences.index(data_indexes, axis=0)
|
||||
|
||||
# sentence_lengths is a k2.RaggedInt like `sentences`, but with the words replaced
|
||||
# sentence_lengths is a k2.RaggedTensor like `sentences`, but with the words replaced
|
||||
# with their respective lengths, in BPE pieces.
|
||||
sentence_lengths = k2.index(word_lengths, sentences)
|
||||
sentence_lengths = k2.ragged.index(word_lengths, sentences)
|
||||
del sentences # save memory
|
||||
assert isinstance(sentence_lengths, k2.RaggedInt)
|
||||
assert isinstance(sentence_lengths, k2.RaggedTensor)
|
||||
|
||||
# convert to float so sum_per_sublist() will work (TODO: sum_per_sublist() will eventually
|
||||
# support int32.)
|
||||
sentence_lengths = k2.RaggedFloat(sentence_lengths.shape(),
|
||||
sentence_lengths.values().to(torch.float32))
|
||||
assert isinstance(sentence_lengths, k2.RaggedFloat)
|
||||
sentence_lengths = sentence_lengths.to(dtype=torch.float32)
|
||||
|
||||
# Convert into a simple tensor of float by adding lengths of words.
|
||||
sentence_lengths = k2.ragged.sum_per_sublist(sentence_lengths)
|
||||
sentence_lengths = sentence_lengths.sum()
|
||||
|
||||
assert isinstance(sentence_lengths, torch.Tensor)
|
||||
assert sentence_lengths.dtype == torch.float32
|
||||
|
@ -9,8 +9,6 @@ import torch.distributed as dist
|
||||
def local_collate_fn(sentences):
|
||||
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)
|
||||
|
||||
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
#mp.set_start_method('spawn')
|
||||
|
@ -136,13 +136,14 @@ def get_params() -> AttributeDict:
|
||||
# exp_4, vs. exp_3, is using the Gloam optimizer with
|
||||
# in exp_5, vs. exp_4, we change Gloam to have a 1/sqrt(t) factor
|
||||
# as well as the exponential part.
|
||||
"exp_dir": Path("conformer_lm/exp_5"),
|
||||
# exp_6, we change the decay from 0.85 to 0.9.
|
||||
"exp_dir": Path("conformer_lm/exp_6"),
|
||||
"lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
|
||||
"num_tokens": 5000,
|
||||
"blank_sym": 0,
|
||||
"bos_sym": 1,
|
||||
"eos_sym": 1,
|
||||
"start_epoch": 0,
|
||||
"start_epoch": 2,
|
||||
"num_epochs": 20,
|
||||
"num_valid_batches": 200,
|
||||
"symbols_per_batch": 5000,
|
||||
@ -529,7 +530,7 @@ def run(rank, world_size, args):
|
||||
model.parameters(),
|
||||
max_lrate=params.max_lrate,
|
||||
first_decrease_epoch=1,
|
||||
decay_per_epoch=0.85
|
||||
decay_per_epoch=0.9
|
||||
)
|
||||
|
||||
if checkpoints:
|
||||
|
@ -82,8 +82,8 @@ def main():
|
||||
sentences.append([ word2index[w] for w in line_words])
|
||||
|
||||
output = dict()
|
||||
output['words' ] = k2.ragged.create_ragged2(words2bpe)
|
||||
output['data'] = k2.ragged.create_ragged2(sentences)
|
||||
output['words' ] = k2.ragged.RaggedTensor(words2bpe)
|
||||
output['data'] = k2.ragged.RaggedTensor(sentences)
|
||||
|
||||
torch.save(output, args.lm_archive)
|
||||
print(f"Saved to {args.lm_archive}")
|
||||
|
@ -3,7 +3,7 @@
|
||||
set -eou pipefail
|
||||
|
||||
nj=15
|
||||
stage=-1
|
||||
stage=9
|
||||
stop_stage=100
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
@ -195,7 +195,7 @@ fi
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
lm_dir=lm_dir=data/lm_training_${vocab_size}
|
||||
lm_dir=data/lm_training_${vocab_size}
|
||||
mkdir -p $lm_dir
|
||||
log "Stage 9: creating $lm_dir/lm_data.pt"
|
||||
./local/prepare_lm_training_data.py data/lang_bpe_${vocab_size}/bpe.model download/lm/librispeech-lm-norm.txt $lm_dir/lm_data.pt
|
||||
|
Loading…
x
Reference in New Issue
Block a user