Get dataset.py working..

This commit is contained in:
Daniel Povey 2021-08-21 18:23:46 +08:00
parent 1c3b13c7eb
commit 421a41027a
3 changed files with 789 additions and 0 deletions

View File

@ -0,0 +1,783 @@
import torch
import torch.distributed as dist
import k2
import _k2
import sentencepiece as spm
from typing import Optional, List, Tuple
class LmDataset(torch.utils.data.Dataset):
"""
Torch dataset for language modeling data. This is a map-style dataset.
The indices are integers.
"""
def __init__(self, sentences: k2.RaggedInt,
words: k2.RaggedInt):
super(LmDataset, self).__init__()
self.sentences = sentences
self.words = words
def __len__(self):
# Total size on axis 0, == num sentences
return self.sentences.tot_size(0)
def __getitem__(self, i: int):
"""
Return the i'th sentence, as a list of ints (representing BPE pieces, without
bos or eos symbols).
"""
# It would be nicer if we could just return self.sentences[i].tolist(), but
# for now that operator on k2.RaggedInt is not implemented.
row_splits = self.sentences.row_splits(1)
(begin, end) = row_splits[i:i+2].tolist()
sentence = self.sentences.values()[begin:end]
return k2.index(self.words, sentence).values().tolist()
def load_train_test_lm_dataset(archive_fn: str,
test_proportion: float = 0.025) -> Tuple[LmDataset, LmDataset]:
"""
returns (train_lm_dataset, test_lm_dataset)
"""
d = torch.load(archive_fn)
words = d['words'] # a k2.RaggedInt with 2 axes, maps from word-ids to sequences of BPE pieces
sentences = d['data'] # a k2.RaggedInt
with torch.random.fork_rng(devices=[]):
g = torch.manual_seed(0)
num_sentences = sentences.tot_size(0)
# probably the generator (g) argument to torch.randperm below is not necessary.
sentence_perm = torch.randperm(num_sentences, generator=g, dtype=torch.int32)
sentences = k2.index(sentences, sentence_perm)
num_test_sentences = int(num_sentences * test_proportion)
axis=0
train_sents = _k2.ragged_int_arange(sentences, axis,
num_test_sentences, num_sentences)
test_sents = _k2.ragged_int_arange(sentences, axis, 0, num_test_sentences)
return LmDataset(train_sents, words), LmDataset(test_sents, words)
def mask_and_pad(sentence: List[int],
seq_len: int,
bos_sym: int,
eos_sym: int,
blank_sym: int,
mask_proportion: float,
padding_proportion: float,
inv_mask_length: float,
unmasked_weight: float) -> Tuple[List[int], List[int], List[int], List[float]]:
"""
This function contains part of the logic of collate_fn, broken out. It is responsible
for inserting masking and padding into the sequence `sentence`. Most of the arguments
are documented for `collate_fn` below.
Other args:
sentence: The original sentence to be masked and padded.
seq_len: The desired length of the lists to be returned
bos_sym, eos_sym, blank_sym, mask_proportion,
padding_proportion, inv_mask_length, unmasked_weight: see their documentation
as args to `collate_fn` below.
Return: a tuple (src, masked_src, tgt, weight, randomizable, attn_mask), all lists of length `seq_len`,
where:
`src` is: [bos] + [the sentence after inserting blanks in place of padding
after regions to be masked] + [eos] + [blank padding to seq_len].
`src_masked` is as `src` but the masked regions have their values replaced with blank,
i.e. they are actually masked.
`tgt` is: [the original sentence, without masking] + [eos] + [blank] + [blank padding to seq_len]
`weight` is the weight at the nnet output, which is: `unmasked_weight` for un-masked
positions, 1.0 for masked and padded positions, and 0.0 for positions that
correspond to blank-padding after the final [eos].
`randomizable` is a bool that is True for positions where the symbol in
in `src_masked` is not bos or eos or blank.
`attn_mask` is a bool that is False for positions in `src` and `src_masked` that
are between the initial [bos] and final [eos] inclusive; and True for
positions after the final [eos].
"""
sent_len = len(sentence)
assert sent_len + 3 <= seq_len
for w in sentence:
assert w not in [bos_sym, eos_sym, blank_sym]
num_mask = int(torch.binomial(count=torch.tensor([sent_len * 1.0]),
prob=torch.tensor([mask_proportion])).item())
num_pad = int(torch.poisson(torch.tensor([sent_len * padding_proportion])).item())
# Ensure the total length after bos, padding of masked sequences, and eos, is
# no greater than seq_len
num_pad -= max(0, sent_len + 2 + num_pad - seq_len)
if num_mask + num_pad == 0:
num_mask += 1
# num_split_points is the number of times we split the (masked+padded)
# region, so the total number of (masking+padding) subsequences will be
# num_split_points + 1. If num_mask positions are masked, then the
# remaining number of words is `sent_len - num_mask`, and any two
# masked regions must have at least one non-masked word between them,
# so num_split_points == number of masked regions - 1, must be
# no greater than `sent_len - num_mask`. The formula about
# mask_proportion * inv_mask_length / (1.0 - mask_proportion)
# is what's required (I think) so that inv_mask_length is the expected
# length of masked regions.
num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]),
prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item())
assert num_split_points <= sent_len - num_mask
assert isinstance(num_split_points, int)
def split_into_subseqs(length: int , num_subseqs: int) -> List[int]:
"""Splits a sequence of `length` items into `num_subseqs` possibly-empty
subsequences. The length distributions are geometric, not Poisson, i.e.
we choose the split locations with uniform probability rather than
randomly assigning each word to one subsequences. This gives us more
shorter/longer subsequences.
Require num_subseqs > 0
"""
boundaries = [0] + sorted(torch.randint(low=0, high=length + 1, size=(num_subseqs - 1,)).tolist()) + [length]
return [ boundaries[i + 1] - boundaries[i] for i in range(num_subseqs) ]
mask_lengths = split_into_subseqs(num_mask, num_split_points + 1)
pad_lengths = split_into_subseqs(num_pad, num_split_points + 1)
# mask_pad_lengths contains only the (mask, pad) length pairs for which mask + pad > 0.
# From this point we only refer to the mask_pad_lengths.
mask_pad_lengths = [ (mask, pad) for (mask, pad) in zip(mask_lengths, pad_lengths) if mask+pad > 0 ]
num_subseqs = len(mask_pad_lengths)
assert num_subseqs > 0
# Now figure out how to distribute these subsequences throughout the actual
# sentence. The subsequences, if there are more than one, must not touch,
# i.e. there must be an actual word in between each subsequence, where the
# number of such "mandatory" words equals num_subseqs - 1. We also have to
# subtract `num_mask` words, since obviously the masked words cannot separate
# the masked regions.
reduced_len = sent_len - num_mask - (num_subseqs - 1)
assert reduced_len >= 0
# unmasked_lengths will be the lengths of the un-masked regions between the masked
# regions.
unmasked_lengths = split_into_subseqs(reduced_len, num_subseqs + 1)
for i in range(1, num_subseqs):
# Unmasked regions between masked regions must have length at least 1,
# we add 1 to unmasked regions that are not initial/final.
unmasked_lengths[i] = unmasked_lengths[i] + 1
assert sum(unmasked_lengths) + sum(mask_lengths) == sent_len
# src_positions will be: for each position in the masked+padded sentence,
# the corresponding position in the source sentence `sentence`; or -1
# if this was padding.
src_positions = []
# `masked` will be: for each position in the masked+padded sentence, True if
# it was masked and False otherwise. (Note: it is False for padding
# locations, although this will not matter in the end).
masked = []
cur_pos = 0 # current position in source sentence
for i in range(num_subseqs + 1):
for j in range(unmasked_lengths[i]):
src_positions.append(cur_pos)
masked.append(False)
cur_pos += 1
if i < num_subseqs:
(mask_len, pad_len) = mask_pad_lengths[i]
for j in range(mask_len):
src_positions.append(cur_pos)
masked.append(True)
cur_pos += 1
for j in range(pad_len):
src_positions.append(-1)
masked.append(False)
assert cur_pos == len(sentence)
src = []
src_masked = []
tgt = []
weight = []
randomizable = []
src.append(bos_sym)
src_masked.append(bos_sym)
randomizable.append(False)
for i, src_pos in enumerate(src_positions):
is_masked = masked[i]
if src_pos >= 0:
src_word = sentence[src_pos]
src_masked.append(blank_sym if masked[i] else src_word)
src.append(src_word)
tgt.append(src_word)
weight.append(1.0 if masked[i] else unmasked_weight)
randomizable.append(not masked[i])
else:
# Padding inside a masked region
src_masked.append(blank_sym)
src.append(blank_sym)
tgt.append(blank_sym)
weight.append(1.0)
randomizable.append(False)
src.append(eos_sym)
src_masked.append(eos_sym)
tgt.append(eos_sym)
weight.append(unmasked_weight)
tgt.append(blank_sym)
weight.append(0.0)
randomizable.append(False)
attn_mask = ([False] * len(src)) + ([True] * (seq_len - len(src)))
for i in range(seq_len - len(src)):
src.append(blank_sym)
src_masked.append(blank_sym)
tgt.append(blank_sym)
weight.append(0.0)
randomizable.append(False)
return (src, src_masked, tgt, weight, randomizable, attn_mask)
# dataset.mask_and_pad(list(range(10, 20)), seq_len=16, bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, inv_mask_length=0.33, unmasked_weight=0.444)
# dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45))], bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, randomize_proportion=0.05, inv_mask_length=0.33, unmasked_weight=0.444)
def collate_fn(sentences: List[List[int]],
bos_sym: int,
eos_sym: int,
blank_sym: int,
mask_proportion: float = 0.15,
padding_proportion: float = 0.15,
randomize_proportion: float = 0.05,
inv_mask_length: float = 0.25,
unmasked_weight: float = 0.25,
debug: bool = False) -> Tuple[torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
torch.Tensor]:
"""
Caution, this is not the collate_fn we give directly to the dataloader,
we give it a lambda: collate_fn=(lambda x: dataset.collate_fn(x, [other args]))
This formats a list-of-lists-of-int into 5 Tensors, explained below.
The key thing is that we mask out subsequences of random length within
these sentences, and force the network to predict the masked-out
subsequences (which have blanks appended to them to prevent the model
from knowing the exact length of the sequences it has to predict).
So it's like BERT but at the level of sequences rather than individual
words.
Args:
bos_sym: the integer id of the beginning-of-sentence symbol, e.g. 2.
Is allowed be the same as eos_sym (we are not necessarily
saying it will work best that way).
eos_sym: the integer id of the end-of-sentence symbol, e.g. 2.
blank_sym: the integer id of the blank symbol, e.g. 0 or 1.
mask_proportion: The proportion of words in each sentence that
are masked, interpreted as (roughly) the probability of any given
word being masked, although the masked locations will
tend to be in contiguous sequences (they are not independent).
padding_proportion: Like mask_proportion, but determines the
number of extra, blank symbols that are inserted as padding
at the end of masked regions (this ensures that the model
cannot know exactly how many words need to be inserted in
any given masked region.
randomize_proportion: The probability with which we replace
words that were not masked with randomly chosen words.
Like BERT, this is intended to force the model to predict
something reasonable at non-masked positions, and to make
this task harder than simply repeating the input.
inv_mask_length: This number determines how many separate
sub-sequences the (masked + padded) proportion of a sentence is split up
into, interpreted as the inverse of the expected length of
each *masked* region.
unmasked_weight: The weight to be applied to the log-likelihoods of
un-masked positions in sentences (predicting un-masked
positions is not completely trivial if randomize_proportion > 0).
Will be reflected in the returned tgt_weights tensor.
Returns a tuple (masked_src_symbols, src_symbols,
tgt_symbols, src_attn_mask,
tgt_weights),
all with 2 axes and the same shape: (num_sent, seq_len).
Their dtypes will be, respectively,
(torch.int64, torch.int64,
torch.int64, torch.bool,
torch.float)
masked_src_symbols: The sentences, with bos_symbol prepended and eos_symbol
appended, masked regions (including padding) replaced with blank,
and `randomize_proportion` non-masked symbols replaced with
symbols randomly taken from elsewhere in the sentences of this
minibatch. Then padded to a fixed length with blank.
src_symbols: Like masked_src_symbols, except with the masked symbols replaced
with the original symbols (but the padding that follows each
masked sub-sequence will still be blank)
tgt_symbols: The original sentences, with eos_symbol appended, and then
padded with blank to the same length as masked_symbols and
src_symbols.
src_attn_mask: Masking tensor for masked_src_symbols and src_symbols, to
account for all the sentence lengths not being identical
(makes each sentence's processing independent of seq_len).
Tensor of Bool of shape (num_sent, seq_len), with True
for masked positions (these are the blanks that follow the
eos_symbol in masked_src_symbols), False for un-masked positions.
tgt_weights: Weights that will be applied to the log-probabilities at
the output of the network. Will have 1.0 in positions
in `tgt_symbols` that were masked (including blank
padding at the end of masked regions), `unmasked_weight`
in other positions in the original sentences (including
terminating eos_symbol); and 0.0 in the remaining positions
corresponding to blank padding after the ends of
sentences.
"""
assert blank_sym not in [bos_sym, eos_sym]
max_sent_len = max([ len(s) for s in sentences])
typical_mask_and_pad = int(max_sent_len * (mask_proportion + padding_proportion))
# The following formula gives roughly 1 standard deviation above where we'd
# expect the maximum sentence length to be with masking and padding.. we use
# this as a hard upper limit, to prevent outliers from affecting the batch
# size too much. We use this as the size `seq_len`.
# The "+ 4" is to ensure there is always room for the BOS, EOS and at least
# two padding symbols.
seq_len = max_sent_len + 4 + typical_mask_and_pad + int(typical_mask_and_pad ** 0.5)
# srcs, srcs_masked, tgts and weights will be lists of the lists returned
# from `mask_and_pad`, one per sentence.
srcs = []
srcs_masked = []
tgts = []
weights = []
randomizables = []
attn_masks = []
for s in sentences:
(src, src_masked, tgt,
weight, randomizable,
attn_mask) = mask_and_pad(s, seq_len, bos_sym, eos_sym,
blank_sym, mask_proportion, padding_proportion,
inv_mask_length, unmasked_weight)
srcs.append(src)
srcs_masked.append(src_masked)
tgts.append(tgt)
weights.append(weight)
randomizables.append(randomizable)
attn_masks.append(attn_mask)
src_symbols = torch.tensor(srcs, dtype=torch.int64)
masked_src_symbols = torch.tensor(srcs_masked, dtype=torch.int64)
tgt_symbols = torch.tensor(tgts, dtype=torch.int64)
src_attn_mask = torch.tensor(attn_masks, dtype=torch.bool)
tgt_weights = torch.tensor(weights, dtype=torch.float)
attn_mask_sum = torch.sum(torch.logical_not(src_attn_mask), dim=0).tolist()
while attn_mask_sum[-1] == 0: # Remove always-masked positions at the endof the lists.
attn_mask_sum.pop()
if len(attn_mask_sum) < seq_len:
seq_len = len(attn_mask_sum)
(src_symbols, masked_src_symbols,
tgt_symbols, src_attn_mask, tgt_weights) = (src_symbols[:,:seq_len], masked_src_symbols[:,:seq_len],
tgt_symbols[:,:seq_len], src_attn_mask[:,:seq_len],
tgt_weights[:,:seq_len])
if randomize_proportion > 0.0:
randomizable_tensor = torch.tensor(randomizables, dtype=torch.bool)
randomizable_indexes = torch.nonzero(randomizable_tensor) # (num_randomizable, 2)
num_randomizable = randomizable_indexes.shape[0]
to_randomize_indexes = torch.nonzero(torch.rand(num_randomizable) < randomize_proportion, as_tuple=True)[0]
num_to_randomize = to_randomize_indexes.numel()
# older versions of torch don't have tensor_split, so fake a simplified version of it.
# we'd be calling it as xxx.tensor_split(dim=1) if really in torc.
def tensor_split(t):
return (t[:,0], t[:,1])
random_src_locations = torch.randperm(num_randomizable)[:num_to_randomize]
random_symbols = src_symbols[tensor_split(randomizable_indexes[random_src_locations])]
random_indexes_tuple= tensor_split(randomizable_indexes[to_randomize_indexes])
src_symbols[random_indexes_tuple] = random_symbols
masked_src_symbols[random_indexes_tuple] = random_symbols
# I set this to true and tested with:
# python3 -c 'import dataset; dataset.collate_fn(sentences=[ list(range(100, 200)), list(range(300, 450)), list(range(500,600))], bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, randomize_proportion=0.05, inv_mask_length=0.33, unmasked_weight=0.444)'
#.. and ran a few times to check the values printed looked about right, and that no assertions failed.
if debug:
check_collated_tensors(sentences, bos_sym, eos_sym, blank_sym,
unmasked_weight,
masked_src_symbols, src_symbols,
tgt_symbols, src_attn_mask, tgt_weights)
return (masked_src_symbols, src_symbols,
tgt_symbols, src_attn_mask, tgt_weights)
def check_collated_tensors(sentences: List[List[int]],
bos_sym: int,
eos_sym: int,
blank_sym: int,
unmasked_weight: float,
masked_src_symbols, src_symbols,
tgt_symbols, src_attn_mask,
tgt_weights):
"""
This function checks the output of collate_fn, consider it test code. Please see
the documentation of collate_fn to understand the args.
"""
for t in src_symbols, tgt_symbols, src_attn_mask, tgt_weights:
assert t.shape == masked_src_symbols.shape
tot_positions = src_symbols.numel()
masked_src_symbols, src_symbols, tgt_symbols, src_attn_mask, tgt_weights = (
masked_src_symbols.tolist(), src_symbols.tolist(), tgt_symbols.tolist(),
src_attn_mask.tolist(), tgt_weights.tolist())
assert len(sentences) == len(masked_src_symbols)
tot_masked_positions = 0
tot_padded_positions = 0
tot_unmasked_positions = 0 # all un-masked, non-blank postions, including eos
tot_randomized_positions = 0
num_masked_subseqs = 0
tot_symbols = 0 # original symbols in sentences, no bos/eos
assert unmasked_weight > 0.001 # or this test code won't work..
for i in range(len(sentences)):
reconstructed_sent = list(filter(lambda x: x not in [bos_sym,eos_sym,blank_sym], tgt_symbols[i]))
if sentences[i] != reconstructed_sent:
print(f"Error: sentence {i}={sentences[i]} differs from {reconstructed_sent}")
(masked_src, src, tgt, src_mask, weights) = (masked_src_symbols[i], src_symbols[i],
tgt_symbols[i], src_attn_mask[i], tgt_weights[i])
assert src[0] == masked_src[0] == bos_sym
for j in range(len(masked_src)):
assert masked_src[j] == blank_sym or masked_src[j] == src[j]
if src[j] not in [bos_sym, eos_sym, blank_sym]:
tot_symbols += 1
if j > 0:
assert (src[j] == eos_sym) == (masked_src[j] == eos_sym) == (tgt[j-1] == eos_sym)
if masked_src[j] == blank_sym: # masked or padding of masked subseq, or post-eos padding..
assert src[j] == tgt[j - 1] # masked symbols are not randomized.
assert weights[j - 1] in [0.0, 1.0] # 0.0 for final blank padding
if weights[j - 1] == 1.0: # Not final blank padding...
if tgt[j - 1] == blank_sym:
tot_padded_positions += 1
else:
tot_masked_positions += 1
if masked_src[j + 1] != blank_sym:
num_masked_subseqs += 1
else:
assert weights[j - 1] == 0 or abs(weights[j-1] - unmasked_weight) < 0.001
if abs(weights[j - 1]-unmasked_weight) < 0.001:
tot_unmasked_positions += 1
if tgt[j - 1] != src[j]:
tot_randomized_positions += 1
if src_mask[j]: # if masked..
assert src[j] == blank_sym
assert tot_symbols == sum(len(x) for x in sentences)
assert tot_unmasked_positions + tot_masked_positions == tot_symbols + len(sentences)
print(f"{tot_unmasked_positions} + {tot_masked_positions} == {tot_symbols} + {len(sentences)}")
print(f"tot_symbols / tot_positions = {tot_symbols/tot_positions} (rest is bos,eos,padding)")
print(f"Masking/tot_symbols = {tot_masked_positions/tot_symbols}, Padding/tot_symbols = {tot_padded_positions/tot_symbols}")
print(f"Randomization/tot_non_masked_symbols = {tot_randomized_positions/(tot_symbols-tot_masked_positions)}")
print(f"Mean masking length = {tot_masked_positions/num_masked_subseqs}, Mean padding length = {tot_padded_positions/num_masked_subseqs}")
# This shows some useful code about the BPE encoding.
# import sentencepiece as spm
# sp = spm.SentencePieceProcessor()
# sp.load(bpe_model_fn) # bpe.model
# sp.GetPieceSize(..)
# sp.Decode(...)
# sp.Encode(...)
# import dataset
# import torch
# train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
# train_dl = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True, collate_fn=(lambda x: train.collate_fn(x)))
# x = iter(train_dl)
# str(next(x))
# '[ [ 10 38 651 593 3 1343 31 780 6 4172 112 788 1696 24 289 24 3 403 6 4493 162 92 71 328 417 217 338 14 5 3 1876 154 21 23 2237 43 3 1535 92 71 2816 7 1031 31 2318 92 2528 4806 14 206 3 954 1373 6 525 4 631 447 2639 ] [ 1014 336 171 209 795 10 16 90 27 787 139 53 45 2817 ] [ 11 980 51 22 1748 14 91 105 363 428 6 8 2887 3305 2525 2297 70 3 4651 6 27 282 335 426 134 292 5 193 3 539 2250 584 127 ] [ 9 3 1858 4 18 2257 4 6 41 748 10 304 7 229 83 2793 4 9 981 7 1484 33 3 103 7 539 5 477 3195 18 64 39 82 1034 6 3 4128 ] [ 17 147 22 7 708 60 133 174 105 4111 4 6 3 1384 65 50 1051 9 2953 6 3 461 180 1142 23 5 36 888 8 131 173 390 78 23 266 2822 715 46 182 65 22 1739 33 3 700 1450 14 233 4 ] [ 80 10 16 67 279 7 1827 264 96 3 187 2851 2108 ] [ 1473 48 106 227 9 160 2011 4 674 ] [ 3 954 762 29 85 228 33 8 940 40 4952 36 486 390 595 3 81 225 6 1440 125 346 134 296 126 419 1017 3824 4 8 179 184 11 33 580 1861 ] [ 30 22 245 15 117 8 2892 28 1204 145 7 3 236 3417 6 3 3839 5 3106 155 198 30 228 2555 46 15 32 41 747 72 9 25 977 ] [ 222 466 6 3157 ] ]'
#
# or:
# import k2
# k2.ragged.to_list(next(x))
# [shows something similar].
#
# You'd really do something like:
# for epoch in range(max_epochs):
# for minibatch in train_dl:
# .. How to process data? Suppose we have a sentence like [259, 278, 45, 11, 303, 1319, 34, 15, 396, 3435, 7, 44].
#
# First: we randomly choose one or more starting positins for a masked segment.
# Each sentence must have at least one masked segment (or there is no contribution to the loss function).
# We choose to have:
# num_masked_segments = max(1, len(sent) // 15)
#
# The length of the masked segment (this is the target for prediction), we set to the geometric
# distribution with the probability of success set to 3:
#
# g = torch.distributions.geometric.Geometric(probs=0.3) # <-- expected value is 3.333
# Example of sampling:
# g.sample(sample_shape=torch.Size([10]))
#
# We now we randomly compute the location of the masked segments (length computed above) as follows:
# First, the masked segments must be separated by at least one non-masked word (else they would be
# a single segment). So for n masked segments, there are n-1 words required for minimal separation.
# If tot-length-of-segments + n-1 is greater than the sentence length, we just have the entire
# sentence be masked. Otherwise, we randomly divide the remaining number of words between the n+1
# positions where they can appear (e.g. for 2 segments, this would be at the start, between the 2 segments,
# and at the end). This is the multinomial distribution, but we can more easily compute this
# directly using rand() and cutoffs, rather than creating a torch.distributions.Multinomial().
#
# Next we need to compute a random amount of blank padding (>= 0) for each of the masked regions;
# this is done so the model never knows the exact length of the masked region. We can just use the
# same distribution as for the length of the masked regions, i.e. geometric with success-prob=0.3
# (expected padding length is 3).
#
# At this point we know where the masked regions are and how much padding they have. We can format
# the result as three lists, of the same length:
#
# sent: contains the words in the sentence with, in masked
# positions, the original (target) words, then with
# blank in the blank-padding after masked positions.
#
# sent_augmented: `sent` with, at a small defined percentage of positions
# that were *not* masked, the real token replaced with a
# token randomly chosen from the tokens in the minibatch.
# (like BERT, we use this type of augmentation, so the model
# has to predict the original token).
#
# masked_sent_augmented: List[int], contains the words in `sent_augmented`, except
# with masked positions and the blank padding after the masked regions
# both replaced with blank.
#
#
#
# The way these will be processed is as follows:
#
# masked_sent_in = [bos] + masked_sent_augmented + [eos] <-- so we know the sentence ended, distinguish it from truncated ones.
# sent_in = [bos] + sent_augmented + [eos]
#
# sent_out = sent + [eos] + [eos] #<--- the predicted targets at each point, although
# # we only really care about this in masked regions.
# # The extra eos is so that the length is the same as
# # masked_sent_in and sent_in.
#
# out_scale = (masked_sent==blk ? 1.0 : non_masked_scale) # e.g. non_masked_scale = 1.0 is fine,
# # this is a choice; we can perhaps
# # report these 2 parts of the loss
# # separately though.
# # <-- can also set the last element
# # of out_scale to a smaller number, since
# # it's a repeated eos.
#
#
# OK, how do we combine these into a minibatch? Firstly, we truncate sentences to a maximum
# length, e.g. 128, if `masked_sent_in`/`sent_in` have length longer than that. We choose randomly
# in each case to truncate the beginning or end, truncating both masked_sent_in/sent_in and sent_out
# from the same side. Caution: this means that these sentences may lack bos and/or eos symbols.
#
# Next, we combine shorter utterances by appending them ( all of: masked_sent_in, sent_in, out_scale)
# as long as doing so would keep the total length under 128. We then pad (masked_sent_in, sent_in, sent_out, out_scale)
# with: (<blk>,<blk>,<eos>, 0) up to the maximum length of any sentence in the minibatch <- or could use
#
#
#
#
#
#
#
# # i.e. ones where masked_sent is blank and zeros elsewhere;
# # this pertains to positions in `sent_out`.
#
#
#
#
#
#
#
#
#
#
# torch.distributions.gamma.Gamma(concentration=1.0, rate=1.0/5)
class LmBatchSampler(torch.utils.data.Sampler):
"""
A sampler that returns a batch of integer indexes as a list, intended for use
with class LmDataset. The sentences returned in each batch will all be about
the same size, and the batch size is specified as a number of words (we also
provide an option that allows you to limit the max memory consumed by transformers)
Has support for distributed operation.
"""
def __init__(self, dataset: LmDataset,
symbols_per_batch: int,
quadratic_constant: float = 0.005,
world_size: Optional[int] = None,
rank: int = None,
seed: int = 0):
"""
Constructor documentation:
dataset: the LmDataset object that we are sampling from. This
class does not retain a reference to the LmDataset.
symbols_per_batch: The number of BPE symbols desired in each minibatch
quadratic_constant: After the sentence length gets more than about
1.0/quadratic_constant, the batch size will start decreasing
as 1/(sentence-length^2). This is a mechanism to
avoid excessive memory consumption in transformers, when
sentence length gets long.
world_size: The world size for distributed operation; if None,
will be worked out from torch.distributed.
rank: The rank of this sampler/process for distributed operation; if None,
will be worked out from torch.distributed.
seed: The random seed
"""
self.seed = seed
self.symbols_per_batch = symbols_per_batch
self.quadratic_constant = quadratic_constant
self._maybe_init_distributed(world_size=world_size, rank=rank)
# a configuration constant we don't expose.
self.multiplicative_random_length = 0.05
# "indexes" is the subset of indexes into LmDataset that this
# sampler is reponsible for (all of them, in the non-distributed case).
data_indexes = torch.arange(self.rank, len(dataset), self.world_size, dtype=torch.int32) # dtype=torch.int32
word_row_splits = dataset.words.row_splits(1) # dtype=torch.int32
word_lengths = word_row_splits[1:] - word_row_splits[:-1] # dtype=torch.int32
# the sentences this sampler is responsible for, as sequences of words.
# It's a ragged tensor of int32
sentences = k2.index(dataset.sentences, data_indexes)
# sentence_lengths is a k2.RaggedInt like `sentences`, but with the words replaced
# with their respective lengths, in BPE pieces.
sentence_lengths = k2.index(word_lengths, sentences)
del sentences # save memory
assert isinstance(sentence_lengths, k2.RaggedInt)
# convert to float so sum_per_sublist() will work (TODO: sum_per_sublist() will eventually
# support int32.)
sentence_lengths = k2.RaggedFloat(sentence_lengths.shape(),
sentence_lengths.values().to(torch.float32))
assert isinstance(sentence_lengths, k2.RaggedFloat)
# Convert into a simple tensor of float by adding lengths of words.
sentence_lengths = k2.ragged.sum_per_sublist(sentence_lengths)
assert isinstance(sentence_lengths, torch.Tensor)
assert sentence_lengths.dtype == torch.float32
# self.sentence_lengths is a Tensor with dtype=torch.float32. It
# contains the lengths, in BPE tokens, of the sentences that this
# sampler is responsible for, whose real indexes are in
# `data_indexes` above (this is not stored, as we know the formula).
self.sentence_lengths = sentence_lengths
self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes
def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]):
if world_size is not None:
assert world_size >= 1
if rank is not None:
assert rank >= 0
if not dist.is_available() or not dist.is_initialized():
self.world_size = 1 if world_size is None else world_size
self.rank = 0 if rank is None else rank
return
self.world_size = dist.get_world_size() if world_size is None else world_size
self.rank = dist.get_rank() if rank is None else rank
assert self.rank < self.world_size
def set_epoch(self, epoch: int):
"""
Must be called at the beginning of each epoch, before initializing the DataLoader,
to re-shuffle the data. If this is not done, this sampler will give you the same batches
each time it is called.
"""
g = torch.manual_seed(self.rank + self.seed + epoch)
sentence_lengths = (self.sentence_lengths *
(1.0 + torch.rand(*self.sentence_lengths.shape, generator=g) * self.multiplicative_random_length))
# This mechanism regulates the batch size so that we don't get OOM in transformers
# when the sentences are long.
sentence_lengths = sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant
values, indices = torch.sort(sentence_lengths) # values,indices dtypes: torch.float,torch.int64
# map to the original indexes into the dataset (the original sentence
# indexes), see torch.arange expression in the constructor. save as
# int32 just to save a little memory. self.indices are indexes into the
# LmDataset, just including the subset of indices that this sampler is
# responsible for (in terms of rank and world_size), and sorted by
# length with a small amount of randomization specific to the epoch.
self.indices = ((indices * self.world_size) + self.rank).to(dtype=torch.int32)
# now `batch_ids` will be: [0, 0, 0, 0, .., 0, 1, 1, 1, ... 1, 2, ... ],
# saying which batch each element of values/indices belongs to.
batch_ids = (torch.cumsum(values, dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32)
batch_boundaries = torch.nonzero(batch_ids[1:] - batch_ids[:-1], as_tuple=True)[0]
batch_boundaries.add_(1)
self.batch_boundaries = torch.cat((torch.zeros(1, dtype=torch.int32), batch_boundaries), dim=0)
num_batches = self.batch_boundaries.numel() - 1
# self.batch_indices is a permutation of [0, 1, ... num_batches -
# 1]; it determines the order in which we access the batches. It's
# necessary to randomize the order of these, to avoid returning batches
# from shortest to longest sentences.
self.batch_indices = torch.randperm(num_batches, generator=g, dtype=torch.int32).tolist()
def __len__(self):
return len(self.batch_indices)
def __iter__(self):
"""
Iterator that yields lists of indices (i.e., integer indices into the LmDataset)
"""
for batch_idx in self.batch_indices:
batch_start = self.batch_boundaries[batch_idx].item()
batch_end = self.batch_boundaries[batch_idx + 1].item()
yield self.indices[batch_start:batch_end].tolist()
# train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
# sampler = dataset.LmBatchSampler(test, symbols_per_batch=1000, world_size=2, rank=0)
# a = iter(sampler)
# print(str(next(a)))
# collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=1, eos_sym=1, blank_sym=0, debug=True))
# train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, collate_fn=collate_fn)
# x = iter(train_dl)
# print(str(next(x)))

View File

@ -46,6 +46,7 @@ def main(out_dir: str):
"4-gram.arpa.gz",
"librispeech-vocab.txt",
"librispeech-lexicon.txt",
"librispeech-lm-norm.txt.gz"
)
for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"):

View File

@ -191,4 +191,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
done
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
fi
cd data && ln -sfv lang_bpe_5000 lang_bpe