Changes to dataset to prevent OOM on batches with short sentences

This commit is contained in:
Daniel Povey 2021-08-24 14:50:49 +08:00
parent 9576d6574f
commit e6eefeba88

View File

@ -2,6 +2,7 @@ import torch
import torch.distributed as dist
import k2
import _k2
import logging
import sentencepiece as spm
from pathlib import Path
from typing import Optional, List, Tuple, Union
@ -333,6 +334,7 @@ def collate_fn(sentences: List[List[int]],
"""
assert blank_sym not in [bos_sym, eos_sym]
max_sent_len = max([ len(s) for s in sentences])
#logging.info(f"Sentence lengths: {[ len(s) for s in sentences]}")
typical_mask_and_pad = int(max_sent_len * (mask_proportion + padding_proportion))
@ -635,17 +637,22 @@ class LmBatchSampler(torch.utils.data.Sampler):
"""
def __init__(self, dataset: LmDataset,
symbols_per_batch: int,
quadratic_constant: float = 0.005,
length_ceil: float = 200.0,
length_floor: float = 4.0,
world_size: Optional[int] = None,
rank: int = None,
seed: int = 0):
seed: int = 0,
delay_init: bool = False):
"""
Constructor documentation:
dataset: the LmDataset object that we are sampling from. This
class does not retain a reference to the LmDataset.
symbols_per_batch: The number of BPE symbols desired in each minibatch
quadratic_constant: After the sentence length gets more than about
1.0/quadratic_constant, the batch size will start decreasing
length_floor: When the sentence length gets less than about this much,
the batch size stops increasing inversely with sentence
length. Prevent OOM on batches with short sentences.
length_ceil: After the sentence length gets more than about
this much, the batch size will start decreasing
as 1/(sentence-length^2). This is a mechanism to
avoid excessive memory consumption in transformers, when
sentence length gets long.
@ -654,10 +661,17 @@ class LmBatchSampler(torch.utils.data.Sampler):
rank: The rank of this sampler/process for distributed operation; if None,
will be worked out from torch.distributed.
seed: The random seed
delay_init: If true, will omit calling self.set_epoch(0) at the
end of the __init__ function. In this case the caller
must call set_epoch(0). [Setting this option is necessary
to work with data-loader worker processes plus DDP, since
set_epoch() will use ddp, which I believe is a no-no prior
to initializing data-loaders.]
"""
self.seed = seed
self.symbols_per_batch = symbols_per_batch
self.quadratic_constant = quadratic_constant
self.length_floor = length_floor
self.quadratic_constant = 1.0 / length_ceil
self._maybe_init_distributed(world_size=world_size, rank=rank)
# a configuration constant we don't expose.
@ -698,8 +712,20 @@ class LmBatchSampler(torch.utils.data.Sampler):
# `data_indexes` above (this is not stored, as we know the formula).
self.sentence_lengths = sentence_lengths
self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes
if not delay_init:
self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes
def _sync_sizes(self, device: torch.device = torch.device('cuda')):
# Calling this on all copies of a DDP setup will sync the sizes so that
# all copies have the exact same number of batches. I think
# this needs to be called with the GPU device, not sure if it would
# work otherwise.
if self.world_size > 1:
min_size = torch.tensor([len(self.batch_indices)], device=device, dtype=torch.int64)
dist.all_reduce(min_size, op=dist.ReduceOp.MIN)
min_size = min_size.to('cpu').item()
logging.info(f"world_size={self.world_size}, rank={self.rank}: reducing batch indices from {len(self.batch_indices)} to {min_size}")
self.batch_indices = self.batch_indices[0:min_size]
def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]):
if world_size is not None:
@ -714,6 +740,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
self.rank = dist.get_rank() if rank is None else rank
assert self.rank < self.world_size
def set_epoch(self, epoch: int):
"""
Must be called at the beginning of each epoch, before initializing the DataLoader,
@ -727,7 +754,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
# This mechanism regulates the batch size so that we don't get OOM in transformers
# when the sentences are long.
sentence_lengths = sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant
sentence_lengths = (sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant) + self.length_floor
values, indices = torch.sort(sentence_lengths) # values,indices dtypes: torch.float,torch.int64
@ -741,7 +768,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
# now `batch_ids` will be: [0, 0, 0, 0, .., 0, 1, 1, 1, ... 1, 2, ... ],
# saying which batch each element of values/indices belongs to.
batch_ids = (torch.cumsum(values, dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32)
batch_ids = (torch.cumsum(values.to(dtype=torch.double), dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32)
batch_boundaries = torch.nonzero(batch_ids[1:] - batch_ids[:-1], as_tuple=True)[0]
batch_boundaries.add_(1)
@ -754,6 +781,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
# necessary to randomize the order of these, to avoid returning batches
# from shortest to longest sentences.
self.batch_indices = torch.randperm(num_batches, generator=g, dtype=torch.int32).tolist()
self._sync_sizes()
def __len__(self):