mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
parent
3441634f34
commit
19828cbf22
1
egs/librispeech/ASR/conformer_lm/asr_datamodule.py
Symbolic link
1
egs/librispeech/ASR/conformer_lm/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../tdnn_lstm_ctc/asr_datamodule.py
|
1484
egs/librispeech/ASR/conformer_lm/conformer.py
Normal file
1484
egs/librispeech/ASR/conformer_lm/conformer.py
Normal file
File diff suppressed because it is too large
Load Diff
823
egs/librispeech/ASR/conformer_lm/dataset.py
Normal file
823
egs/librispeech/ASR/conformer_lm/dataset.py
Normal file
@ -0,0 +1,823 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import k2
|
||||||
|
import _k2
|
||||||
|
import logging
|
||||||
|
import sentencepiece as spm
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LmDataset(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
Torch dataset for language modeling data. This is a map-style dataset.
|
||||||
|
The indices are integers.
|
||||||
|
"""
|
||||||
|
def __init__(self, sentences: k2.RaggedTensor,
|
||||||
|
words: k2.RaggedTensor):
|
||||||
|
super(LmDataset, self).__init__()
|
||||||
|
self.sentences = sentences
|
||||||
|
self.words = words
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
# Total size on axis 0, == num sentences
|
||||||
|
return self.sentences.tot_size(0)
|
||||||
|
|
||||||
|
def __getitem__(self, i: int):
|
||||||
|
"""
|
||||||
|
Return the i'th sentence, as a list of ints (representing BPE pieces, without
|
||||||
|
bos or eos symbols).
|
||||||
|
"""
|
||||||
|
# in future will just do:
|
||||||
|
#return self.words[self.sentences[i]].tolist()
|
||||||
|
|
||||||
|
# It would be nicer if we could just return self.sentences[i].tolist(),
|
||||||
|
# but for now that operator on k2.RaggedInt does not support when the
|
||||||
|
# ragged has only 2 axes.
|
||||||
|
row_splits = self.sentences.shape.row_splits(1)
|
||||||
|
(begin, end) = row_splits[i:i+2].tolist()
|
||||||
|
sentence = self.sentences.data[begin:end]
|
||||||
|
sentence, _ = self.words.index(sentence, axis=0, need_value_indexes=False)
|
||||||
|
return sentence.data.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def load_train_test_lm_dataset(archive_fn: Union[str,Path],
|
||||||
|
test_proportion: float = 0.025) -> Tuple[LmDataset, LmDataset]:
|
||||||
|
"""
|
||||||
|
returns (train_lm_dataset, test_lm_dataset)
|
||||||
|
"""
|
||||||
|
|
||||||
|
d = torch.load(archive_fn)
|
||||||
|
words = d['words'] # a k2.RaggedTensor with 2 axes, maps from word-ids to sequences of BPE pieces
|
||||||
|
sentences = d['data'] # a k2.RaggedTensor
|
||||||
|
|
||||||
|
with torch.random.fork_rng(devices=[]):
|
||||||
|
g = torch.manual_seed(0)
|
||||||
|
num_sentences = sentences.tot_size(0)
|
||||||
|
# probably the generator (g) argument to torch.randperm below is not necessary.
|
||||||
|
sentence_perm = torch.randperm(num_sentences, generator=g, dtype=torch.int32)
|
||||||
|
sentences, _ = sentences.index(sentence_perm, axis=0, need_value_indexes=False)
|
||||||
|
|
||||||
|
num_test_sentences = int(num_sentences * test_proportion)
|
||||||
|
|
||||||
|
axis=0
|
||||||
|
train_sents = sentences.arange(axis, num_test_sentences, num_sentences)
|
||||||
|
test_sents = sentences.arange(axis, 0, num_test_sentences)
|
||||||
|
|
||||||
|
return LmDataset(train_sents, words), LmDataset(test_sents, words)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_and_pad(sentence: List[int],
|
||||||
|
seq_len: int,
|
||||||
|
bos_sym: int,
|
||||||
|
eos_sym: int,
|
||||||
|
blank_sym: int,
|
||||||
|
mask_proportion: float,
|
||||||
|
padding_proportion: float,
|
||||||
|
inv_mask_length: float,
|
||||||
|
unmasked_weight: float) -> Tuple[List[int], List[int], List[int], List[float]]:
|
||||||
|
"""
|
||||||
|
This function contains part of the logic of collate_fn, broken out. It is responsible
|
||||||
|
for inserting masking and padding into the sequence `sentence`. Most of the arguments
|
||||||
|
are documented for `collate_fn` below.
|
||||||
|
Other args:
|
||||||
|
sentence: The original sentence to be masked and padded.
|
||||||
|
seq_len: The desired length of the lists to be returned
|
||||||
|
bos_sym, eos_sym, blank_sym, mask_proportion,
|
||||||
|
padding_proportion, inv_mask_length, unmasked_weight: see their documentation
|
||||||
|
as args to `collate_fn` below.
|
||||||
|
|
||||||
|
|
||||||
|
Return: a tuple (src, masked_src, tgt, weight, randomizable, attn_mask), all lists of length `seq_len`,
|
||||||
|
where:
|
||||||
|
`src` is: [bos] + [the sentence after inserting blanks in place of padding
|
||||||
|
after regions to be masked] + [eos] + [blank padding to seq_len].
|
||||||
|
`src_masked` is as `src` but the masked regions have their values replaced with blank,
|
||||||
|
i.e. they are actually masked.
|
||||||
|
`tgt` is: [the original sentence, without masking] + [eos] + [blank] + [blank padding to seq_len]
|
||||||
|
`weight` is the weight at the nnet output, which is: `unmasked_weight` for un-masked
|
||||||
|
positions, 1.0 for masked and padded positions, and 0.0 for positions that
|
||||||
|
correspond to blank-padding after the final [eos].
|
||||||
|
`randomizable` is a bool that is True for positions where the symbol in
|
||||||
|
in `src_masked` is not bos or eos or blank.
|
||||||
|
`attn_mask` is a bool that is False for positions in `src` and `src_masked` that
|
||||||
|
are between the initial [bos] and final [eos] inclusive; and True for
|
||||||
|
positions after the final [eos].
|
||||||
|
"""
|
||||||
|
sent_len = len(sentence)
|
||||||
|
assert sent_len + 3 <= seq_len
|
||||||
|
|
||||||
|
for w in sentence:
|
||||||
|
assert w not in [bos_sym, eos_sym, blank_sym]
|
||||||
|
|
||||||
|
num_mask = int(torch.binomial(count=torch.tensor([sent_len * 1.0]),
|
||||||
|
prob=torch.tensor([mask_proportion])).item())
|
||||||
|
num_pad = int(torch.poisson(torch.tensor([sent_len * padding_proportion])).item())
|
||||||
|
# Ensure the total length after bos, padding of masked sequences, and eos, is
|
||||||
|
# no greater than seq_len
|
||||||
|
num_pad -= max(0, sent_len + 2 + num_pad - seq_len)
|
||||||
|
|
||||||
|
if num_mask + num_pad == 0:
|
||||||
|
num_pad += 1
|
||||||
|
|
||||||
|
# num_split_points is the number of times we split the (masked+padded)
|
||||||
|
# region, so the total number of (masking+padding) subsequences will be
|
||||||
|
# num_split_points + 1. If num_mask positions are masked, then the
|
||||||
|
# remaining number of words is `sent_len - num_mask`, and any two
|
||||||
|
# masked regions must have at least one non-masked word between them,
|
||||||
|
# so num_split_points == number of masked regions - 1, must be
|
||||||
|
# no greater than `sent_len - num_mask`. The formula about
|
||||||
|
# mask_proportion * inv_mask_length / (1.0 - mask_proportion)
|
||||||
|
# is what's required (I think) so that inv_mask_length is the expected
|
||||||
|
# length of masked regions.
|
||||||
|
num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]),
|
||||||
|
prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item())
|
||||||
|
# Somehow this assertion failed, debugging it below.
|
||||||
|
assert num_split_points <= sent_len - num_mask
|
||||||
|
|
||||||
|
assert isinstance(num_split_points, int)
|
||||||
|
|
||||||
|
def split_into_subseqs(length: int , num_subseqs: int) -> List[int]:
|
||||||
|
"""Splits a sequence of `length` items into `num_subseqs` possibly-empty
|
||||||
|
subsequences. The length distributions are geometric, not Poisson, i.e.
|
||||||
|
we choose the split locations with uniform probability rather than
|
||||||
|
randomly assigning each word to one subsequences. This gives us more
|
||||||
|
shorter/longer subsequences.
|
||||||
|
Require num_subseqs > 0
|
||||||
|
"""
|
||||||
|
boundaries = [0] + sorted(torch.randint(low=0, high=length + 1, size=(num_subseqs - 1,)).tolist()) + [length]
|
||||||
|
return [ boundaries[i + 1] - boundaries[i] for i in range(num_subseqs) ]
|
||||||
|
|
||||||
|
mask_lengths = split_into_subseqs(num_mask, num_split_points + 1)
|
||||||
|
pad_lengths = split_into_subseqs(num_pad, num_split_points + 1)
|
||||||
|
# mask_pad_lengths contains only the (mask, pad) length pairs for which mask + pad > 0.
|
||||||
|
# From this point we only refer to the mask_pad_lengths.
|
||||||
|
mask_pad_lengths = [ (mask, pad) for (mask, pad) in zip(mask_lengths, pad_lengths) if mask+pad > 0 ]
|
||||||
|
num_subseqs = len(mask_pad_lengths)
|
||||||
|
assert num_subseqs > 0
|
||||||
|
|
||||||
|
# Now figure out how to distribute these subsequences throughout the actual
|
||||||
|
# sentence. The subsequences, if there are more than one, must not touch,
|
||||||
|
# i.e. there must be an actual word in between each subsequence, where the
|
||||||
|
# number of such "mandatory" words equals num_subseqs - 1. We also have to
|
||||||
|
# subtract `num_mask` words, since obviously the masked words cannot separate
|
||||||
|
# the masked regions.
|
||||||
|
reduced_len = sent_len - num_mask - (num_subseqs - 1)
|
||||||
|
assert reduced_len >= 0
|
||||||
|
# unmasked_lengths will be the lengths of the un-masked regions between the masked
|
||||||
|
# regions.
|
||||||
|
unmasked_lengths = split_into_subseqs(reduced_len, num_subseqs + 1)
|
||||||
|
for i in range(1, num_subseqs):
|
||||||
|
# Unmasked regions between masked regions must have length at least 1,
|
||||||
|
# we add 1 to unmasked regions that are not initial/final.
|
||||||
|
unmasked_lengths[i] = unmasked_lengths[i] + 1
|
||||||
|
assert sum(unmasked_lengths) + sum(mask_lengths) == sent_len
|
||||||
|
|
||||||
|
|
||||||
|
# src_positions will be: for each position in the masked+padded sentence,
|
||||||
|
# the corresponding position in the source sentence `sentence`; or -1
|
||||||
|
# if this was padding.
|
||||||
|
src_positions = []
|
||||||
|
# `masked` will be: for each position in the masked+padded sentence, True if
|
||||||
|
# it was masked and False otherwise. (Note: it is False for padding
|
||||||
|
# locations, although this will not matter in the end).
|
||||||
|
masked = []
|
||||||
|
|
||||||
|
cur_pos = 0 # current position in source sentence
|
||||||
|
for i in range(num_subseqs + 1):
|
||||||
|
for j in range(unmasked_lengths[i]):
|
||||||
|
src_positions.append(cur_pos)
|
||||||
|
masked.append(False)
|
||||||
|
cur_pos += 1
|
||||||
|
if i < num_subseqs:
|
||||||
|
(mask_len, pad_len) = mask_pad_lengths[i]
|
||||||
|
for j in range(mask_len):
|
||||||
|
src_positions.append(cur_pos)
|
||||||
|
masked.append(True)
|
||||||
|
cur_pos += 1
|
||||||
|
for j in range(pad_len):
|
||||||
|
src_positions.append(-1)
|
||||||
|
masked.append(False)
|
||||||
|
assert cur_pos == len(sentence)
|
||||||
|
|
||||||
|
|
||||||
|
src = []
|
||||||
|
src_masked = []
|
||||||
|
tgt = []
|
||||||
|
weight = []
|
||||||
|
randomizable = []
|
||||||
|
|
||||||
|
src.append(bos_sym)
|
||||||
|
src_masked.append(bos_sym)
|
||||||
|
randomizable.append(False)
|
||||||
|
for i, src_pos in enumerate(src_positions):
|
||||||
|
is_masked = masked[i]
|
||||||
|
if src_pos >= 0:
|
||||||
|
src_word = sentence[src_pos]
|
||||||
|
src_masked.append(blank_sym if masked[i] else src_word)
|
||||||
|
src.append(src_word)
|
||||||
|
tgt.append(src_word)
|
||||||
|
weight.append(1.0 if masked[i] else unmasked_weight)
|
||||||
|
randomizable.append(not masked[i])
|
||||||
|
else:
|
||||||
|
# Padding inside a masked region
|
||||||
|
src_masked.append(blank_sym)
|
||||||
|
src.append(blank_sym)
|
||||||
|
tgt.append(blank_sym)
|
||||||
|
weight.append(1.0)
|
||||||
|
randomizable.append(False)
|
||||||
|
src.append(eos_sym)
|
||||||
|
src_masked.append(eos_sym)
|
||||||
|
tgt.append(eos_sym)
|
||||||
|
weight.append(unmasked_weight)
|
||||||
|
tgt.append(blank_sym)
|
||||||
|
weight.append(0.0)
|
||||||
|
randomizable.append(False)
|
||||||
|
|
||||||
|
attn_mask = ([False] * len(src)) + ([True] * (seq_len - len(src)))
|
||||||
|
|
||||||
|
for i in range(seq_len - len(src)):
|
||||||
|
src.append(blank_sym)
|
||||||
|
src_masked.append(blank_sym)
|
||||||
|
tgt.append(blank_sym)
|
||||||
|
weight.append(0.0)
|
||||||
|
randomizable.append(False)
|
||||||
|
|
||||||
|
return (src, src_masked, tgt, weight, randomizable, attn_mask)
|
||||||
|
|
||||||
|
|
||||||
|
# dataset.mask_and_pad(list(range(10, 20)), seq_len=16, bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, inv_mask_length=0.33, unmasked_weight=0.444)
|
||||||
|
|
||||||
|
# dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45))], bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, randomize_proportion=0.05, inv_mask_length=0.33, unmasked_weight=0.444)
|
||||||
|
|
||||||
|
def collate_fn(sentences: List[List[int]],
|
||||||
|
bos_sym: int,
|
||||||
|
eos_sym: int,
|
||||||
|
blank_sym: int,
|
||||||
|
mask_proportion: float = 0.15,
|
||||||
|
padding_proportion: float = 0.15,
|
||||||
|
randomize_proportion: float = 0.05,
|
||||||
|
inv_mask_length: float = 0.25,
|
||||||
|
unmasked_weight: float = 0.25,
|
||||||
|
debug: bool = False) -> Tuple[torch.Tensor, torch.Tensor,
|
||||||
|
torch.Tensor, torch.Tensor,
|
||||||
|
torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Caution, this is not the collate_fn we give directly to the dataloader,
|
||||||
|
we give it a lambda: collate_fn=(lambda x: dataset.collate_fn(x, [other args]))
|
||||||
|
This formats a list-of-lists-of-int into 5 Tensors, explained below.
|
||||||
|
The key thing is that we mask out subsequences of random length within
|
||||||
|
these sentences, and force the network to predict the masked-out
|
||||||
|
subsequences (which have blanks appended to them to prevent the model
|
||||||
|
from knowing the exact length of the sequences it has to predict).
|
||||||
|
So it's like BERT but at the level of sequences rather than individual
|
||||||
|
words.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bos_sym: the integer id of the beginning-of-sentence symbol, e.g. 2.
|
||||||
|
Is allowed be the same as eos_sym (we are not necessarily
|
||||||
|
saying it will work best that way).
|
||||||
|
eos_sym: the integer id of the end-of-sentence symbol, e.g. 2.
|
||||||
|
blank_sym: the integer id of the blank symbol, e.g. 0 or 1.
|
||||||
|
mask_proportion: The proportion of words in each sentence that
|
||||||
|
are masked, interpreted as (roughly) the probability of any given
|
||||||
|
word being masked, although the masked locations will
|
||||||
|
tend to be in contiguous sequences (they are not independent).
|
||||||
|
padding_proportion: Like mask_proportion, but determines the
|
||||||
|
number of extra, blank symbols that are inserted as padding
|
||||||
|
at the end of masked regions (this ensures that the model
|
||||||
|
cannot know exactly how many words need to be inserted in
|
||||||
|
any given masked region.
|
||||||
|
randomize_proportion: The probability with which we replace
|
||||||
|
words that were not masked with randomly chosen words.
|
||||||
|
Like BERT, this is intended to force the model to predict
|
||||||
|
something reasonable at non-masked positions, and to make
|
||||||
|
this task harder than simply repeating the input.
|
||||||
|
inv_mask_length: This number determines how many separate
|
||||||
|
sub-sequences the (masked + padded) proportion of a sentence is split up
|
||||||
|
into, interpreted as the inverse of the expected length of
|
||||||
|
each *masked* region.
|
||||||
|
unmasked_weight: The weight to be applied to the log-likelihoods of
|
||||||
|
un-masked positions in sentences (predicting un-masked
|
||||||
|
positions is not completely trivial if randomize_proportion > 0).
|
||||||
|
Will be reflected in the returned tgt_weights tensor.
|
||||||
|
|
||||||
|
Returns a tuple (masked_src_symbols, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask,
|
||||||
|
tgt_weights),
|
||||||
|
all with 2 axes and the same shape: (num_sent, seq_len).
|
||||||
|
Their dtypes will be, respectively,
|
||||||
|
(torch.int64, torch.int64,
|
||||||
|
torch.int64, torch.bool,
|
||||||
|
torch.float)
|
||||||
|
masked_src_symbols: The sentences, with bos_symbol prepended and eos_symbol
|
||||||
|
appended, masked regions (including padding) replaced with blank,
|
||||||
|
and `randomize_proportion` non-masked symbols replaced with
|
||||||
|
symbols randomly taken from elsewhere in the sentences of this
|
||||||
|
minibatch. Then padded to a fixed length with blank.
|
||||||
|
src_symbols: Like masked_src_symbols, except with the masked symbols replaced
|
||||||
|
with the original symbols (but the padding that follows each
|
||||||
|
masked sub-sequence will still be blank)
|
||||||
|
tgt_symbols: The original sentences, with eos_symbol appended, and then
|
||||||
|
padded with blank to the same length as masked_symbols and
|
||||||
|
src_symbols.
|
||||||
|
src_key_padding_mask: Masking tensor for masked_src_symbols and src_symbols, to
|
||||||
|
account for all the sentence lengths not being identical
|
||||||
|
(makes each sentence's processing independent of seq_len).
|
||||||
|
Tensor of Bool of shape (num_sent, seq_len), with True
|
||||||
|
for masked positions (these are the blanks that follow the
|
||||||
|
eos_symbol in masked_src_symbols), False for un-masked positions.
|
||||||
|
tgt_weights: Weights that will be applied to the log-probabilities at
|
||||||
|
the output of the network. Will have 1.0 in positions
|
||||||
|
in `tgt_symbols` that were masked (including blank
|
||||||
|
padding at the end of masked regions), `unmasked_weight`
|
||||||
|
in other positions in the original sentences (including
|
||||||
|
terminating eos_symbol); and 0.0 in the remaining positions
|
||||||
|
corresponding to blank padding after the ends of
|
||||||
|
sentences.
|
||||||
|
"""
|
||||||
|
assert blank_sym not in [bos_sym, eos_sym]
|
||||||
|
max_sent_len = max([ len(s) for s in sentences])
|
||||||
|
#logging.info(f"Sentence lengths: {[ len(s) for s in sentences]}")
|
||||||
|
|
||||||
|
typical_mask_and_pad = int(max_sent_len * (mask_proportion + padding_proportion))
|
||||||
|
|
||||||
|
# The following formula gives roughly 1 standard deviation above where we'd
|
||||||
|
# expect the maximum sentence length to be with masking and padding.. we use
|
||||||
|
# this as a hard upper limit, to prevent outliers from affecting the batch
|
||||||
|
# size too much. We use this as the size `seq_len`.
|
||||||
|
# The "+ 4" is to ensure there is always room for the BOS, EOS and at least
|
||||||
|
# two padding symbols.
|
||||||
|
seq_len = max_sent_len + 4 + typical_mask_and_pad + int(typical_mask_and_pad ** 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
# srcs, srcs_masked, tgts and weights will be lists of the lists returned
|
||||||
|
# from `mask_and_pad`, one per sentence.
|
||||||
|
srcs = []
|
||||||
|
srcs_masked = []
|
||||||
|
tgts = []
|
||||||
|
weights = []
|
||||||
|
randomizables = []
|
||||||
|
attn_masks = []
|
||||||
|
for s in sentences:
|
||||||
|
(src, src_masked, tgt,
|
||||||
|
weight, randomizable,
|
||||||
|
attn_mask) = mask_and_pad(s, seq_len, bos_sym, eos_sym,
|
||||||
|
blank_sym, mask_proportion, padding_proportion,
|
||||||
|
inv_mask_length, unmasked_weight)
|
||||||
|
srcs.append(src)
|
||||||
|
srcs_masked.append(src_masked)
|
||||||
|
tgts.append(tgt)
|
||||||
|
weights.append(weight)
|
||||||
|
randomizables.append(randomizable)
|
||||||
|
attn_masks.append(attn_mask)
|
||||||
|
|
||||||
|
src_symbols = torch.tensor(srcs, dtype=torch.int64)
|
||||||
|
masked_src_symbols = torch.tensor(srcs_masked, dtype=torch.int64)
|
||||||
|
tgt_symbols = torch.tensor(tgts, dtype=torch.int64)
|
||||||
|
src_key_padding_mask = torch.tensor(attn_masks, dtype=torch.bool)
|
||||||
|
tgt_weights = torch.tensor(weights, dtype=torch.float)
|
||||||
|
|
||||||
|
attn_mask_sum = torch.sum(torch.logical_not(src_key_padding_mask), dim=0).tolist()
|
||||||
|
while attn_mask_sum[-1] == 0: # Remove always-masked positions at the endof the lists.
|
||||||
|
attn_mask_sum.pop()
|
||||||
|
if len(attn_mask_sum) < seq_len:
|
||||||
|
seq_len = len(attn_mask_sum)
|
||||||
|
(src_symbols, masked_src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask, tgt_weights) = (src_symbols[:,:seq_len], masked_src_symbols[:,:seq_len],
|
||||||
|
tgt_symbols[:,:seq_len], src_key_padding_mask[:,:seq_len],
|
||||||
|
tgt_weights[:,:seq_len])
|
||||||
|
|
||||||
|
if randomize_proportion > 0.0:
|
||||||
|
randomizable_tensor = torch.tensor(randomizables, dtype=torch.bool)
|
||||||
|
randomizable_indexes = torch.nonzero(randomizable_tensor) # (num_randomizable, 2)
|
||||||
|
num_randomizable = randomizable_indexes.shape[0]
|
||||||
|
|
||||||
|
to_randomize_indexes = torch.nonzero(torch.rand(num_randomizable) < randomize_proportion, as_tuple=True)[0]
|
||||||
|
num_to_randomize = to_randomize_indexes.numel()
|
||||||
|
|
||||||
|
# older versions of torch don't have tensor_split, so fake a simplified version of it.
|
||||||
|
# we'd be calling it as xxx.tensor_split(dim=1) if really in torc.
|
||||||
|
def tensor_split(t):
|
||||||
|
return (t[:,0], t[:,1])
|
||||||
|
|
||||||
|
random_src_locations = torch.randperm(num_randomizable)[:num_to_randomize]
|
||||||
|
|
||||||
|
random_symbols = src_symbols[tensor_split(randomizable_indexes[random_src_locations])]
|
||||||
|
random_indexes_tuple= tensor_split(randomizable_indexes[to_randomize_indexes])
|
||||||
|
src_symbols[random_indexes_tuple] = random_symbols
|
||||||
|
masked_src_symbols[random_indexes_tuple] = random_symbols
|
||||||
|
|
||||||
|
|
||||||
|
# I set this to true and tested with:
|
||||||
|
# python3 -c 'import dataset; dataset.collate_fn(sentences=[ list(range(100, 200)), list(range(300, 450)), list(range(500,600))], bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, randomize_proportion=0.05, inv_mask_length=0.33, unmasked_weight=0.444)'
|
||||||
|
#.. and ran a few times to check the values printed looked about right, and that no assertions failed.
|
||||||
|
if debug:
|
||||||
|
check_collated_tensors(sentences, bos_sym, eos_sym, blank_sym,
|
||||||
|
unmasked_weight,
|
||||||
|
masked_src_symbols, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask, tgt_weights)
|
||||||
|
return (masked_src_symbols, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask, tgt_weights)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def check_collated_tensors(sentences: List[List[int]],
|
||||||
|
bos_sym: int,
|
||||||
|
eos_sym: int,
|
||||||
|
blank_sym: int,
|
||||||
|
unmasked_weight: float,
|
||||||
|
masked_src_symbols, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask,
|
||||||
|
tgt_weights):
|
||||||
|
"""
|
||||||
|
This function checks the output of collate_fn, consider it test code. Please see
|
||||||
|
the documentation of collate_fn to understand the args.
|
||||||
|
"""
|
||||||
|
for t in src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights:
|
||||||
|
assert t.shape == masked_src_symbols.shape
|
||||||
|
|
||||||
|
tot_positions = src_symbols.numel()
|
||||||
|
|
||||||
|
masked_src_symbols, src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights = (
|
||||||
|
masked_src_symbols.tolist(), src_symbols.tolist(), tgt_symbols.tolist(),
|
||||||
|
src_key_padding_mask.tolist(), tgt_weights.tolist())
|
||||||
|
assert len(sentences) == len(masked_src_symbols)
|
||||||
|
|
||||||
|
tot_masked_positions = 0
|
||||||
|
tot_padded_positions = 0
|
||||||
|
tot_unmasked_positions = 0 # all un-masked, non-blank postions, including eos
|
||||||
|
tot_randomized_positions = 0
|
||||||
|
num_masked_subseqs = 0
|
||||||
|
tot_symbols = 0 # original symbols in sentences, no bos/eos
|
||||||
|
|
||||||
|
assert unmasked_weight > 0.001 # or this test code won't work..
|
||||||
|
|
||||||
|
for i in range(len(sentences)):
|
||||||
|
reconstructed_sent = list(filter(lambda x: x not in [bos_sym,eos_sym,blank_sym], tgt_symbols[i]))
|
||||||
|
if sentences[i] != reconstructed_sent:
|
||||||
|
print(f"Error: sentence {i}={sentences[i]} differs from {reconstructed_sent}")
|
||||||
|
(masked_src, src, tgt, src_mask, weights) = (masked_src_symbols[i], src_symbols[i],
|
||||||
|
tgt_symbols[i], src_key_padding_mask[i], tgt_weights[i])
|
||||||
|
|
||||||
|
assert src[0] == masked_src[0] == bos_sym
|
||||||
|
for j in range(len(masked_src)):
|
||||||
|
assert masked_src[j] == blank_sym or masked_src[j] == src[j]
|
||||||
|
|
||||||
|
if src[j] not in [bos_sym, eos_sym, blank_sym]:
|
||||||
|
tot_symbols += 1
|
||||||
|
|
||||||
|
if j > 0:
|
||||||
|
assert (src[j] == eos_sym) == (masked_src[j] == eos_sym) == (tgt[j-1] == eos_sym)
|
||||||
|
if masked_src[j] == blank_sym: # masked or padding of masked subseq, or post-eos padding..
|
||||||
|
assert src[j] == tgt[j - 1] # masked symbols are not randomized.
|
||||||
|
assert weights[j - 1] in [0.0, 1.0] # 0.0 for final blank padding
|
||||||
|
if weights[j - 1] == 1.0: # Not final blank padding...
|
||||||
|
if tgt[j - 1] == blank_sym:
|
||||||
|
tot_padded_positions += 1
|
||||||
|
else:
|
||||||
|
tot_masked_positions += 1
|
||||||
|
if masked_src[j + 1] != blank_sym:
|
||||||
|
num_masked_subseqs += 1
|
||||||
|
else:
|
||||||
|
assert weights[j - 1] == 0 or abs(weights[j-1] - unmasked_weight) < 0.001
|
||||||
|
if abs(weights[j - 1]-unmasked_weight) < 0.001:
|
||||||
|
tot_unmasked_positions += 1
|
||||||
|
if tgt[j - 1] != src[j]:
|
||||||
|
tot_randomized_positions += 1
|
||||||
|
|
||||||
|
if src_mask[j]: # if masked..
|
||||||
|
assert src[j] == blank_sym
|
||||||
|
|
||||||
|
assert tot_symbols == sum(len(x) for x in sentences)
|
||||||
|
|
||||||
|
assert tot_unmasked_positions + tot_masked_positions == tot_symbols + len(sentences)
|
||||||
|
|
||||||
|
print(f"{tot_unmasked_positions} + {tot_masked_positions} == {tot_symbols} + {len(sentences)}")
|
||||||
|
print(f"tot_symbols / tot_positions = {tot_symbols/tot_positions} (rest is bos,eos,padding)")
|
||||||
|
|
||||||
|
print(f"Masking/tot_symbols = {tot_masked_positions/tot_symbols}, Padding/tot_symbols = {tot_padded_positions/tot_symbols}")
|
||||||
|
print(f"Randomization/tot_non_masked_symbols = {tot_randomized_positions/(tot_symbols-tot_masked_positions)}")
|
||||||
|
print(f"Mean masking length = {tot_masked_positions/num_masked_subseqs}, Mean padding length = {tot_padded_positions/num_masked_subseqs}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# This shows some useful code about the BPE encoding.
|
||||||
|
# import sentencepiece as spm
|
||||||
|
# sp = spm.SentencePieceProcessor()
|
||||||
|
# sp.load(bpe_model_fn) # bpe.model
|
||||||
|
# sp.GetPieceSize(..)
|
||||||
|
# sp.Decode(...)
|
||||||
|
# sp.Encode(...)
|
||||||
|
|
||||||
|
|
||||||
|
# import dataset
|
||||||
|
# import torch
|
||||||
|
# train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
||||||
|
|
||||||
|
|
||||||
|
# train_dl = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True, collate_fn=(lambda x: train.collate_fn(x)))
|
||||||
|
# x = iter(train_dl)
|
||||||
|
# str(next(x))
|
||||||
|
# '[ [ 10 38 651 593 3 1343 31 780 6 4172 112 788 1696 24 289 24 3 403 6 4493 162 92 71 328 417 217 338 14 5 3 1876 154 21 23 2237 43 3 1535 92 71 2816 7 1031 31 2318 92 2528 4806 14 206 3 954 1373 6 525 4 631 447 2639 ] [ 1014 336 171 209 795 10 16 90 27 787 139 53 45 2817 ] [ 11 980 51 22 1748 14 91 105 363 428 6 8 2887 3305 2525 2297 70 3 4651 6 27 282 335 426 134 292 5 193 3 539 2250 584 127 ] [ 9 3 1858 4 18 2257 4 6 41 748 10 304 7 229 83 2793 4 9 981 7 1484 33 3 103 7 539 5 477 3195 18 64 39 82 1034 6 3 4128 ] [ 17 147 22 7 708 60 133 174 105 4111 4 6 3 1384 65 50 1051 9 2953 6 3 461 180 1142 23 5 36 888 8 131 173 390 78 23 266 2822 715 46 182 65 22 1739 33 3 700 1450 14 233 4 ] [ 80 10 16 67 279 7 1827 264 96 3 187 2851 2108 ] [ 1473 48 106 227 9 160 2011 4 674 ] [ 3 954 762 29 85 228 33 8 940 40 4952 36 486 390 595 3 81 225 6 1440 125 346 134 296 126 419 1017 3824 4 8 179 184 11 33 580 1861 ] [ 30 22 245 15 117 8 2892 28 1204 145 7 3 236 3417 6 3 3839 5 3106 155 198 30 228 2555 46 15 32 41 747 72 9 25 977 ] [ 222 466 6 3157 ] ]'
|
||||||
|
#
|
||||||
|
# or:
|
||||||
|
# import k2
|
||||||
|
# k2.ragged.to_list(next(x))
|
||||||
|
# [shows something similar].
|
||||||
|
#
|
||||||
|
# You'd really do something like:
|
||||||
|
# for epoch in range(max_epochs):
|
||||||
|
# for minibatch in train_dl:
|
||||||
|
|
||||||
|
|
||||||
|
# .. How to process data? Suppose we have a sentence like [259, 278, 45, 11, 303, 1319, 34, 15, 396, 3435, 7, 44].
|
||||||
|
#
|
||||||
|
# First: we randomly choose one or more starting positins for a masked segment.
|
||||||
|
# Each sentence must have at least one masked segment (or there is no contribution to the loss function).
|
||||||
|
# We choose to have:
|
||||||
|
# num_masked_segments = max(1, len(sent) // 15)
|
||||||
|
#
|
||||||
|
# The length of the masked segment (this is the target for prediction), we set to the geometric
|
||||||
|
# distribution with the probability of success set to 3:
|
||||||
|
#
|
||||||
|
# g = torch.distributions.geometric.Geometric(probs=0.3) # <-- expected value is 3.333
|
||||||
|
# Example of sampling:
|
||||||
|
# g.sample(sample_shape=torch.Size([10]))
|
||||||
|
#
|
||||||
|
# We now we randomly compute the location of the masked segments (length computed above) as follows:
|
||||||
|
# First, the masked segments must be separated by at least one non-masked word (else they would be
|
||||||
|
# a single segment). So for n masked segments, there are n-1 words required for minimal separation.
|
||||||
|
# If tot-length-of-segments + n-1 is greater than the sentence length, we just have the entire
|
||||||
|
# sentence be masked. Otherwise, we randomly divide the remaining number of words between the n+1
|
||||||
|
# positions where they can appear (e.g. for 2 segments, this would be at the start, between the 2 segments,
|
||||||
|
# and at the end). This is the multinomial distribution, but we can more easily compute this
|
||||||
|
# directly using rand() and cutoffs, rather than creating a torch.distributions.Multinomial().
|
||||||
|
#
|
||||||
|
|
||||||
|
# Next we need to compute a random amount of blank padding (>= 0) for each of the masked regions;
|
||||||
|
# this is done so the model never knows the exact length of the masked region. We can just use the
|
||||||
|
# same distribution as for the length of the masked regions, i.e. geometric with success-prob=0.3
|
||||||
|
# (expected padding length is 3).
|
||||||
|
#
|
||||||
|
# At this point we know where the masked regions are and how much padding they have. We can format
|
||||||
|
# the result as three lists, of the same length:
|
||||||
|
#
|
||||||
|
# sent: contains the words in the sentence with, in masked
|
||||||
|
# positions, the original (target) words, then with
|
||||||
|
# blank in the blank-padding after masked positions.
|
||||||
|
#
|
||||||
|
# sent_augmented: `sent` with, at a small defined percentage of positions
|
||||||
|
# that were *not* masked, the real token replaced with a
|
||||||
|
# token randomly chosen from the tokens in the minibatch.
|
||||||
|
# (like BERT, we use this type of augmentation, so the model
|
||||||
|
# has to predict the original token).
|
||||||
|
#
|
||||||
|
# masked_sent_augmented: List[int], contains the words in `sent_augmented`, except
|
||||||
|
# with masked positions and the blank padding after the masked regions
|
||||||
|
# both replaced with blank.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# The way these will be processed is as follows:
|
||||||
|
#
|
||||||
|
# masked_sent_in = [bos] + masked_sent_augmented + [eos] <-- so we know the sentence ended, distinguish it from truncated ones.
|
||||||
|
# sent_in = [bos] + sent_augmented + [eos]
|
||||||
|
#
|
||||||
|
# sent_out = sent + [eos] + [eos] #<--- the predicted targets at each point, although
|
||||||
|
# # we only really care about this in masked regions.
|
||||||
|
# # The extra eos is so that the length is the same as
|
||||||
|
# # masked_sent_in and sent_in.
|
||||||
|
#
|
||||||
|
# out_scale = (masked_sent==blk ? 1.0 : non_masked_scale) # e.g. non_masked_scale = 1.0 is fine,
|
||||||
|
# # this is a choice; we can perhaps
|
||||||
|
# # report these 2 parts of the loss
|
||||||
|
# # separately though.
|
||||||
|
# # <-- can also set the last element
|
||||||
|
# # of out_scale to a smaller number, since
|
||||||
|
# # it's a repeated eos.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# OK, how do we combine these into a minibatch? Firstly, we truncate sentences to a maximum
|
||||||
|
# length, e.g. 128, if `masked_sent_in`/`sent_in` have length longer than that. We choose randomly
|
||||||
|
# in each case to truncate the beginning or end, truncating both masked_sent_in/sent_in and sent_out
|
||||||
|
# from the same side. Caution: this means that these sentences may lack bos and/or eos symbols.
|
||||||
|
#
|
||||||
|
# Next, we combine shorter utterances by appending them ( all of: masked_sent_in, sent_in, out_scale)
|
||||||
|
# as long as doing so would keep the total length under 128. We then pad (masked_sent_in, sent_in, sent_out, out_scale)
|
||||||
|
# with: (<blk>,<blk>,<eos>, 0) up to the maximum length of any sentence in the minibatch <- or could use
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# # i.e. ones where masked_sent is blank and zeros elsewhere;
|
||||||
|
# # this pertains to positions in `sent_out`.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# torch.distributions.gamma.Gamma(concentration=1.0, rate=1.0/5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LmBatchSampler(torch.utils.data.Sampler):
|
||||||
|
"""
|
||||||
|
A sampler that returns a batch of integer indexes as a list, intended for use
|
||||||
|
with class LmDataset. The sentences returned in each batch will all be about
|
||||||
|
the same size, and the batch size is specified as a number of words (we also
|
||||||
|
provide an option that allows you to limit the max memory consumed by transformers)
|
||||||
|
|
||||||
|
Has support for distributed operation.
|
||||||
|
"""
|
||||||
|
def __init__(self, dataset: LmDataset,
|
||||||
|
symbols_per_batch: int,
|
||||||
|
length_ceil: float = 200.0,
|
||||||
|
length_floor: float = 4.0,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: int = None,
|
||||||
|
seed: int = 0,
|
||||||
|
delay_init: bool = False):
|
||||||
|
"""
|
||||||
|
Constructor documentation:
|
||||||
|
dataset: the LmDataset object that we are sampling from. This
|
||||||
|
class does not retain a reference to the LmDataset.
|
||||||
|
symbols_per_batch: The number of BPE symbols desired in each minibatch
|
||||||
|
length_floor: When the sentence length gets less than about this much,
|
||||||
|
the batch size stops increasing inversely with sentence
|
||||||
|
length. Prevent OOM on batches with short sentences.
|
||||||
|
length_ceil: After the sentence length gets more than about
|
||||||
|
this much, the batch size will start decreasing
|
||||||
|
as 1/(sentence-length^2). This is a mechanism to
|
||||||
|
avoid excessive memory consumption in transformers, when
|
||||||
|
sentence length gets long.
|
||||||
|
world_size: The world size for distributed operation; if None,
|
||||||
|
will be worked out from torch.distributed.
|
||||||
|
rank: The rank of this sampler/process for distributed operation; if None,
|
||||||
|
will be worked out from torch.distributed.
|
||||||
|
seed: The random seed
|
||||||
|
delay_init: If true, will omit calling self.set_epoch(0) at the
|
||||||
|
end of the __init__ function. In this case the caller
|
||||||
|
must call set_epoch(0). [Setting this option is necessary
|
||||||
|
to work with data-loader worker processes plus DDP, since
|
||||||
|
set_epoch() will use ddp, which I believe is a no-no prior
|
||||||
|
to initializing data-loaders.]
|
||||||
|
"""
|
||||||
|
self.seed = seed
|
||||||
|
self.symbols_per_batch = symbols_per_batch
|
||||||
|
self.length_floor = length_floor
|
||||||
|
self.quadratic_constant = 1.0 / length_ceil
|
||||||
|
self._maybe_init_distributed(world_size=world_size, rank=rank)
|
||||||
|
|
||||||
|
# a configuration constant we don't expose.
|
||||||
|
self.multiplicative_random_length = 0.05
|
||||||
|
|
||||||
|
# "indexes" is the subset of indexes into LmDataset that this
|
||||||
|
# sampler is reponsible for (all of them, in the non-distributed case).
|
||||||
|
data_indexes = torch.arange(self.rank, len(dataset), self.world_size, dtype=torch.int32) # dtype=torch.int32
|
||||||
|
|
||||||
|
word_row_splits = dataset.words.shape.row_splits(1) # dtype=torch.int32
|
||||||
|
word_lengths = word_row_splits[1:] - word_row_splits[:-1] # dtype=torch.int32
|
||||||
|
|
||||||
|
# the sentences this sampler is responsible for, as sequences of words.
|
||||||
|
# It's a ragged tensor of int32
|
||||||
|
sentences, _ = dataset.sentences.index(data_indexes, axis=0)
|
||||||
|
|
||||||
|
# sentence_lengths is a k2.RaggedTensor like `sentences`, but with the words replaced
|
||||||
|
# with their respective lengths, in BPE pieces.
|
||||||
|
sentence_lengths = k2.ragged.index(word_lengths, sentences)
|
||||||
|
del sentences # save memory
|
||||||
|
assert isinstance(sentence_lengths, k2.RaggedTensor)
|
||||||
|
|
||||||
|
# convert to float so sum_per_sublist() will work (TODO: sum_per_sublist() will eventually
|
||||||
|
# support int32.)
|
||||||
|
sentence_lengths = sentence_lengths.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
# Convert into a simple tensor of float by adding lengths of words.
|
||||||
|
sentence_lengths = sentence_lengths.sum()
|
||||||
|
|
||||||
|
assert isinstance(sentence_lengths, torch.Tensor)
|
||||||
|
assert sentence_lengths.dtype == torch.float32
|
||||||
|
|
||||||
|
# self.sentence_lengths is a Tensor with dtype=torch.float32. It
|
||||||
|
# contains the lengths, in BPE tokens, of the sentences that this
|
||||||
|
# sampler is responsible for, whose real indexes are in
|
||||||
|
# `data_indexes` above (this is not stored, as we know the formula).
|
||||||
|
self.sentence_lengths = sentence_lengths
|
||||||
|
|
||||||
|
if not delay_init:
|
||||||
|
self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes
|
||||||
|
|
||||||
|
def _sync_sizes(self, device: torch.device = torch.device('cuda')):
|
||||||
|
# Calling this on all copies of a DDP setup will sync the sizes so that
|
||||||
|
# all copies have the exact same number of batches. I think
|
||||||
|
# this needs to be called with the GPU device, not sure if it would
|
||||||
|
# work otherwise.
|
||||||
|
if self.world_size > 1:
|
||||||
|
min_size = torch.tensor([len(self.batch_indices)], device=device, dtype=torch.int64)
|
||||||
|
dist.all_reduce(min_size, op=dist.ReduceOp.MIN)
|
||||||
|
min_size = min_size.to('cpu').item()
|
||||||
|
logging.info(f"world_size={self.world_size}, rank={self.rank}: reducing batch indices from {len(self.batch_indices)} to {min_size}")
|
||||||
|
self.batch_indices = self.batch_indices[0:min_size]
|
||||||
|
|
||||||
|
def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]):
|
||||||
|
if world_size is not None:
|
||||||
|
assert world_size >= 1
|
||||||
|
if rank is not None:
|
||||||
|
assert rank >= 0
|
||||||
|
if not dist.is_available() or not dist.is_initialized():
|
||||||
|
self.world_size = 1 if world_size is None else world_size
|
||||||
|
self.rank = 0 if rank is None else rank
|
||||||
|
return
|
||||||
|
self.world_size = dist.get_world_size() if world_size is None else world_size
|
||||||
|
self.rank = dist.get_rank() if rank is None else rank
|
||||||
|
assert self.rank < self.world_size
|
||||||
|
|
||||||
|
|
||||||
|
def set_epoch(self, epoch: int):
|
||||||
|
"""
|
||||||
|
Must be called at the beginning of each epoch, before initializing the DataLoader,
|
||||||
|
to re-shuffle the data. If this is not done, this sampler will give you the same batches
|
||||||
|
each time it is called.
|
||||||
|
"""
|
||||||
|
g = torch.manual_seed(self.rank + self.seed + epoch)
|
||||||
|
|
||||||
|
sentence_lengths = (self.sentence_lengths *
|
||||||
|
(1.0 + torch.rand(*self.sentence_lengths.shape, generator=g) * self.multiplicative_random_length))
|
||||||
|
|
||||||
|
# This mechanism regulates the batch size so that we don't get OOM in transformers
|
||||||
|
# when the sentences are long.
|
||||||
|
sentence_lengths = (sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant) + self.length_floor
|
||||||
|
|
||||||
|
values, indices = torch.sort(sentence_lengths) # values,indices dtypes: torch.float,torch.int64
|
||||||
|
|
||||||
|
# map to the original indexes into the dataset (the original sentence
|
||||||
|
# indexes), see torch.arange expression in the constructor. save as
|
||||||
|
# int32 just to save a little memory. self.indices are indexes into the
|
||||||
|
# LmDataset, just including the subset of indices that this sampler is
|
||||||
|
# responsible for (in terms of rank and world_size), and sorted by
|
||||||
|
# length with a small amount of randomization specific to the epoch.
|
||||||
|
self.indices = ((indices * self.world_size) + self.rank).to(dtype=torch.int32)
|
||||||
|
|
||||||
|
# now `batch_ids` will be: [0, 0, 0, 0, .., 0, 1, 1, 1, ... 1, 2, ... ],
|
||||||
|
# saying which batch each element of values/indices belongs to.
|
||||||
|
batch_ids = (torch.cumsum(values.to(dtype=torch.double), dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32)
|
||||||
|
|
||||||
|
batch_boundaries = torch.nonzero(batch_ids[1:] - batch_ids[:-1], as_tuple=True)[0]
|
||||||
|
batch_boundaries.add_(1)
|
||||||
|
self.batch_boundaries = torch.cat((torch.zeros(1, dtype=torch.int32), batch_boundaries), dim=0)
|
||||||
|
|
||||||
|
num_batches = self.batch_boundaries.numel() - 1
|
||||||
|
|
||||||
|
# self.batch_indices is a permutation of [0, 1, ... num_batches -
|
||||||
|
# 1]; it determines the order in which we access the batches. It's
|
||||||
|
# necessary to randomize the order of these, to avoid returning batches
|
||||||
|
# from shortest to longest sentences.
|
||||||
|
self.batch_indices = torch.randperm(num_batches, generator=g, dtype=torch.int32).tolist()
|
||||||
|
self._sync_sizes()
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.batch_indices)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""
|
||||||
|
Iterator that yields lists of indices (i.e., integer indices into the LmDataset)
|
||||||
|
"""
|
||||||
|
for batch_idx in self.batch_indices:
|
||||||
|
batch_start = self.batch_boundaries[batch_idx].item()
|
||||||
|
batch_end = self.batch_boundaries[batch_idx + 1].item()
|
||||||
|
yield self.indices[batch_start:batch_end].tolist()
|
||||||
|
|
||||||
|
|
||||||
|
class CollateFn:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.extra_args = kwargs
|
||||||
|
|
||||||
|
def __call__(self, sentences: List[List[int]]):
|
||||||
|
return collate_fn(sentences, **self.extra_args)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
||||||
|
# sampler = dataset.LmBatchSampler(test, symbols_per_batch=1000, world_size=2, rank=0)
|
||||||
|
# a = iter(sampler)
|
||||||
|
# print(str(next(a)))
|
||||||
|
|
||||||
|
# collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=1, eos_sym=1, blank_sym=0, debug=True))
|
||||||
|
# train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, collate_fn=collate_fn)
|
||||||
|
# x = iter(train_dl)
|
||||||
|
# print(str(next(x)))
|
1256
egs/librispeech/ASR/conformer_lm/madam.py
Normal file
1256
egs/librispeech/ASR/conformer_lm/madam.py
Normal file
File diff suppressed because it is too large
Load Diff
156
egs/librispeech/ASR/conformer_lm/test_conformer.py
Normal file
156
egs/librispeech/ASR/conformer_lm/test_conformer.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# run with:
|
||||||
|
# python3 -m pytest test_conformer.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import dataset # from .
|
||||||
|
from conformer import (
|
||||||
|
RelPosTransformerDecoder,
|
||||||
|
RelPosTransformerDecoderLayer,
|
||||||
|
MaskedLmConformer,
|
||||||
|
MaskedLmConformerEncoder,
|
||||||
|
MaskedLmConformerEncoderLayer,
|
||||||
|
RelPositionMultiheadAttention,
|
||||||
|
RelPositionalEncoding,
|
||||||
|
generate_square_subsequent_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def test_rel_position_multihead_attention():
|
||||||
|
# Also tests RelPositionalEncoding
|
||||||
|
embed_dim = 256
|
||||||
|
num_heads = 4
|
||||||
|
T = 25
|
||||||
|
N = 4
|
||||||
|
C = 256
|
||||||
|
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||||
|
rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
|
||||||
|
|
||||||
|
x = torch.randn(N, T, C)
|
||||||
|
#pos_emb = torch.randn(1, 2*T-1, C)
|
||||||
|
x, pos_emb = pos_emb_module(x)
|
||||||
|
x = x.transpose(0, 1) # (T, N, C)
|
||||||
|
attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_emb)
|
||||||
|
|
||||||
|
|
||||||
|
def test_masked_lm_conformer_encoder_layer():
|
||||||
|
# Also tests RelPositionalEncoding
|
||||||
|
embed_dim = 256
|
||||||
|
num_heads = 4
|
||||||
|
T = 25
|
||||||
|
N = 4
|
||||||
|
C = 256
|
||||||
|
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||||
|
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)
|
||||||
|
|
||||||
|
|
||||||
|
x = torch.randn(N, T, C)
|
||||||
|
x, pos_emb = pos_emb_module(x)
|
||||||
|
x = x.transpose(0, 1) # (T, N, C)
|
||||||
|
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||||
|
y = encoder_layer(x, pos_emb, key_padding_mask=key_padding_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def test_masked_lm_conformer_encoder():
|
||||||
|
# Also tests RelPositionalEncoding
|
||||||
|
embed_dim = 256
|
||||||
|
num_heads = 4
|
||||||
|
T = 25
|
||||||
|
N = 4
|
||||||
|
C = 256
|
||||||
|
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||||
|
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)
|
||||||
|
norm = torch.nn.LayerNorm(embed_dim)
|
||||||
|
encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=4,
|
||||||
|
norm=norm)
|
||||||
|
|
||||||
|
|
||||||
|
x = torch.randn(N, T, C)
|
||||||
|
x, pos_emb = pos_emb_module(x)
|
||||||
|
x = x.transpose(0, 1) # (T, N, C)
|
||||||
|
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||||
|
y = encoder(x, pos_emb, key_padding_mask=key_padding_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformer_decoder_layer_rel_pos():
|
||||||
|
embed_dim = 256
|
||||||
|
num_heads = 4
|
||||||
|
T = 25
|
||||||
|
N = 4
|
||||||
|
C = 256
|
||||||
|
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||||
|
decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads)
|
||||||
|
|
||||||
|
|
||||||
|
x = torch.randn(N, T, C)
|
||||||
|
x, pos_emb = pos_emb_module(x)
|
||||||
|
x = x.transpose(0, 1) # (T, N, C)
|
||||||
|
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||||
|
attn_mask = generate_square_subsequent_mask(T)
|
||||||
|
memory = torch.randn(T, N, C)
|
||||||
|
y = decoder_layer(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformer_decoder_rel_pos():
|
||||||
|
embed_dim = 256
|
||||||
|
num_heads = 4
|
||||||
|
T = 25
|
||||||
|
N = 4
|
||||||
|
C = 256
|
||||||
|
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
|
||||||
|
decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads)
|
||||||
|
decoder_norm = torch.nn.LayerNorm(embed_dim)
|
||||||
|
decoder = RelPosTransformerDecoder(decoder_layer, num_layers=6, norm=decoder_norm)
|
||||||
|
|
||||||
|
x = torch.randn(N, T, C)
|
||||||
|
x, pos_emb = pos_emb_module(x)
|
||||||
|
x = x.transpose(0, 1) # (T, N, C)
|
||||||
|
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
|
||||||
|
attn_mask = generate_square_subsequent_mask(T)
|
||||||
|
memory = torch.randn(T, N, C)
|
||||||
|
y = decoder(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def test_masked_lm_conformer():
|
||||||
|
|
||||||
|
num_classes = 87
|
||||||
|
d_model = 256
|
||||||
|
|
||||||
|
model = MaskedLmConformer(num_classes,d_model)
|
||||||
|
|
||||||
|
|
||||||
|
N = 31
|
||||||
|
|
||||||
|
|
||||||
|
(masked_src_symbols, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask,
|
||||||
|
tgt_weights) = dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45)), list(range(50,68))], bos_sym=1, eos_sym=2,
|
||||||
|
blank_sym=0)
|
||||||
|
|
||||||
|
# test forward() of MaskedLmConformer
|
||||||
|
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
|
||||||
|
nll = model.decoder_nll(memory, pos_emb, src_symbols, tgt_symbols,
|
||||||
|
src_key_padding_mask)
|
||||||
|
print("nll = ", nll)
|
||||||
|
loss = (nll * tgt_weights).sum()
|
||||||
|
print("loss = ", loss)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_square_subsequent_mask():
|
||||||
|
s = 5
|
||||||
|
mask = generate_square_subsequent_mask(s, torch.device('cpu'))
|
||||||
|
inf = float("inf")
|
||||||
|
expected_mask = torch.tensor(
|
||||||
|
[
|
||||||
|
[0.0, -inf, -inf, -inf, -inf],
|
||||||
|
[0.0, 0.0, -inf, -inf, -inf],
|
||||||
|
[0.0, 0.0, 0.0, -inf, -inf],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, -inf],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert torch.all(torch.eq(mask, expected_mask))
|
32
egs/librispeech/ASR/conformer_lm/test_dataset.py
Normal file
32
egs/librispeech/ASR/conformer_lm/test_dataset.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import _k2
|
||||||
|
import dataset
|
||||||
|
import os
|
||||||
|
from torch import multiprocessing as mp
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
def local_collate_fn(sentences):
|
||||||
|
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
#mp.set_start_method('spawn')
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = "12344"
|
||||||
|
|
||||||
|
dist.init_process_group(backend="nccl", group_name="main",
|
||||||
|
rank=0, world_size=1)
|
||||||
|
|
||||||
|
train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
|
||||||
|
sampler = dataset.LmBatchSampler(test, symbols_per_batch=5000, world_size=2, rank=0)
|
||||||
|
print("len(sampler) = ", len(sampler))
|
||||||
|
|
||||||
|
a = iter(sampler)
|
||||||
|
print(str(next(a)))
|
||||||
|
|
||||||
|
train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler,
|
||||||
|
collate_fn=local_collate_fn,
|
||||||
|
num_workers=2)
|
||||||
|
x = iter(train_dl)
|
||||||
|
print(str(next(x)))
|
39
egs/librispeech/ASR/conformer_lm/test_dataset_empty.py
Normal file
39
egs/librispeech/ASR/conformer_lm/test_dataset_empty.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import _k2
|
||||||
|
import dataset
|
||||||
|
from dataset import LmDataset
|
||||||
|
import os
|
||||||
|
from torch import multiprocessing as mp
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
def local_collate_fn(sentences):
|
||||||
|
return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=False)
|
||||||
|
|
||||||
|
x = _k2.RaggedInt('[[1]]') # make sure library initialized?
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
mp.set_start_method('spawn')
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = "12344"
|
||||||
|
|
||||||
|
dist.init_process_group(backend="nccl", group_name="main",
|
||||||
|
rank=0, world_size=1)
|
||||||
|
|
||||||
|
words = k2.RaggedInt('[[0][1 2]]')
|
||||||
|
sentences = k2.RaggedInt('[[1][][][][][]]')
|
||||||
|
|
||||||
|
train = LmDataset(sentences, words)
|
||||||
|
|
||||||
|
|
||||||
|
sampler = dataset.LmBatchSampler(train, symbols_per_batch=10, world_size=1, rank=0)
|
||||||
|
|
||||||
|
a = iter(sampler)
|
||||||
|
print(str(next(a)))
|
||||||
|
|
||||||
|
train_dl = torch.utils.data.DataLoader(train, batch_sampler=sampler,
|
||||||
|
collate_fn=local_collate_fn,
|
||||||
|
num_workers=0)
|
||||||
|
x = iter(train_dl)
|
||||||
|
print(str(next(x)))
|
623
egs/librispeech/ASR/conformer_lm/train.py
Executable file
623
egs/librispeech/ASR/conformer_lm/train.py
Executable file
@ -0,0 +1,623 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Daniel Povey)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import copyfile
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import dataset # from .
|
||||||
|
import madam # from .
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
from conformer import MaskedLmConformer
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from madam import Gloam
|
||||||
|
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
|
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--world-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of GPUs for DDP training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--master-port",
|
||||||
|
type=int,
|
||||||
|
default=12354,
|
||||||
|
help="Master port to use for DDP training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tensorboard",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Should various information be logged in tensorboard.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
"""Return a dict containing training parameters.
|
||||||
|
|
||||||
|
All training related parameters that are not passed from the commandline
|
||||||
|
is saved in the variable `params`.
|
||||||
|
|
||||||
|
Commandline options are merged into `params` after they are parsed, so
|
||||||
|
you can also access them via `params`.
|
||||||
|
|
||||||
|
Explanation of options saved in `params`:
|
||||||
|
|
||||||
|
- exp_dir: It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
|
||||||
|
- lr: It specifies the initial learning rate
|
||||||
|
|
||||||
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
|
in computing features.
|
||||||
|
|
||||||
|
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
|
||||||
|
and continue training from that checkpoint.
|
||||||
|
|
||||||
|
- num_epochs: Number of epochs to train.
|
||||||
|
|
||||||
|
- num_valid_batches: Number of batches of validation data to use each
|
||||||
|
time we compute validation loss
|
||||||
|
|
||||||
|
- symbols_per_batch: Number of symbols in each batch (sampler will
|
||||||
|
choose the number of sentences to satisfy this contraint).
|
||||||
|
|
||||||
|
- best_train_loss: Best training loss so far. It is used to select
|
||||||
|
the model that has the lowest training loss. It is
|
||||||
|
updated during the training.
|
||||||
|
|
||||||
|
- best_valid_loss: Best validation loss so far. It is used to select
|
||||||
|
the model that has the lowest validation loss. It is
|
||||||
|
updated during the training.
|
||||||
|
|
||||||
|
- best_train_epoch: It is the epoch that has the best training loss.
|
||||||
|
|
||||||
|
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||||
|
|
||||||
|
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||||
|
contains number of batches trained so far across
|
||||||
|
epochs.
|
||||||
|
|
||||||
|
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||||
|
|
||||||
|
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||||
|
|
||||||
|
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||||
|
|
||||||
|
"""
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
# exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate.
|
||||||
|
# exp_4, vs. exp_3, is using the Gloam optimizer with
|
||||||
|
# in exp_5, vs. exp_4, we change Gloam to have a 1/sqrt(t) factor
|
||||||
|
# as well as the exponential part.
|
||||||
|
# exp_6, we change the decay from 0.85 to 0.9.
|
||||||
|
"exp_dir": Path("conformer_lm/exp_6"),
|
||||||
|
"lm_dataset": Path("data/lm_training_5000/lm_data.pt"),
|
||||||
|
"num_tokens": 5000,
|
||||||
|
"blank_sym": 0,
|
||||||
|
"bos_sym": 1,
|
||||||
|
"eos_sym": 1,
|
||||||
|
"start_epoch": 2,
|
||||||
|
"num_epochs": 20,
|
||||||
|
"num_valid_batches": 200,
|
||||||
|
"symbols_per_batch": 5000,
|
||||||
|
"best_train_loss": float("inf"),
|
||||||
|
"best_valid_loss": float("inf"),
|
||||||
|
"best_train_epoch": -1,
|
||||||
|
"best_valid_epoch": -1,
|
||||||
|
"batch_idx_train": 0,
|
||||||
|
"log_interval": 10,
|
||||||
|
"reset_interval": 200,
|
||||||
|
"valid_interval": 3000,
|
||||||
|
"beam_size": 10,
|
||||||
|
"accum_grad": 1,
|
||||||
|
"attention_dim": 512,
|
||||||
|
"nhead": 8,
|
||||||
|
"num_decoder_layers": 6,
|
||||||
|
"max_lrate": 5.0e-04
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_if_available(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Load checkpoint from file.
|
||||||
|
|
||||||
|
If params.start_epoch is positive, it will load the checkpoint from
|
||||||
|
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
||||||
|
|
||||||
|
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||||
|
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||||
|
and `best_valid_loss` in `params`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
The return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The training model.
|
||||||
|
optimizer:
|
||||||
|
The optimizer that we are using.
|
||||||
|
scheduler:
|
||||||
|
The learning rate scheduler we are using.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
if params.start_epoch <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
|
saved_params = load_checkpoint(
|
||||||
|
filename,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
keys = [
|
||||||
|
"best_train_epoch",
|
||||||
|
"best_valid_epoch",
|
||||||
|
"batch_idx_train",
|
||||||
|
"best_train_loss",
|
||||||
|
"best_valid_loss",
|
||||||
|
]
|
||||||
|
for k in keys:
|
||||||
|
params[k] = saved_params[k]
|
||||||
|
|
||||||
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||||
|
rank: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The training model.
|
||||||
|
"""
|
||||||
|
if rank != 0:
|
||||||
|
return
|
||||||
|
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||||
|
save_checkpoint_impl(
|
||||||
|
filename=filename,
|
||||||
|
model=model,
|
||||||
|
params=params,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.best_train_epoch == params.cur_epoch:
|
||||||
|
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||||
|
copyfile(src=filename, dst=best_train_filename)
|
||||||
|
|
||||||
|
if params.best_valid_epoch == params.cur_epoch:
|
||||||
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
model: nn.Module,
|
||||||
|
batch: Tuple,
|
||||||
|
is_training: bool,
|
||||||
|
):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Compute training or validation loss given the model and its inputs
|
||||||
|
(this corresponds to log-prob of the targets, with weighting
|
||||||
|
of 1.0 for masked subsequences
|
||||||
|
(including padding blanks), and something smaller, e.g. 0.25,
|
||||||
|
for non-masked positions (this is not totally trivial due to
|
||||||
|
a small amount of randomization of symbols).
|
||||||
|
|
||||||
|
This loss is not normalized; you can divide by batch[4].sum()
|
||||||
|
to get a normalized loss (i.e. divide by soft-count).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
Parameters for training. See :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The model for training. It is an instance of MaskedLmConformer in our case.
|
||||||
|
batch:
|
||||||
|
A batch of data, actually a tuple of 5 tensors (on the device), as returned
|
||||||
|
by collate_fn in ./dataset.py.
|
||||||
|
is_training:
|
||||||
|
True for training. False for validation. When it is True, this
|
||||||
|
function enables autograd during computation; when it is False, it
|
||||||
|
disables autograd.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Returns the loss as a scalar tensor.
|
||||||
|
"""
|
||||||
|
(masked_src_symbols, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask, tgt_weights) = batch
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(is_training):
|
||||||
|
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
|
||||||
|
decoder_nll_func = model.module.decoder_nll if isinstance(model, DDP) else model.decoder_nll
|
||||||
|
tgt_nll = decoder_nll_func(memory, pos_emb, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask)
|
||||||
|
loss = (tgt_nll * tgt_weights).sum()
|
||||||
|
|
||||||
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def compute_validation_loss(
|
||||||
|
device: torch.device,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
world_size: int = 1,
|
||||||
|
) -> None:
|
||||||
|
"""Run the validation process. The validation loss
|
||||||
|
is saved in `params.valid_loss`.
|
||||||
|
"""
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
tot_loss = 0.0
|
||||||
|
tot_frames = 0.0
|
||||||
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
|
if batch_idx == params.num_valid_batches:
|
||||||
|
break
|
||||||
|
batch = tuple(x.to(device) for x in batch)
|
||||||
|
|
||||||
|
|
||||||
|
loss = compute_loss(model, batch, is_training=False)
|
||||||
|
num_frames = batch[4].sum()
|
||||||
|
|
||||||
|
assert loss.requires_grad is False
|
||||||
|
|
||||||
|
loss_cpu = loss.detach().cpu().item()
|
||||||
|
num_frames_cpu = num_frames.cpu().item()
|
||||||
|
|
||||||
|
tot_loss += loss_cpu
|
||||||
|
tot_frames += num_frames_cpu
|
||||||
|
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
s = torch.tensor(
|
||||||
|
[tot_loss, tot_frames],
|
||||||
|
device=loss.device,
|
||||||
|
)
|
||||||
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||||
|
(tot_loss, tot_frames) = s.cpu().tolist()
|
||||||
|
|
||||||
|
params.valid_loss = tot_loss / tot_frames
|
||||||
|
|
||||||
|
if params.valid_loss < params.best_valid_loss:
|
||||||
|
params.best_valid_epoch = params.cur_epoch
|
||||||
|
params.best_valid_loss = params.valid_loss
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(
|
||||||
|
device: torch.device,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
train_dl: torch.utils.data.DataLoader,
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
|
world_size: int = 1,
|
||||||
|
) -> None:
|
||||||
|
"""Train the model for one epoch.
|
||||||
|
|
||||||
|
The training loss from the mean of all frames is saved in
|
||||||
|
`params.train_loss`. It runs the validation process every
|
||||||
|
`params.valid_interval` batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device:
|
||||||
|
The device to use for training (model must be on this device)
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The model for training.
|
||||||
|
optimizer:
|
||||||
|
The optimizer we are using.
|
||||||
|
train_dl:
|
||||||
|
Dataloader for the training dataset.
|
||||||
|
valid_dl:
|
||||||
|
Dataloader for the validation dataset.
|
||||||
|
tb_writer:
|
||||||
|
Writer to write log messages to tensorboard.
|
||||||
|
world_size:
|
||||||
|
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||||
|
"""
|
||||||
|
model.train() # training mode
|
||||||
|
|
||||||
|
tot_loss = 0.0 # sum of losses over all batches
|
||||||
|
tot_frames = 0.0 # sum of frames over all batches
|
||||||
|
|
||||||
|
params.tot_loss = 0.0
|
||||||
|
params.tot_frames = 0.0
|
||||||
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
params.batch_idx_train += 1
|
||||||
|
batch = tuple(x.to(device) for x in batch)
|
||||||
|
|
||||||
|
try:
|
||||||
|
loss = compute_loss(
|
||||||
|
model=model,
|
||||||
|
batch=batch,
|
||||||
|
is_training=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
# We are not normalizing by the num-frames, but Adam/Madam are insensitive to the total
|
||||||
|
# gradient scale so this should not matter.
|
||||||
|
# clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
|
optimizer.step()
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(f"Error on batch of shape (N,T) = {batch[0].shape}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
loss_cpu = loss.detach().cpu().item()
|
||||||
|
num_frames_cpu = batch[4].sum().cpu().item()
|
||||||
|
|
||||||
|
tot_loss += loss_cpu
|
||||||
|
tot_frames += num_frames_cpu
|
||||||
|
|
||||||
|
params.tot_frames += num_frames_cpu
|
||||||
|
params.tot_loss += loss_cpu
|
||||||
|
|
||||||
|
tot_avg_loss = tot_loss / tot_frames
|
||||||
|
|
||||||
|
if batch_idx % params.log_interval == 0:
|
||||||
|
logging.info(
|
||||||
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||||
|
f"batch avg loss {loss_cpu/num_frames_cpu:.4f}, "
|
||||||
|
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||||
|
f"batch shape: {tuple(batch[0].shape)}")
|
||||||
|
|
||||||
|
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/current_loss",
|
||||||
|
loss_cpu / num_frames_cpu,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/tot_avg_loss",
|
||||||
|
tot_avg_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||||
|
tot_loss = 0.0 # sum of losses over all batches
|
||||||
|
tot_frames = 0.0 # sum of frames over all batches
|
||||||
|
|
||||||
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
|
compute_validation_loss(
|
||||||
|
device=device,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
valid_dl=valid_dl,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
model.train()
|
||||||
|
logging.info(
|
||||||
|
f"Epoch {params.cur_epoch}, "
|
||||||
|
f"valid loss {params.valid_loss:.4f},"
|
||||||
|
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||||
|
f"best valid epoch: {params.best_valid_epoch}"
|
||||||
|
)
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/valid_loss",
|
||||||
|
params.valid_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
|
params.train_loss = params.tot_loss / params.tot_frames
|
||||||
|
|
||||||
|
if params.train_loss < params.best_train_loss:
|
||||||
|
params.best_train_epoch = params.cur_epoch
|
||||||
|
params.best_train_loss = params.train_loss
|
||||||
|
|
||||||
|
|
||||||
|
def run(rank, world_size, args):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
rank:
|
||||||
|
It is a value between 0 and `world_size-1`, which is
|
||||||
|
passed automatically by `mp.spawn()` in :func:`main`.
|
||||||
|
The node with rank 0 is responsible for saving checkpoint.
|
||||||
|
world_size:
|
||||||
|
Number of GPUs for DDP training.
|
||||||
|
args:
|
||||||
|
The return value of get_parser().parse_args()
|
||||||
|
"""
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
fix_random_seed(42)
|
||||||
|
if world_size > 1:
|
||||||
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
|
logging.info("Training started")
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
if args.tensorboard and rank == 0:
|
||||||
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
|
else:
|
||||||
|
tb_writer = None
|
||||||
|
|
||||||
|
num_tokens = params.num_tokens
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = MaskedLmConformer(
|
||||||
|
num_classes=params.num_tokens,
|
||||||
|
d_model=params.attention_dim,
|
||||||
|
nhead=params.nhead,
|
||||||
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
if world_size > 1:
|
||||||
|
model = DDP(model, device_ids=[rank])
|
||||||
|
|
||||||
|
# Caution: don't forget to do optimizer.set_epoch() with Gloam!
|
||||||
|
# Don't remove this warning!
|
||||||
|
optimizer = Gloam(
|
||||||
|
model.parameters(),
|
||||||
|
max_lrate=params.max_lrate,
|
||||||
|
first_decrease_epoch=1,
|
||||||
|
decay_per_epoch=0.9
|
||||||
|
)
|
||||||
|
|
||||||
|
if checkpoints:
|
||||||
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
|
|
||||||
|
train,test = dataset.load_train_test_lm_dataset(params.lm_dataset)
|
||||||
|
|
||||||
|
collate_fn=dataset.CollateFn(bos_sym=params.bos_sym,
|
||||||
|
eos_sym=params.eos_sym,
|
||||||
|
blank_sym=params.blank_sym,
|
||||||
|
mask_proportion=0.15,
|
||||||
|
padding_proportion=0.15,
|
||||||
|
randomize_proportion=0.05,
|
||||||
|
inv_mask_length=0.25,
|
||||||
|
unmasked_weight=0.25)
|
||||||
|
|
||||||
|
train_sampler = dataset.LmBatchSampler(train,
|
||||||
|
symbols_per_batch=params.symbols_per_batch,
|
||||||
|
world_size=world_size, rank=rank)
|
||||||
|
test_sampler = dataset.LmBatchSampler(test,
|
||||||
|
symbols_per_batch=params.symbols_per_batch,
|
||||||
|
world_size=world_size, rank=rank)
|
||||||
|
|
||||||
|
train_dl = torch.utils.data.DataLoader(train,
|
||||||
|
batch_sampler=train_sampler,
|
||||||
|
collate_fn=collate_fn)
|
||||||
|
|
||||||
|
valid_dl = torch.utils.data.DataLoader(test,
|
||||||
|
batch_sampler=test_sampler,
|
||||||
|
collate_fn=collate_fn)
|
||||||
|
|
||||||
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
train_sampler.set_epoch(epoch)
|
||||||
|
optimizer.set_epoch(epoch) # Caution: this is specific to the Gloam
|
||||||
|
# optimizer.
|
||||||
|
|
||||||
|
cur_lr = optimizer._rate
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
||||||
|
|
||||||
|
params.cur_epoch = epoch
|
||||||
|
|
||||||
|
train_one_epoch(
|
||||||
|
device=device,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
train_dl=train_dl,
|
||||||
|
valid_dl=valid_dl,
|
||||||
|
tb_writer=tb_writer,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_checkpoint(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
cleanup_dist()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
world_size = args.world_size
|
||||||
|
assert world_size >= 1
|
||||||
|
if world_size > 1:
|
||||||
|
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||||
|
else:
|
||||||
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user