mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-15 20:22:42 +00:00
UPdates for new k2 version; change LR decay from 0.85 to 0.9
This commit is contained in:
parent
d0e5b9b8a5
commit
3ce1de337d
@ -14,8 +14,8 @@ class LmDataset(torch.utils.data.Dataset):
|
|||||||
Torch dataset for language modeling data. This is a map-style dataset.
|
Torch dataset for language modeling data. This is a map-style dataset.
|
||||||
The indices are integers.
|
The indices are integers.
|
||||||
"""
|
"""
|
||||||
def __init__(self, sentences: k2.RaggedInt,
|
def __init__(self, sentences: k2.RaggedTensor,
|
||||||
words: k2.RaggedInt):
|
words: k2.RaggedTensor):
|
||||||
super(LmDataset, self).__init__()
|
super(LmDataset, self).__init__()
|
||||||
self.sentences = sentences
|
self.sentences = sentences
|
||||||
self.words = words
|
self.words = words
|
||||||
@ -30,12 +30,17 @@ class LmDataset(torch.utils.data.Dataset):
|
|||||||
Return the i'th sentence, as a list of ints (representing BPE pieces, without
|
Return the i'th sentence, as a list of ints (representing BPE pieces, without
|
||||||
bos or eos symbols).
|
bos or eos symbols).
|
||||||
"""
|
"""
|
||||||
# It would be nicer if we could just return self.sentences[i].tolist(), but
|
# in future will just do:
|
||||||
# for now that operator on k2.RaggedInt is not implemented.
|
#return self.words[self.sentences[i]].tolist()
|
||||||
row_splits = self.sentences.row_splits(1)
|
|
||||||
|
# It would be nicer if we could just return self.sentences[i].tolist(),
|
||||||
|
# but for now that operator on k2.RaggedInt does not support when the
|
||||||
|
# ragged has only 2 axes.
|
||||||
|
row_splits = self.sentences.shape.row_splits(1)
|
||||||
(begin, end) = row_splits[i:i+2].tolist()
|
(begin, end) = row_splits[i:i+2].tolist()
|
||||||
sentence = self.sentences.values()[begin:end]
|
sentence = self.sentences.data[begin:end]
|
||||||
return k2.index(self.words, sentence).values().tolist()
|
sentence, _ = self.words.index(sentence, axis=0, need_value_indexes=False)
|
||||||
|
return sentence.data.tolist()
|
||||||
|
|
||||||
|
|
||||||
def load_train_test_lm_dataset(archive_fn: Union[str,Path],
|
def load_train_test_lm_dataset(archive_fn: Union[str,Path],
|
||||||
@ -45,22 +50,21 @@ def load_train_test_lm_dataset(archive_fn: Union[str,Path],
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
d = torch.load(archive_fn)
|
d = torch.load(archive_fn)
|
||||||
words = d['words'] # a k2.RaggedInt with 2 axes, maps from word-ids to sequences of BPE pieces
|
words = d['words'] # a k2.RaggedTensor with 2 axes, maps from word-ids to sequences of BPE pieces
|
||||||
sentences = d['data'] # a k2.RaggedInt
|
sentences = d['data'] # a k2.RaggedTensor
|
||||||
|
|
||||||
with torch.random.fork_rng(devices=[]):
|
with torch.random.fork_rng(devices=[]):
|
||||||
g = torch.manual_seed(0)
|
g = torch.manual_seed(0)
|
||||||
num_sentences = sentences.tot_size(0)
|
num_sentences = sentences.tot_size(0)
|
||||||
# probably the generator (g) argument to torch.randperm below is not necessary.
|
# probably the generator (g) argument to torch.randperm below is not necessary.
|
||||||
sentence_perm = torch.randperm(num_sentences, generator=g, dtype=torch.int32)
|
sentence_perm = torch.randperm(num_sentences, generator=g, dtype=torch.int32)
|
||||||
sentences = k2.index(sentences, sentence_perm)
|
sentences, _ = sentences.index(sentence_perm, axis=0, need_value_indexes=False)
|
||||||
|
|
||||||
num_test_sentences = int(num_sentences * test_proportion)
|
num_test_sentences = int(num_sentences * test_proportion)
|
||||||
|
|
||||||
axis=0
|
axis=0
|
||||||
train_sents = _k2.ragged_int_arange(sentences, axis,
|
train_sents = sentences.arange(axis, num_test_sentences, num_sentences)
|
||||||
num_test_sentences, num_sentences)
|
test_sents = sentences.arange(axis, 0, num_test_sentences)
|
||||||
test_sents = _k2.ragged_int_arange(sentences, axis, 0, num_test_sentences)
|
|
||||||
|
|
||||||
return LmDataset(train_sents, words), LmDataset(test_sents, words)
|
return LmDataset(train_sents, words), LmDataset(test_sents, words)
|
||||||
|
|
||||||
@ -683,27 +687,25 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
|||||||
# sampler is reponsible for (all of them, in the non-distributed case).
|
# sampler is reponsible for (all of them, in the non-distributed case).
|
||||||
data_indexes = torch.arange(self.rank, len(dataset), self.world_size, dtype=torch.int32) # dtype=torch.int32
|
data_indexes = torch.arange(self.rank, len(dataset), self.world_size, dtype=torch.int32) # dtype=torch.int32
|
||||||
|
|
||||||
word_row_splits = dataset.words.row_splits(1) # dtype=torch.int32
|
word_row_splits = dataset.words.shape.row_splits(1) # dtype=torch.int32
|
||||||
word_lengths = word_row_splits[1:] - word_row_splits[:-1] # dtype=torch.int32
|
word_lengths = word_row_splits[1:] - word_row_splits[:-1] # dtype=torch.int32
|
||||||
|
|
||||||
# the sentences this sampler is responsible for, as sequences of words.
|
# the sentences this sampler is responsible for, as sequences of words.
|
||||||
# It's a ragged tensor of int32
|
# It's a ragged tensor of int32
|
||||||
sentences = k2.index(dataset.sentences, data_indexes)
|
sentences, _ = dataset.sentences.index(data_indexes, axis=0)
|
||||||
|
|
||||||
# sentence_lengths is a k2.RaggedInt like `sentences`, but with the words replaced
|
# sentence_lengths is a k2.RaggedTensor like `sentences`, but with the words replaced
|
||||||
# with their respective lengths, in BPE pieces.
|
# with their respective lengths, in BPE pieces.
|
||||||
sentence_lengths = k2.index(word_lengths, sentences)
|
sentence_lengths = k2.ragged.index(word_lengths, sentences)
|
||||||
del sentences # save memory
|
del sentences # save memory
|
||||||
assert isinstance(sentence_lengths, k2.RaggedInt)
|
assert isinstance(sentence_lengths, k2.RaggedTensor)
|
||||||
|
|
||||||
# convert to float so sum_per_sublist() will work (TODO: sum_per_sublist() will eventually
|
# convert to float so sum_per_sublist() will work (TODO: sum_per_sublist() will eventually
|
||||||
# support int32.)
|
# support int32.)
|
||||||
sentence_lengths = k2.RaggedFloat(sentence_lengths.shape(),
|
sentence_lengths = sentence_lengths.to(dtype=torch.float32)
|
||||||
sentence_lengths.values().to(torch.float32))
|
|
||||||
assert isinstance(sentence_lengths, k2.RaggedFloat)
|
|
||||||
|
|
||||||
# Convert into a simple tensor of float by adding lengths of words.
|
# Convert into a simple tensor of float by adding lengths of words.
|
||||||
sentence_lengths = k2.ragged.sum_per_sublist(sentence_lengths)
|
sentence_lengths = sentence_lengths.sum()
|
||||||
|
|
||||||
assert isinstance(sentence_lengths, torch.Tensor)
|
assert isinstance(sentence_lengths, torch.Tensor)
|
||||||
assert sentence_lengths.dtype == torch.float32
|
assert sentence_lengths.dtype == torch.float32
|
||||||
|
@ -9,8 +9,6 @@ import torch.distributed as dist
|
|||||||
def local_collate_fn(sentences):
|
def local_collate_fn(sentences):
|
||||||
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)
|
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)
|
||||||
|
|
||||||
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
#mp.set_start_method('spawn')
|
#mp.set_start_method('spawn')
|
||||||
|
@ -136,13 +136,14 @@ def get_params() -> AttributeDict:
|
|||||||
# exp_4, vs. exp_3, is using the Gloam optimizer with
|
# exp_4, vs. exp_3, is using the Gloam optimizer with
|
||||||
# in exp_5, vs. exp_4, we change Gloam to have a 1/sqrt(t) factor
|
# in exp_5, vs. exp_4, we change Gloam to have a 1/sqrt(t) factor
|
||||||
# as well as the exponential part.
|
# as well as the exponential part.
|
||||||
"exp_dir": Path("conformer_lm/exp_5"),
|
# exp_6, we change the decay from 0.85 to 0.9.
|
||||||
|
"exp_dir": Path("conformer_lm/exp_6"),
|
||||||
"lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
|
"lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
|
||||||
"num_tokens": 5000,
|
"num_tokens": 5000,
|
||||||
"blank_sym": 0,
|
"blank_sym": 0,
|
||||||
"bos_sym": 1,
|
"bos_sym": 1,
|
||||||
"eos_sym": 1,
|
"eos_sym": 1,
|
||||||
"start_epoch": 0,
|
"start_epoch": 2,
|
||||||
"num_epochs": 20,
|
"num_epochs": 20,
|
||||||
"num_valid_batches": 200,
|
"num_valid_batches": 200,
|
||||||
"symbols_per_batch": 5000,
|
"symbols_per_batch": 5000,
|
||||||
@ -529,7 +530,7 @@ def run(rank, world_size, args):
|
|||||||
model.parameters(),
|
model.parameters(),
|
||||||
max_lrate=params.max_lrate,
|
max_lrate=params.max_lrate,
|
||||||
first_decrease_epoch=1,
|
first_decrease_epoch=1,
|
||||||
decay_per_epoch=0.85
|
decay_per_epoch=0.9
|
||||||
)
|
)
|
||||||
|
|
||||||
if checkpoints:
|
if checkpoints:
|
||||||
|
@ -82,8 +82,8 @@ def main():
|
|||||||
sentences.append([ word2index[w] for w in line_words])
|
sentences.append([ word2index[w] for w in line_words])
|
||||||
|
|
||||||
output = dict()
|
output = dict()
|
||||||
output['words' ] = k2.ragged.create_ragged2(words2bpe)
|
output['words' ] = k2.ragged.RaggedTensor(words2bpe)
|
||||||
output['data'] = k2.ragged.create_ragged2(sentences)
|
output['data'] = k2.ragged.RaggedTensor(sentences)
|
||||||
|
|
||||||
torch.save(output, args.lm_archive)
|
torch.save(output, args.lm_archive)
|
||||||
print(f"Saved to {args.lm_archive}")
|
print(f"Saved to {args.lm_archive}")
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
nj=15
|
nj=15
|
||||||
stage=-1
|
stage=9
|
||||||
stop_stage=100
|
stop_stage=100
|
||||||
|
|
||||||
# We assume dl_dir (download dir) contains the following
|
# We assume dl_dir (download dir) contains the following
|
||||||
@ -195,7 +195,7 @@ fi
|
|||||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||||
for vocab_size in ${vocab_sizes[@]}; do
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
lang_dir=data/lang_bpe_${vocab_size}
|
lang_dir=data/lang_bpe_${vocab_size}
|
||||||
lm_dir=lm_dir=data/lm_training_${vocab_size}
|
lm_dir=data/lm_training_${vocab_size}
|
||||||
mkdir -p $lm_dir
|
mkdir -p $lm_dir
|
||||||
log "Stage 9: creating $lm_dir/lm_data.pt"
|
log "Stage 9: creating $lm_dir/lm_data.pt"
|
||||||
./local/prepare_lm_training_data.py data/lang_bpe_${vocab_size}/bpe.model download/lm/librispeech-lm-norm.txt $lm_dir/lm_data.pt
|
./local/prepare_lm_training_data.py data/lang_bpe_${vocab_size}/bpe.model download/lm/librispeech-lm-norm.txt $lm_dir/lm_data.pt
|
||||||
|
Loading…
x
Reference in New Issue
Block a user