mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Changes to dataset to prevent OOM on batches with short sentences
This commit is contained in:
parent
9576d6574f
commit
e6eefeba88
@ -2,6 +2,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import k2
|
||||
import _k2
|
||||
import logging
|
||||
import sentencepiece as spm
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple, Union
|
||||
@ -333,6 +334,7 @@ def collate_fn(sentences: List[List[int]],
|
||||
"""
|
||||
assert blank_sym not in [bos_sym, eos_sym]
|
||||
max_sent_len = max([ len(s) for s in sentences])
|
||||
#logging.info(f"Sentence lengths: {[ len(s) for s in sentences]}")
|
||||
|
||||
typical_mask_and_pad = int(max_sent_len * (mask_proportion + padding_proportion))
|
||||
|
||||
@ -635,17 +637,22 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
||||
"""
|
||||
def __init__(self, dataset: LmDataset,
|
||||
symbols_per_batch: int,
|
||||
quadratic_constant: float = 0.005,
|
||||
length_ceil: float = 200.0,
|
||||
length_floor: float = 4.0,
|
||||
world_size: Optional[int] = None,
|
||||
rank: int = None,
|
||||
seed: int = 0):
|
||||
seed: int = 0,
|
||||
delay_init: bool = False):
|
||||
"""
|
||||
Constructor documentation:
|
||||
dataset: the LmDataset object that we are sampling from. This
|
||||
class does not retain a reference to the LmDataset.
|
||||
symbols_per_batch: The number of BPE symbols desired in each minibatch
|
||||
quadratic_constant: After the sentence length gets more than about
|
||||
1.0/quadratic_constant, the batch size will start decreasing
|
||||
length_floor: When the sentence length gets less than about this much,
|
||||
the batch size stops increasing inversely with sentence
|
||||
length. Prevent OOM on batches with short sentences.
|
||||
length_ceil: After the sentence length gets more than about
|
||||
this much, the batch size will start decreasing
|
||||
as 1/(sentence-length^2). This is a mechanism to
|
||||
avoid excessive memory consumption in transformers, when
|
||||
sentence length gets long.
|
||||
@ -654,10 +661,17 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
||||
rank: The rank of this sampler/process for distributed operation; if None,
|
||||
will be worked out from torch.distributed.
|
||||
seed: The random seed
|
||||
delay_init: If true, will omit calling self.set_epoch(0) at the
|
||||
end of the __init__ function. In this case the caller
|
||||
must call set_epoch(0). [Setting this option is necessary
|
||||
to work with data-loader worker processes plus DDP, since
|
||||
set_epoch() will use ddp, which I believe is a no-no prior
|
||||
to initializing data-loaders.]
|
||||
"""
|
||||
self.seed = seed
|
||||
self.symbols_per_batch = symbols_per_batch
|
||||
self.quadratic_constant = quadratic_constant
|
||||
self.length_floor = length_floor
|
||||
self.quadratic_constant = 1.0 / length_ceil
|
||||
self._maybe_init_distributed(world_size=world_size, rank=rank)
|
||||
|
||||
# a configuration constant we don't expose.
|
||||
@ -698,8 +712,20 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
||||
# `data_indexes` above (this is not stored, as we know the formula).
|
||||
self.sentence_lengths = sentence_lengths
|
||||
|
||||
self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes
|
||||
if not delay_init:
|
||||
self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes
|
||||
|
||||
def _sync_sizes(self, device: torch.device = torch.device('cuda')):
|
||||
# Calling this on all copies of a DDP setup will sync the sizes so that
|
||||
# all copies have the exact same number of batches. I think
|
||||
# this needs to be called with the GPU device, not sure if it would
|
||||
# work otherwise.
|
||||
if self.world_size > 1:
|
||||
min_size = torch.tensor([len(self.batch_indices)], device=device, dtype=torch.int64)
|
||||
dist.all_reduce(min_size, op=dist.ReduceOp.MIN)
|
||||
min_size = min_size.to('cpu').item()
|
||||
logging.info(f"world_size={self.world_size}, rank={self.rank}: reducing batch indices from {len(self.batch_indices)} to {min_size}")
|
||||
self.batch_indices = self.batch_indices[0:min_size]
|
||||
|
||||
def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]):
|
||||
if world_size is not None:
|
||||
@ -714,6 +740,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
||||
self.rank = dist.get_rank() if rank is None else rank
|
||||
assert self.rank < self.world_size
|
||||
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
"""
|
||||
Must be called at the beginning of each epoch, before initializing the DataLoader,
|
||||
@ -727,7 +754,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
||||
|
||||
# This mechanism regulates the batch size so that we don't get OOM in transformers
|
||||
# when the sentences are long.
|
||||
sentence_lengths = sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant
|
||||
sentence_lengths = (sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant) + self.length_floor
|
||||
|
||||
values, indices = torch.sort(sentence_lengths) # values,indices dtypes: torch.float,torch.int64
|
||||
|
||||
@ -741,7 +768,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
||||
|
||||
# now `batch_ids` will be: [0, 0, 0, 0, .., 0, 1, 1, 1, ... 1, 2, ... ],
|
||||
# saying which batch each element of values/indices belongs to.
|
||||
batch_ids = (torch.cumsum(values, dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32)
|
||||
batch_ids = (torch.cumsum(values.to(dtype=torch.double), dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32)
|
||||
|
||||
batch_boundaries = torch.nonzero(batch_ids[1:] - batch_ids[:-1], as_tuple=True)[0]
|
||||
batch_boundaries.add_(1)
|
||||
@ -754,6 +781,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
||||
# necessary to randomize the order of these, to avoid returning batches
|
||||
# from shortest to longest sentences.
|
||||
self.batch_indices = torch.randperm(num_batches, generator=g, dtype=torch.int32).tolist()
|
||||
self._sync_sizes()
|
||||
|
||||
|
||||
def __len__(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user