Changes to dataset to prevent OOM on batches with short sentences

This commit is contained in:
Daniel Povey 2021-08-24 14:50:49 +08:00
parent 9576d6574f
commit e6eefeba88

View File

@ -2,6 +2,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import k2 import k2
import _k2 import _k2
import logging
import sentencepiece as spm import sentencepiece as spm
from pathlib import Path from pathlib import Path
from typing import Optional, List, Tuple, Union from typing import Optional, List, Tuple, Union
@ -333,6 +334,7 @@ def collate_fn(sentences: List[List[int]],
""" """
assert blank_sym not in [bos_sym, eos_sym] assert blank_sym not in [bos_sym, eos_sym]
max_sent_len = max([ len(s) for s in sentences]) max_sent_len = max([ len(s) for s in sentences])
#logging.info(f"Sentence lengths: {[ len(s) for s in sentences]}")
typical_mask_and_pad = int(max_sent_len * (mask_proportion + padding_proportion)) typical_mask_and_pad = int(max_sent_len * (mask_proportion + padding_proportion))
@ -635,17 +637,22 @@ class LmBatchSampler(torch.utils.data.Sampler):
""" """
def __init__(self, dataset: LmDataset, def __init__(self, dataset: LmDataset,
symbols_per_batch: int, symbols_per_batch: int,
quadratic_constant: float = 0.005, length_ceil: float = 200.0,
length_floor: float = 4.0,
world_size: Optional[int] = None, world_size: Optional[int] = None,
rank: int = None, rank: int = None,
seed: int = 0): seed: int = 0,
delay_init: bool = False):
""" """
Constructor documentation: Constructor documentation:
dataset: the LmDataset object that we are sampling from. This dataset: the LmDataset object that we are sampling from. This
class does not retain a reference to the LmDataset. class does not retain a reference to the LmDataset.
symbols_per_batch: The number of BPE symbols desired in each minibatch symbols_per_batch: The number of BPE symbols desired in each minibatch
quadratic_constant: After the sentence length gets more than about length_floor: When the sentence length gets less than about this much,
1.0/quadratic_constant, the batch size will start decreasing the batch size stops increasing inversely with sentence
length. Prevent OOM on batches with short sentences.
length_ceil: After the sentence length gets more than about
this much, the batch size will start decreasing
as 1/(sentence-length^2). This is a mechanism to as 1/(sentence-length^2). This is a mechanism to
avoid excessive memory consumption in transformers, when avoid excessive memory consumption in transformers, when
sentence length gets long. sentence length gets long.
@ -654,10 +661,17 @@ class LmBatchSampler(torch.utils.data.Sampler):
rank: The rank of this sampler/process for distributed operation; if None, rank: The rank of this sampler/process for distributed operation; if None,
will be worked out from torch.distributed. will be worked out from torch.distributed.
seed: The random seed seed: The random seed
delay_init: If true, will omit calling self.set_epoch(0) at the
end of the __init__ function. In this case the caller
must call set_epoch(0). [Setting this option is necessary
to work with data-loader worker processes plus DDP, since
set_epoch() will use ddp, which I believe is a no-no prior
to initializing data-loaders.]
""" """
self.seed = seed self.seed = seed
self.symbols_per_batch = symbols_per_batch self.symbols_per_batch = symbols_per_batch
self.quadratic_constant = quadratic_constant self.length_floor = length_floor
self.quadratic_constant = 1.0 / length_ceil
self._maybe_init_distributed(world_size=world_size, rank=rank) self._maybe_init_distributed(world_size=world_size, rank=rank)
# a configuration constant we don't expose. # a configuration constant we don't expose.
@ -698,8 +712,20 @@ class LmBatchSampler(torch.utils.data.Sampler):
# `data_indexes` above (this is not stored, as we know the formula). # `data_indexes` above (this is not stored, as we know the formula).
self.sentence_lengths = sentence_lengths self.sentence_lengths = sentence_lengths
self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes if not delay_init:
self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes
def _sync_sizes(self, device: torch.device = torch.device('cuda')):
# Calling this on all copies of a DDP setup will sync the sizes so that
# all copies have the exact same number of batches. I think
# this needs to be called with the GPU device, not sure if it would
# work otherwise.
if self.world_size > 1:
min_size = torch.tensor([len(self.batch_indices)], device=device, dtype=torch.int64)
dist.all_reduce(min_size, op=dist.ReduceOp.MIN)
min_size = min_size.to('cpu').item()
logging.info(f"world_size={self.world_size}, rank={self.rank}: reducing batch indices from {len(self.batch_indices)} to {min_size}")
self.batch_indices = self.batch_indices[0:min_size]
def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]): def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]):
if world_size is not None: if world_size is not None:
@ -714,6 +740,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
self.rank = dist.get_rank() if rank is None else rank self.rank = dist.get_rank() if rank is None else rank
assert self.rank < self.world_size assert self.rank < self.world_size
def set_epoch(self, epoch: int): def set_epoch(self, epoch: int):
""" """
Must be called at the beginning of each epoch, before initializing the DataLoader, Must be called at the beginning of each epoch, before initializing the DataLoader,
@ -727,7 +754,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
# This mechanism regulates the batch size so that we don't get OOM in transformers # This mechanism regulates the batch size so that we don't get OOM in transformers
# when the sentences are long. # when the sentences are long.
sentence_lengths = sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant sentence_lengths = (sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant) + self.length_floor
values, indices = torch.sort(sentence_lengths) # values,indices dtypes: torch.float,torch.int64 values, indices = torch.sort(sentence_lengths) # values,indices dtypes: torch.float,torch.int64
@ -741,7 +768,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
# now `batch_ids` will be: [0, 0, 0, 0, .., 0, 1, 1, 1, ... 1, 2, ... ], # now `batch_ids` will be: [0, 0, 0, 0, .., 0, 1, 1, 1, ... 1, 2, ... ],
# saying which batch each element of values/indices belongs to. # saying which batch each element of values/indices belongs to.
batch_ids = (torch.cumsum(values, dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32) batch_ids = (torch.cumsum(values.to(dtype=torch.double), dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32)
batch_boundaries = torch.nonzero(batch_ids[1:] - batch_ids[:-1], as_tuple=True)[0] batch_boundaries = torch.nonzero(batch_ids[1:] - batch_ids[:-1], as_tuple=True)[0]
batch_boundaries.add_(1) batch_boundaries.add_(1)
@ -754,6 +781,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
# necessary to randomize the order of these, to avoid returning batches # necessary to randomize the order of these, to avoid returning batches
# from shortest to longest sentences. # from shortest to longest sentences.
self.batch_indices = torch.randperm(num_batches, generator=g, dtype=torch.int32).tolist() self.batch_indices = torch.randperm(num_batches, generator=g, dtype=torch.int32).tolist()
self._sync_sizes()
def __len__(self): def __len__(self):