From e6eefeba882763c6c4e2cdab63a3215a1671cd82 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 24 Aug 2021 14:50:49 +0800 Subject: [PATCH] Changes to dataset to prevent OOM on batches with short sentences --- egs/librispeech/ASR/conformer_lm/dataset.py | 44 +++++++++++++++++---- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py index 3074d1099..fcf7f39f0 100644 --- a/egs/librispeech/ASR/conformer_lm/dataset.py +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -2,6 +2,7 @@ import torch import torch.distributed as dist import k2 import _k2 +import logging import sentencepiece as spm from pathlib import Path from typing import Optional, List, Tuple, Union @@ -333,6 +334,7 @@ def collate_fn(sentences: List[List[int]], """ assert blank_sym not in [bos_sym, eos_sym] max_sent_len = max([ len(s) for s in sentences]) + #logging.info(f"Sentence lengths: {[ len(s) for s in sentences]}") typical_mask_and_pad = int(max_sent_len * (mask_proportion + padding_proportion)) @@ -635,17 +637,22 @@ class LmBatchSampler(torch.utils.data.Sampler): """ def __init__(self, dataset: LmDataset, symbols_per_batch: int, - quadratic_constant: float = 0.005, + length_ceil: float = 200.0, + length_floor: float = 4.0, world_size: Optional[int] = None, rank: int = None, - seed: int = 0): + seed: int = 0, + delay_init: bool = False): """ Constructor documentation: dataset: the LmDataset object that we are sampling from. This class does not retain a reference to the LmDataset. symbols_per_batch: The number of BPE symbols desired in each minibatch - quadratic_constant: After the sentence length gets more than about - 1.0/quadratic_constant, the batch size will start decreasing + length_floor: When the sentence length gets less than about this much, + the batch size stops increasing inversely with sentence + length. Prevent OOM on batches with short sentences. + length_ceil: After the sentence length gets more than about + this much, the batch size will start decreasing as 1/(sentence-length^2). This is a mechanism to avoid excessive memory consumption in transformers, when sentence length gets long. @@ -654,10 +661,17 @@ class LmBatchSampler(torch.utils.data.Sampler): rank: The rank of this sampler/process for distributed operation; if None, will be worked out from torch.distributed. seed: The random seed + delay_init: If true, will omit calling self.set_epoch(0) at the + end of the __init__ function. In this case the caller + must call set_epoch(0). [Setting this option is necessary + to work with data-loader worker processes plus DDP, since + set_epoch() will use ddp, which I believe is a no-no prior + to initializing data-loaders.] """ self.seed = seed self.symbols_per_batch = symbols_per_batch - self.quadratic_constant = quadratic_constant + self.length_floor = length_floor + self.quadratic_constant = 1.0 / length_ceil self._maybe_init_distributed(world_size=world_size, rank=rank) # a configuration constant we don't expose. @@ -698,8 +712,20 @@ class LmBatchSampler(torch.utils.data.Sampler): # `data_indexes` above (this is not stored, as we know the formula). self.sentence_lengths = sentence_lengths - self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes + if not delay_init: + self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes + def _sync_sizes(self, device: torch.device = torch.device('cuda')): + # Calling this on all copies of a DDP setup will sync the sizes so that + # all copies have the exact same number of batches. I think + # this needs to be called with the GPU device, not sure if it would + # work otherwise. + if self.world_size > 1: + min_size = torch.tensor([len(self.batch_indices)], device=device, dtype=torch.int64) + dist.all_reduce(min_size, op=dist.ReduceOp.MIN) + min_size = min_size.to('cpu').item() + logging.info(f"world_size={self.world_size}, rank={self.rank}: reducing batch indices from {len(self.batch_indices)} to {min_size}") + self.batch_indices = self.batch_indices[0:min_size] def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]): if world_size is not None: @@ -714,6 +740,7 @@ class LmBatchSampler(torch.utils.data.Sampler): self.rank = dist.get_rank() if rank is None else rank assert self.rank < self.world_size + def set_epoch(self, epoch: int): """ Must be called at the beginning of each epoch, before initializing the DataLoader, @@ -727,7 +754,7 @@ class LmBatchSampler(torch.utils.data.Sampler): # This mechanism regulates the batch size so that we don't get OOM in transformers # when the sentences are long. - sentence_lengths = sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant + sentence_lengths = (sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant) + self.length_floor values, indices = torch.sort(sentence_lengths) # values,indices dtypes: torch.float,torch.int64 @@ -741,7 +768,7 @@ class LmBatchSampler(torch.utils.data.Sampler): # now `batch_ids` will be: [0, 0, 0, 0, .., 0, 1, 1, 1, ... 1, 2, ... ], # saying which batch each element of values/indices belongs to. - batch_ids = (torch.cumsum(values, dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32) + batch_ids = (torch.cumsum(values.to(dtype=torch.double), dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32) batch_boundaries = torch.nonzero(batch_ids[1:] - batch_ids[:-1], as_tuple=True)[0] batch_boundaries.add_(1) @@ -754,6 +781,7 @@ class LmBatchSampler(torch.utils.data.Sampler): # necessary to randomize the order of these, to avoid returning batches # from shortest to longest sentences. self.batch_indices = torch.randperm(num_batches, generator=g, dtype=torch.int32).tolist() + self._sync_sizes() def __len__(self):