diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py new file mode 100644 index 000000000..75f603d9d --- /dev/null +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -0,0 +1,783 @@ +import torch +import torch.distributed as dist +import k2 +import _k2 +import sentencepiece as spm +from typing import Optional, List, Tuple + + + +class LmDataset(torch.utils.data.Dataset): + """ + Torch dataset for language modeling data. This is a map-style dataset. + The indices are integers. + """ + def __init__(self, sentences: k2.RaggedInt, + words: k2.RaggedInt): + super(LmDataset, self).__init__() + self.sentences = sentences + self.words = words + + + def __len__(self): + # Total size on axis 0, == num sentences + return self.sentences.tot_size(0) + + def __getitem__(self, i: int): + """ + Return the i'th sentence, as a list of ints (representing BPE pieces, without + bos or eos symbols). + """ + # It would be nicer if we could just return self.sentences[i].tolist(), but + # for now that operator on k2.RaggedInt is not implemented. + row_splits = self.sentences.row_splits(1) + (begin, end) = row_splits[i:i+2].tolist() + sentence = self.sentences.values()[begin:end] + return k2.index(self.words, sentence).values().tolist() + + +def load_train_test_lm_dataset(archive_fn: str, + test_proportion: float = 0.025) -> Tuple[LmDataset, LmDataset]: + """ + returns (train_lm_dataset, test_lm_dataset) + """ + + d = torch.load(archive_fn) + words = d['words'] # a k2.RaggedInt with 2 axes, maps from word-ids to sequences of BPE pieces + sentences = d['data'] # a k2.RaggedInt + + with torch.random.fork_rng(devices=[]): + g = torch.manual_seed(0) + num_sentences = sentences.tot_size(0) + # probably the generator (g) argument to torch.randperm below is not necessary. + sentence_perm = torch.randperm(num_sentences, generator=g, dtype=torch.int32) + sentences = k2.index(sentences, sentence_perm) + + num_test_sentences = int(num_sentences * test_proportion) + + axis=0 + train_sents = _k2.ragged_int_arange(sentences, axis, + num_test_sentences, num_sentences) + test_sents = _k2.ragged_int_arange(sentences, axis, 0, num_test_sentences) + + return LmDataset(train_sents, words), LmDataset(test_sents, words) + + +def mask_and_pad(sentence: List[int], + seq_len: int, + bos_sym: int, + eos_sym: int, + blank_sym: int, + mask_proportion: float, + padding_proportion: float, + inv_mask_length: float, + unmasked_weight: float) -> Tuple[List[int], List[int], List[int], List[float]]: + """ + This function contains part of the logic of collate_fn, broken out. It is responsible + for inserting masking and padding into the sequence `sentence`. Most of the arguments + are documented for `collate_fn` below. + Other args: + sentence: The original sentence to be masked and padded. + seq_len: The desired length of the lists to be returned + bos_sym, eos_sym, blank_sym, mask_proportion, + padding_proportion, inv_mask_length, unmasked_weight: see their documentation + as args to `collate_fn` below. + + + Return: a tuple (src, masked_src, tgt, weight, randomizable, attn_mask), all lists of length `seq_len`, + where: + `src` is: [bos] + [the sentence after inserting blanks in place of padding + after regions to be masked] + [eos] + [blank padding to seq_len]. + `src_masked` is as `src` but the masked regions have their values replaced with blank, + i.e. they are actually masked. + `tgt` is: [the original sentence, without masking] + [eos] + [blank] + [blank padding to seq_len] + `weight` is the weight at the nnet output, which is: `unmasked_weight` for un-masked + positions, 1.0 for masked and padded positions, and 0.0 for positions that + correspond to blank-padding after the final [eos]. + `randomizable` is a bool that is True for positions where the symbol in + in `src_masked` is not bos or eos or blank. + `attn_mask` is a bool that is False for positions in `src` and `src_masked` that + are between the initial [bos] and final [eos] inclusive; and True for + positions after the final [eos]. + """ + sent_len = len(sentence) + assert sent_len + 3 <= seq_len + + for w in sentence: + assert w not in [bos_sym, eos_sym, blank_sym] + + num_mask = int(torch.binomial(count=torch.tensor([sent_len * 1.0]), + prob=torch.tensor([mask_proportion])).item()) + num_pad = int(torch.poisson(torch.tensor([sent_len * padding_proportion])).item()) + # Ensure the total length after bos, padding of masked sequences, and eos, is + # no greater than seq_len + num_pad -= max(0, sent_len + 2 + num_pad - seq_len) + + if num_mask + num_pad == 0: + num_mask += 1 + + # num_split_points is the number of times we split the (masked+padded) + # region, so the total number of (masking+padding) subsequences will be + # num_split_points + 1. If num_mask positions are masked, then the + # remaining number of words is `sent_len - num_mask`, and any two + # masked regions must have at least one non-masked word between them, + # so num_split_points == number of masked regions - 1, must be + # no greater than `sent_len - num_mask`. The formula about + # mask_proportion * inv_mask_length / (1.0 - mask_proportion) + # is what's required (I think) so that inv_mask_length is the expected + # length of masked regions. + num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]), + prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item()) + assert num_split_points <= sent_len - num_mask + assert isinstance(num_split_points, int) + + def split_into_subseqs(length: int , num_subseqs: int) -> List[int]: + """Splits a sequence of `length` items into `num_subseqs` possibly-empty + subsequences. The length distributions are geometric, not Poisson, i.e. + we choose the split locations with uniform probability rather than + randomly assigning each word to one subsequences. This gives us more + shorter/longer subsequences. + Require num_subseqs > 0 + """ + boundaries = [0] + sorted(torch.randint(low=0, high=length + 1, size=(num_subseqs - 1,)).tolist()) + [length] + return [ boundaries[i + 1] - boundaries[i] for i in range(num_subseqs) ] + + mask_lengths = split_into_subseqs(num_mask, num_split_points + 1) + pad_lengths = split_into_subseqs(num_pad, num_split_points + 1) + # mask_pad_lengths contains only the (mask, pad) length pairs for which mask + pad > 0. + # From this point we only refer to the mask_pad_lengths. + mask_pad_lengths = [ (mask, pad) for (mask, pad) in zip(mask_lengths, pad_lengths) if mask+pad > 0 ] + num_subseqs = len(mask_pad_lengths) + assert num_subseqs > 0 + + # Now figure out how to distribute these subsequences throughout the actual + # sentence. The subsequences, if there are more than one, must not touch, + # i.e. there must be an actual word in between each subsequence, where the + # number of such "mandatory" words equals num_subseqs - 1. We also have to + # subtract `num_mask` words, since obviously the masked words cannot separate + # the masked regions. + reduced_len = sent_len - num_mask - (num_subseqs - 1) + assert reduced_len >= 0 + # unmasked_lengths will be the lengths of the un-masked regions between the masked + # regions. + unmasked_lengths = split_into_subseqs(reduced_len, num_subseqs + 1) + for i in range(1, num_subseqs): + # Unmasked regions between masked regions must have length at least 1, + # we add 1 to unmasked regions that are not initial/final. + unmasked_lengths[i] = unmasked_lengths[i] + 1 + assert sum(unmasked_lengths) + sum(mask_lengths) == sent_len + + + # src_positions will be: for each position in the masked+padded sentence, + # the corresponding position in the source sentence `sentence`; or -1 + # if this was padding. + src_positions = [] + # `masked` will be: for each position in the masked+padded sentence, True if + # it was masked and False otherwise. (Note: it is False for padding + # locations, although this will not matter in the end). + masked = [] + + cur_pos = 0 # current position in source sentence + for i in range(num_subseqs + 1): + for j in range(unmasked_lengths[i]): + src_positions.append(cur_pos) + masked.append(False) + cur_pos += 1 + if i < num_subseqs: + (mask_len, pad_len) = mask_pad_lengths[i] + for j in range(mask_len): + src_positions.append(cur_pos) + masked.append(True) + cur_pos += 1 + for j in range(pad_len): + src_positions.append(-1) + masked.append(False) + assert cur_pos == len(sentence) + + + src = [] + src_masked = [] + tgt = [] + weight = [] + randomizable = [] + + src.append(bos_sym) + src_masked.append(bos_sym) + randomizable.append(False) + for i, src_pos in enumerate(src_positions): + is_masked = masked[i] + if src_pos >= 0: + src_word = sentence[src_pos] + src_masked.append(blank_sym if masked[i] else src_word) + src.append(src_word) + tgt.append(src_word) + weight.append(1.0 if masked[i] else unmasked_weight) + randomizable.append(not masked[i]) + else: + # Padding inside a masked region + src_masked.append(blank_sym) + src.append(blank_sym) + tgt.append(blank_sym) + weight.append(1.0) + randomizable.append(False) + src.append(eos_sym) + src_masked.append(eos_sym) + tgt.append(eos_sym) + weight.append(unmasked_weight) + tgt.append(blank_sym) + weight.append(0.0) + randomizable.append(False) + + attn_mask = ([False] * len(src)) + ([True] * (seq_len - len(src))) + + for i in range(seq_len - len(src)): + src.append(blank_sym) + src_masked.append(blank_sym) + tgt.append(blank_sym) + weight.append(0.0) + randomizable.append(False) + + return (src, src_masked, tgt, weight, randomizable, attn_mask) + + +# dataset.mask_and_pad(list(range(10, 20)), seq_len=16, bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, inv_mask_length=0.33, unmasked_weight=0.444) + +# dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45))], bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, randomize_proportion=0.05, inv_mask_length=0.33, unmasked_weight=0.444) + +def collate_fn(sentences: List[List[int]], + bos_sym: int, + eos_sym: int, + blank_sym: int, + mask_proportion: float = 0.15, + padding_proportion: float = 0.15, + randomize_proportion: float = 0.05, + inv_mask_length: float = 0.25, + unmasked_weight: float = 0.25, + debug: bool = False) -> Tuple[torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor, + torch.Tensor]: + """ + Caution, this is not the collate_fn we give directly to the dataloader, + we give it a lambda: collate_fn=(lambda x: dataset.collate_fn(x, [other args])) + This formats a list-of-lists-of-int into 5 Tensors, explained below. + The key thing is that we mask out subsequences of random length within + these sentences, and force the network to predict the masked-out + subsequences (which have blanks appended to them to prevent the model + from knowing the exact length of the sequences it has to predict). + So it's like BERT but at the level of sequences rather than individual + words. + + Args: + bos_sym: the integer id of the beginning-of-sentence symbol, e.g. 2. + Is allowed be the same as eos_sym (we are not necessarily + saying it will work best that way). + eos_sym: the integer id of the end-of-sentence symbol, e.g. 2. + blank_sym: the integer id of the blank symbol, e.g. 0 or 1. + mask_proportion: The proportion of words in each sentence that + are masked, interpreted as (roughly) the probability of any given + word being masked, although the masked locations will + tend to be in contiguous sequences (they are not independent). + padding_proportion: Like mask_proportion, but determines the + number of extra, blank symbols that are inserted as padding + at the end of masked regions (this ensures that the model + cannot know exactly how many words need to be inserted in + any given masked region. + randomize_proportion: The probability with which we replace + words that were not masked with randomly chosen words. + Like BERT, this is intended to force the model to predict + something reasonable at non-masked positions, and to make + this task harder than simply repeating the input. + inv_mask_length: This number determines how many separate + sub-sequences the (masked + padded) proportion of a sentence is split up + into, interpreted as the inverse of the expected length of + each *masked* region. + unmasked_weight: The weight to be applied to the log-likelihoods of + un-masked positions in sentences (predicting un-masked + positions is not completely trivial if randomize_proportion > 0). + Will be reflected in the returned tgt_weights tensor. + + Returns a tuple (masked_src_symbols, src_symbols, + tgt_symbols, src_attn_mask, + tgt_weights), + all with 2 axes and the same shape: (num_sent, seq_len). + Their dtypes will be, respectively, + (torch.int64, torch.int64, + torch.int64, torch.bool, + torch.float) + masked_src_symbols: The sentences, with bos_symbol prepended and eos_symbol + appended, masked regions (including padding) replaced with blank, + and `randomize_proportion` non-masked symbols replaced with + symbols randomly taken from elsewhere in the sentences of this + minibatch. Then padded to a fixed length with blank. + src_symbols: Like masked_src_symbols, except with the masked symbols replaced + with the original symbols (but the padding that follows each + masked sub-sequence will still be blank) + tgt_symbols: The original sentences, with eos_symbol appended, and then + padded with blank to the same length as masked_symbols and + src_symbols. + src_attn_mask: Masking tensor for masked_src_symbols and src_symbols, to + account for all the sentence lengths not being identical + (makes each sentence's processing independent of seq_len). + Tensor of Bool of shape (num_sent, seq_len), with True + for masked positions (these are the blanks that follow the + eos_symbol in masked_src_symbols), False for un-masked positions. + tgt_weights: Weights that will be applied to the log-probabilities at + the output of the network. Will have 1.0 in positions + in `tgt_symbols` that were masked (including blank + padding at the end of masked regions), `unmasked_weight` + in other positions in the original sentences (including + terminating eos_symbol); and 0.0 in the remaining positions + corresponding to blank padding after the ends of + sentences. + """ + assert blank_sym not in [bos_sym, eos_sym] + max_sent_len = max([ len(s) for s in sentences]) + + typical_mask_and_pad = int(max_sent_len * (mask_proportion + padding_proportion)) + + # The following formula gives roughly 1 standard deviation above where we'd + # expect the maximum sentence length to be with masking and padding.. we use + # this as a hard upper limit, to prevent outliers from affecting the batch + # size too much. We use this as the size `seq_len`. + # The "+ 4" is to ensure there is always room for the BOS, EOS and at least + # two padding symbols. + seq_len = max_sent_len + 4 + typical_mask_and_pad + int(typical_mask_and_pad ** 0.5) + + + # srcs, srcs_masked, tgts and weights will be lists of the lists returned + # from `mask_and_pad`, one per sentence. + srcs = [] + srcs_masked = [] + tgts = [] + weights = [] + randomizables = [] + attn_masks = [] + for s in sentences: + (src, src_masked, tgt, + weight, randomizable, + attn_mask) = mask_and_pad(s, seq_len, bos_sym, eos_sym, + blank_sym, mask_proportion, padding_proportion, + inv_mask_length, unmasked_weight) + srcs.append(src) + srcs_masked.append(src_masked) + tgts.append(tgt) + weights.append(weight) + randomizables.append(randomizable) + attn_masks.append(attn_mask) + + src_symbols = torch.tensor(srcs, dtype=torch.int64) + masked_src_symbols = torch.tensor(srcs_masked, dtype=torch.int64) + tgt_symbols = torch.tensor(tgts, dtype=torch.int64) + src_attn_mask = torch.tensor(attn_masks, dtype=torch.bool) + tgt_weights = torch.tensor(weights, dtype=torch.float) + + attn_mask_sum = torch.sum(torch.logical_not(src_attn_mask), dim=0).tolist() + while attn_mask_sum[-1] == 0: # Remove always-masked positions at the endof the lists. + attn_mask_sum.pop() + if len(attn_mask_sum) < seq_len: + seq_len = len(attn_mask_sum) + (src_symbols, masked_src_symbols, + tgt_symbols, src_attn_mask, tgt_weights) = (src_symbols[:,:seq_len], masked_src_symbols[:,:seq_len], + tgt_symbols[:,:seq_len], src_attn_mask[:,:seq_len], + tgt_weights[:,:seq_len]) + + if randomize_proportion > 0.0: + randomizable_tensor = torch.tensor(randomizables, dtype=torch.bool) + randomizable_indexes = torch.nonzero(randomizable_tensor) # (num_randomizable, 2) + num_randomizable = randomizable_indexes.shape[0] + + to_randomize_indexes = torch.nonzero(torch.rand(num_randomizable) < randomize_proportion, as_tuple=True)[0] + num_to_randomize = to_randomize_indexes.numel() + + # older versions of torch don't have tensor_split, so fake a simplified version of it. + # we'd be calling it as xxx.tensor_split(dim=1) if really in torc. + def tensor_split(t): + return (t[:,0], t[:,1]) + + random_src_locations = torch.randperm(num_randomizable)[:num_to_randomize] + + random_symbols = src_symbols[tensor_split(randomizable_indexes[random_src_locations])] + random_indexes_tuple= tensor_split(randomizable_indexes[to_randomize_indexes]) + src_symbols[random_indexes_tuple] = random_symbols + masked_src_symbols[random_indexes_tuple] = random_symbols + + + # I set this to true and tested with: + # python3 -c 'import dataset; dataset.collate_fn(sentences=[ list(range(100, 200)), list(range(300, 450)), list(range(500,600))], bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, randomize_proportion=0.05, inv_mask_length=0.33, unmasked_weight=0.444)' + #.. and ran a few times to check the values printed looked about right, and that no assertions failed. + if debug: + check_collated_tensors(sentences, bos_sym, eos_sym, blank_sym, + unmasked_weight, + masked_src_symbols, src_symbols, + tgt_symbols, src_attn_mask, tgt_weights) + return (masked_src_symbols, src_symbols, + tgt_symbols, src_attn_mask, tgt_weights) + + + +def check_collated_tensors(sentences: List[List[int]], + bos_sym: int, + eos_sym: int, + blank_sym: int, + unmasked_weight: float, + masked_src_symbols, src_symbols, + tgt_symbols, src_attn_mask, + tgt_weights): + """ + This function checks the output of collate_fn, consider it test code. Please see + the documentation of collate_fn to understand the args. + """ + for t in src_symbols, tgt_symbols, src_attn_mask, tgt_weights: + assert t.shape == masked_src_symbols.shape + + tot_positions = src_symbols.numel() + + masked_src_symbols, src_symbols, tgt_symbols, src_attn_mask, tgt_weights = ( + masked_src_symbols.tolist(), src_symbols.tolist(), tgt_symbols.tolist(), + src_attn_mask.tolist(), tgt_weights.tolist()) + assert len(sentences) == len(masked_src_symbols) + + tot_masked_positions = 0 + tot_padded_positions = 0 + tot_unmasked_positions = 0 # all un-masked, non-blank postions, including eos + tot_randomized_positions = 0 + num_masked_subseqs = 0 + tot_symbols = 0 # original symbols in sentences, no bos/eos + + assert unmasked_weight > 0.001 # or this test code won't work.. + + for i in range(len(sentences)): + reconstructed_sent = list(filter(lambda x: x not in [bos_sym,eos_sym,blank_sym], tgt_symbols[i])) + if sentences[i] != reconstructed_sent: + print(f"Error: sentence {i}={sentences[i]} differs from {reconstructed_sent}") + (masked_src, src, tgt, src_mask, weights) = (masked_src_symbols[i], src_symbols[i], + tgt_symbols[i], src_attn_mask[i], tgt_weights[i]) + + assert src[0] == masked_src[0] == bos_sym + for j in range(len(masked_src)): + assert masked_src[j] == blank_sym or masked_src[j] == src[j] + + if src[j] not in [bos_sym, eos_sym, blank_sym]: + tot_symbols += 1 + + if j > 0: + assert (src[j] == eos_sym) == (masked_src[j] == eos_sym) == (tgt[j-1] == eos_sym) + if masked_src[j] == blank_sym: # masked or padding of masked subseq, or post-eos padding.. + assert src[j] == tgt[j - 1] # masked symbols are not randomized. + assert weights[j - 1] in [0.0, 1.0] # 0.0 for final blank padding + if weights[j - 1] == 1.0: # Not final blank padding... + if tgt[j - 1] == blank_sym: + tot_padded_positions += 1 + else: + tot_masked_positions += 1 + if masked_src[j + 1] != blank_sym: + num_masked_subseqs += 1 + else: + assert weights[j - 1] == 0 or abs(weights[j-1] - unmasked_weight) < 0.001 + if abs(weights[j - 1]-unmasked_weight) < 0.001: + tot_unmasked_positions += 1 + if tgt[j - 1] != src[j]: + tot_randomized_positions += 1 + + if src_mask[j]: # if masked.. + assert src[j] == blank_sym + + assert tot_symbols == sum(len(x) for x in sentences) + + assert tot_unmasked_positions + tot_masked_positions == tot_symbols + len(sentences) + + print(f"{tot_unmasked_positions} + {tot_masked_positions} == {tot_symbols} + {len(sentences)}") + print(f"tot_symbols / tot_positions = {tot_symbols/tot_positions} (rest is bos,eos,padding)") + + print(f"Masking/tot_symbols = {tot_masked_positions/tot_symbols}, Padding/tot_symbols = {tot_padded_positions/tot_symbols}") + print(f"Randomization/tot_non_masked_symbols = {tot_randomized_positions/(tot_symbols-tot_masked_positions)}") + print(f"Mean masking length = {tot_masked_positions/num_masked_subseqs}, Mean padding length = {tot_padded_positions/num_masked_subseqs}") + + + +# This shows some useful code about the BPE encoding. +# import sentencepiece as spm +# sp = spm.SentencePieceProcessor() +# sp.load(bpe_model_fn) # bpe.model +# sp.GetPieceSize(..) +# sp.Decode(...) +# sp.Encode(...) + + +# import dataset +# import torch +# train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt') + + +# train_dl = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True, collate_fn=(lambda x: train.collate_fn(x))) +# x = iter(train_dl) +# str(next(x)) +# '[ [ 10 38 651 593 3 1343 31 780 6 4172 112 788 1696 24 289 24 3 403 6 4493 162 92 71 328 417 217 338 14 5 3 1876 154 21 23 2237 43 3 1535 92 71 2816 7 1031 31 2318 92 2528 4806 14 206 3 954 1373 6 525 4 631 447 2639 ] [ 1014 336 171 209 795 10 16 90 27 787 139 53 45 2817 ] [ 11 980 51 22 1748 14 91 105 363 428 6 8 2887 3305 2525 2297 70 3 4651 6 27 282 335 426 134 292 5 193 3 539 2250 584 127 ] [ 9 3 1858 4 18 2257 4 6 41 748 10 304 7 229 83 2793 4 9 981 7 1484 33 3 103 7 539 5 477 3195 18 64 39 82 1034 6 3 4128 ] [ 17 147 22 7 708 60 133 174 105 4111 4 6 3 1384 65 50 1051 9 2953 6 3 461 180 1142 23 5 36 888 8 131 173 390 78 23 266 2822 715 46 182 65 22 1739 33 3 700 1450 14 233 4 ] [ 80 10 16 67 279 7 1827 264 96 3 187 2851 2108 ] [ 1473 48 106 227 9 160 2011 4 674 ] [ 3 954 762 29 85 228 33 8 940 40 4952 36 486 390 595 3 81 225 6 1440 125 346 134 296 126 419 1017 3824 4 8 179 184 11 33 580 1861 ] [ 30 22 245 15 117 8 2892 28 1204 145 7 3 236 3417 6 3 3839 5 3106 155 198 30 228 2555 46 15 32 41 747 72 9 25 977 ] [ 222 466 6 3157 ] ]' +# +# or: +# import k2 +# k2.ragged.to_list(next(x)) +# [shows something similar]. +# +# You'd really do something like: +# for epoch in range(max_epochs): +# for minibatch in train_dl: + + +# .. How to process data? Suppose we have a sentence like [259, 278, 45, 11, 303, 1319, 34, 15, 396, 3435, 7, 44]. +# +# First: we randomly choose one or more starting positins for a masked segment. +# Each sentence must have at least one masked segment (or there is no contribution to the loss function). +# We choose to have: +# num_masked_segments = max(1, len(sent) // 15) +# +# The length of the masked segment (this is the target for prediction), we set to the geometric +# distribution with the probability of success set to 3: +# +# g = torch.distributions.geometric.Geometric(probs=0.3) # <-- expected value is 3.333 +# Example of sampling: +# g.sample(sample_shape=torch.Size([10])) +# +# We now we randomly compute the location of the masked segments (length computed above) as follows: +# First, the masked segments must be separated by at least one non-masked word (else they would be +# a single segment). So for n masked segments, there are n-1 words required for minimal separation. +# If tot-length-of-segments + n-1 is greater than the sentence length, we just have the entire +# sentence be masked. Otherwise, we randomly divide the remaining number of words between the n+1 +# positions where they can appear (e.g. for 2 segments, this would be at the start, between the 2 segments, +# and at the end). This is the multinomial distribution, but we can more easily compute this +# directly using rand() and cutoffs, rather than creating a torch.distributions.Multinomial(). +# + +# Next we need to compute a random amount of blank padding (>= 0) for each of the masked regions; +# this is done so the model never knows the exact length of the masked region. We can just use the +# same distribution as for the length of the masked regions, i.e. geometric with success-prob=0.3 +# (expected padding length is 3). +# +# At this point we know where the masked regions are and how much padding they have. We can format +# the result as three lists, of the same length: +# +# sent: contains the words in the sentence with, in masked +# positions, the original (target) words, then with +# blank in the blank-padding after masked positions. +# +# sent_augmented: `sent` with, at a small defined percentage of positions +# that were *not* masked, the real token replaced with a +# token randomly chosen from the tokens in the minibatch. +# (like BERT, we use this type of augmentation, so the model +# has to predict the original token). +# +# masked_sent_augmented: List[int], contains the words in `sent_augmented`, except +# with masked positions and the blank padding after the masked regions +# both replaced with blank. +# +# +# +# The way these will be processed is as follows: +# +# masked_sent_in = [bos] + masked_sent_augmented + [eos] <-- so we know the sentence ended, distinguish it from truncated ones. +# sent_in = [bos] + sent_augmented + [eos] +# +# sent_out = sent + [eos] + [eos] #<--- the predicted targets at each point, although +# # we only really care about this in masked regions. +# # The extra eos is so that the length is the same as +# # masked_sent_in and sent_in. +# +# out_scale = (masked_sent==blk ? 1.0 : non_masked_scale) # e.g. non_masked_scale = 1.0 is fine, +# # this is a choice; we can perhaps +# # report these 2 parts of the loss +# # separately though. +# # <-- can also set the last element +# # of out_scale to a smaller number, since +# # it's a repeated eos. +# +# +# OK, how do we combine these into a minibatch? Firstly, we truncate sentences to a maximum +# length, e.g. 128, if `masked_sent_in`/`sent_in` have length longer than that. We choose randomly +# in each case to truncate the beginning or end, truncating both masked_sent_in/sent_in and sent_out +# from the same side. Caution: this means that these sentences may lack bos and/or eos symbols. +# +# Next, we combine shorter utterances by appending them ( all of: masked_sent_in, sent_in, out_scale) +# as long as doing so would keep the total length under 128. We then pad (masked_sent_in, sent_in, sent_out, out_scale) +# with: (,,, 0) up to the maximum length of any sentence in the minibatch <- or could use +# +# +# +# +# +# +# +# # i.e. ones where masked_sent is blank and zeros elsewhere; +# # this pertains to positions in `sent_out`. +# +# +# +# +# +# +# +# +# +# +# torch.distributions.gamma.Gamma(concentration=1.0, rate=1.0/5) + + + + +class LmBatchSampler(torch.utils.data.Sampler): + """ + A sampler that returns a batch of integer indexes as a list, intended for use + with class LmDataset. The sentences returned in each batch will all be about + the same size, and the batch size is specified as a number of words (we also + provide an option that allows you to limit the max memory consumed by transformers) + + Has support for distributed operation. + """ + def __init__(self, dataset: LmDataset, + symbols_per_batch: int, + quadratic_constant: float = 0.005, + world_size: Optional[int] = None, + rank: int = None, + seed: int = 0): + """ + Constructor documentation: + dataset: the LmDataset object that we are sampling from. This + class does not retain a reference to the LmDataset. + symbols_per_batch: The number of BPE symbols desired in each minibatch + quadratic_constant: After the sentence length gets more than about + 1.0/quadratic_constant, the batch size will start decreasing + as 1/(sentence-length^2). This is a mechanism to + avoid excessive memory consumption in transformers, when + sentence length gets long. + world_size: The world size for distributed operation; if None, + will be worked out from torch.distributed. + rank: The rank of this sampler/process for distributed operation; if None, + will be worked out from torch.distributed. + seed: The random seed + """ + self.seed = seed + self.symbols_per_batch = symbols_per_batch + self.quadratic_constant = quadratic_constant + self._maybe_init_distributed(world_size=world_size, rank=rank) + + # a configuration constant we don't expose. + self.multiplicative_random_length = 0.05 + + # "indexes" is the subset of indexes into LmDataset that this + # sampler is reponsible for (all of them, in the non-distributed case). + data_indexes = torch.arange(self.rank, len(dataset), self.world_size, dtype=torch.int32) # dtype=torch.int32 + + word_row_splits = dataset.words.row_splits(1) # dtype=torch.int32 + word_lengths = word_row_splits[1:] - word_row_splits[:-1] # dtype=torch.int32 + + # the sentences this sampler is responsible for, as sequences of words. + # It's a ragged tensor of int32 + sentences = k2.index(dataset.sentences, data_indexes) + + # sentence_lengths is a k2.RaggedInt like `sentences`, but with the words replaced + # with their respective lengths, in BPE pieces. + sentence_lengths = k2.index(word_lengths, sentences) + del sentences # save memory + assert isinstance(sentence_lengths, k2.RaggedInt) + + # convert to float so sum_per_sublist() will work (TODO: sum_per_sublist() will eventually + # support int32.) + sentence_lengths = k2.RaggedFloat(sentence_lengths.shape(), + sentence_lengths.values().to(torch.float32)) + assert isinstance(sentence_lengths, k2.RaggedFloat) + + # Convert into a simple tensor of float by adding lengths of words. + sentence_lengths = k2.ragged.sum_per_sublist(sentence_lengths) + + assert isinstance(sentence_lengths, torch.Tensor) + assert sentence_lengths.dtype == torch.float32 + + # self.sentence_lengths is a Tensor with dtype=torch.float32. It + # contains the lengths, in BPE tokens, of the sentences that this + # sampler is responsible for, whose real indexes are in + # `data_indexes` above (this is not stored, as we know the formula). + self.sentence_lengths = sentence_lengths + + self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes + + + def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]): + if world_size is not None: + assert world_size >= 1 + if rank is not None: + assert rank >= 0 + if not dist.is_available() or not dist.is_initialized(): + self.world_size = 1 if world_size is None else world_size + self.rank = 0 if rank is None else rank + return + self.world_size = dist.get_world_size() if world_size is None else world_size + self.rank = dist.get_rank() if rank is None else rank + assert self.rank < self.world_size + + def set_epoch(self, epoch: int): + """ + Must be called at the beginning of each epoch, before initializing the DataLoader, + to re-shuffle the data. If this is not done, this sampler will give you the same batches + each time it is called. + """ + g = torch.manual_seed(self.rank + self.seed + epoch) + + sentence_lengths = (self.sentence_lengths * + (1.0 + torch.rand(*self.sentence_lengths.shape, generator=g) * self.multiplicative_random_length)) + + # This mechanism regulates the batch size so that we don't get OOM in transformers + # when the sentences are long. + sentence_lengths = sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant + + values, indices = torch.sort(sentence_lengths) # values,indices dtypes: torch.float,torch.int64 + + # map to the original indexes into the dataset (the original sentence + # indexes), see torch.arange expression in the constructor. save as + # int32 just to save a little memory. self.indices are indexes into the + # LmDataset, just including the subset of indices that this sampler is + # responsible for (in terms of rank and world_size), and sorted by + # length with a small amount of randomization specific to the epoch. + self.indices = ((indices * self.world_size) + self.rank).to(dtype=torch.int32) + + # now `batch_ids` will be: [0, 0, 0, 0, .., 0, 1, 1, 1, ... 1, 2, ... ], + # saying which batch each element of values/indices belongs to. + batch_ids = (torch.cumsum(values, dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32) + + batch_boundaries = torch.nonzero(batch_ids[1:] - batch_ids[:-1], as_tuple=True)[0] + batch_boundaries.add_(1) + self.batch_boundaries = torch.cat((torch.zeros(1, dtype=torch.int32), batch_boundaries), dim=0) + + num_batches = self.batch_boundaries.numel() - 1 + + # self.batch_indices is a permutation of [0, 1, ... num_batches - + # 1]; it determines the order in which we access the batches. It's + # necessary to randomize the order of these, to avoid returning batches + # from shortest to longest sentences. + self.batch_indices = torch.randperm(num_batches, generator=g, dtype=torch.int32).tolist() + + + def __len__(self): + return len(self.batch_indices) + + def __iter__(self): + """ + Iterator that yields lists of indices (i.e., integer indices into the LmDataset) + """ + for batch_idx in self.batch_indices: + batch_start = self.batch_boundaries[batch_idx].item() + batch_end = self.batch_boundaries[batch_idx + 1].item() + yield self.indices[batch_start:batch_end].tolist() + + + + + + +# train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt') +# sampler = dataset.LmBatchSampler(test, symbols_per_batch=1000, world_size=2, rank=0) +# a = iter(sampler) +# print(str(next(a))) + +# collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)) +# train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, collate_fn=collate_fn) +# x = iter(train_dl) +# print(str(next(x))) diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 5c9e2a675..e78c6d9f3 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -46,6 +46,7 @@ def main(out_dir: str): "4-gram.arpa.gz", "librispeech-vocab.txt", "librispeech-lexicon.txt", + "librispeech-lm-norm.txt.gz" ) for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"): diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index f06e013f6..798a30631 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -191,4 +191,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then done fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + +fi + cd data && ln -sfv lang_bpe_5000 lang_bpe