mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 04:22:21 +00:00
Changes to dataset to prevent OOM on batches with short sentences
This commit is contained in:
parent
9576d6574f
commit
e6eefeba88
@ -2,6 +2,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import k2
|
import k2
|
||||||
import _k2
|
import _k2
|
||||||
|
import logging
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List, Tuple, Union
|
from typing import Optional, List, Tuple, Union
|
||||||
@ -333,6 +334,7 @@ def collate_fn(sentences: List[List[int]],
|
|||||||
"""
|
"""
|
||||||
assert blank_sym not in [bos_sym, eos_sym]
|
assert blank_sym not in [bos_sym, eos_sym]
|
||||||
max_sent_len = max([ len(s) for s in sentences])
|
max_sent_len = max([ len(s) for s in sentences])
|
||||||
|
#logging.info(f"Sentence lengths: {[ len(s) for s in sentences]}")
|
||||||
|
|
||||||
typical_mask_and_pad = int(max_sent_len * (mask_proportion + padding_proportion))
|
typical_mask_and_pad = int(max_sent_len * (mask_proportion + padding_proportion))
|
||||||
|
|
||||||
@ -635,17 +637,22 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, dataset: LmDataset,
|
def __init__(self, dataset: LmDataset,
|
||||||
symbols_per_batch: int,
|
symbols_per_batch: int,
|
||||||
quadratic_constant: float = 0.005,
|
length_ceil: float = 200.0,
|
||||||
|
length_floor: float = 4.0,
|
||||||
world_size: Optional[int] = None,
|
world_size: Optional[int] = None,
|
||||||
rank: int = None,
|
rank: int = None,
|
||||||
seed: int = 0):
|
seed: int = 0,
|
||||||
|
delay_init: bool = False):
|
||||||
"""
|
"""
|
||||||
Constructor documentation:
|
Constructor documentation:
|
||||||
dataset: the LmDataset object that we are sampling from. This
|
dataset: the LmDataset object that we are sampling from. This
|
||||||
class does not retain a reference to the LmDataset.
|
class does not retain a reference to the LmDataset.
|
||||||
symbols_per_batch: The number of BPE symbols desired in each minibatch
|
symbols_per_batch: The number of BPE symbols desired in each minibatch
|
||||||
quadratic_constant: After the sentence length gets more than about
|
length_floor: When the sentence length gets less than about this much,
|
||||||
1.0/quadratic_constant, the batch size will start decreasing
|
the batch size stops increasing inversely with sentence
|
||||||
|
length. Prevent OOM on batches with short sentences.
|
||||||
|
length_ceil: After the sentence length gets more than about
|
||||||
|
this much, the batch size will start decreasing
|
||||||
as 1/(sentence-length^2). This is a mechanism to
|
as 1/(sentence-length^2). This is a mechanism to
|
||||||
avoid excessive memory consumption in transformers, when
|
avoid excessive memory consumption in transformers, when
|
||||||
sentence length gets long.
|
sentence length gets long.
|
||||||
@ -654,10 +661,17 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
|||||||
rank: The rank of this sampler/process for distributed operation; if None,
|
rank: The rank of this sampler/process for distributed operation; if None,
|
||||||
will be worked out from torch.distributed.
|
will be worked out from torch.distributed.
|
||||||
seed: The random seed
|
seed: The random seed
|
||||||
|
delay_init: If true, will omit calling self.set_epoch(0) at the
|
||||||
|
end of the __init__ function. In this case the caller
|
||||||
|
must call set_epoch(0). [Setting this option is necessary
|
||||||
|
to work with data-loader worker processes plus DDP, since
|
||||||
|
set_epoch() will use ddp, which I believe is a no-no prior
|
||||||
|
to initializing data-loaders.]
|
||||||
"""
|
"""
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.symbols_per_batch = symbols_per_batch
|
self.symbols_per_batch = symbols_per_batch
|
||||||
self.quadratic_constant = quadratic_constant
|
self.length_floor = length_floor
|
||||||
|
self.quadratic_constant = 1.0 / length_ceil
|
||||||
self._maybe_init_distributed(world_size=world_size, rank=rank)
|
self._maybe_init_distributed(world_size=world_size, rank=rank)
|
||||||
|
|
||||||
# a configuration constant we don't expose.
|
# a configuration constant we don't expose.
|
||||||
@ -698,8 +712,20 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
|||||||
# `data_indexes` above (this is not stored, as we know the formula).
|
# `data_indexes` above (this is not stored, as we know the formula).
|
||||||
self.sentence_lengths = sentence_lengths
|
self.sentence_lengths = sentence_lengths
|
||||||
|
|
||||||
self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes
|
if not delay_init:
|
||||||
|
self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes
|
||||||
|
|
||||||
|
def _sync_sizes(self, device: torch.device = torch.device('cuda')):
|
||||||
|
# Calling this on all copies of a DDP setup will sync the sizes so that
|
||||||
|
# all copies have the exact same number of batches. I think
|
||||||
|
# this needs to be called with the GPU device, not sure if it would
|
||||||
|
# work otherwise.
|
||||||
|
if self.world_size > 1:
|
||||||
|
min_size = torch.tensor([len(self.batch_indices)], device=device, dtype=torch.int64)
|
||||||
|
dist.all_reduce(min_size, op=dist.ReduceOp.MIN)
|
||||||
|
min_size = min_size.to('cpu').item()
|
||||||
|
logging.info(f"world_size={self.world_size}, rank={self.rank}: reducing batch indices from {len(self.batch_indices)} to {min_size}")
|
||||||
|
self.batch_indices = self.batch_indices[0:min_size]
|
||||||
|
|
||||||
def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]):
|
def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]):
|
||||||
if world_size is not None:
|
if world_size is not None:
|
||||||
@ -714,6 +740,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
|||||||
self.rank = dist.get_rank() if rank is None else rank
|
self.rank = dist.get_rank() if rank is None else rank
|
||||||
assert self.rank < self.world_size
|
assert self.rank < self.world_size
|
||||||
|
|
||||||
|
|
||||||
def set_epoch(self, epoch: int):
|
def set_epoch(self, epoch: int):
|
||||||
"""
|
"""
|
||||||
Must be called at the beginning of each epoch, before initializing the DataLoader,
|
Must be called at the beginning of each epoch, before initializing the DataLoader,
|
||||||
@ -727,7 +754,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
|||||||
|
|
||||||
# This mechanism regulates the batch size so that we don't get OOM in transformers
|
# This mechanism regulates the batch size so that we don't get OOM in transformers
|
||||||
# when the sentences are long.
|
# when the sentences are long.
|
||||||
sentence_lengths = sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant
|
sentence_lengths = (sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant) + self.length_floor
|
||||||
|
|
||||||
values, indices = torch.sort(sentence_lengths) # values,indices dtypes: torch.float,torch.int64
|
values, indices = torch.sort(sentence_lengths) # values,indices dtypes: torch.float,torch.int64
|
||||||
|
|
||||||
@ -741,7 +768,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
|||||||
|
|
||||||
# now `batch_ids` will be: [0, 0, 0, 0, .., 0, 1, 1, 1, ... 1, 2, ... ],
|
# now `batch_ids` will be: [0, 0, 0, 0, .., 0, 1, 1, 1, ... 1, 2, ... ],
|
||||||
# saying which batch each element of values/indices belongs to.
|
# saying which batch each element of values/indices belongs to.
|
||||||
batch_ids = (torch.cumsum(values, dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32)
|
batch_ids = (torch.cumsum(values.to(dtype=torch.double), dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32)
|
||||||
|
|
||||||
batch_boundaries = torch.nonzero(batch_ids[1:] - batch_ids[:-1], as_tuple=True)[0]
|
batch_boundaries = torch.nonzero(batch_ids[1:] - batch_ids[:-1], as_tuple=True)[0]
|
||||||
batch_boundaries.add_(1)
|
batch_boundaries.add_(1)
|
||||||
@ -754,6 +781,7 @@ class LmBatchSampler(torch.utils.data.Sampler):
|
|||||||
# necessary to randomize the order of these, to avoid returning batches
|
# necessary to randomize the order of these, to avoid returning batches
|
||||||
# from shortest to longest sentences.
|
# from shortest to longest sentences.
|
||||||
self.batch_indices = torch.randperm(num_batches, generator=g, dtype=torch.int32).tolist()
|
self.batch_indices = torch.randperm(num_batches, generator=g, dtype=torch.int32).tolist()
|
||||||
|
self._sync_sizes()
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user