From 91f7b1ce6f70b34f72a385702c77edb421f793ec Mon Sep 17 00:00:00 2001 From: JinZr Date: Fri, 6 Sep 2024 18:07:50 +0800 Subject: [PATCH] sort of fixed DDP training issue --- .../CODEC/encodec/codec_datamodule.py | 24 +++++++++++++++---- egs/libritts/CODEC/encodec/encodec.py | 15 +++++++++--- egs/libritts/CODEC/encodec/loss.py | 6 ++--- egs/libritts/CODEC/encodec/train.py | 24 +++++++++++++++---- 4 files changed, 54 insertions(+), 15 deletions(-) diff --git a/egs/libritts/CODEC/encodec/codec_datamodule.py b/egs/libritts/CODEC/encodec/codec_datamodule.py index b547e8513..e84f08e70 100644 --- a/egs/libritts/CODEC/encodec/codec_datamodule.py +++ b/egs/libritts/CODEC/encodec/codec_datamodule.py @@ -139,7 +139,7 @@ class LibriTTSCodecDataModule: group.add_argument( "--num-workers", type=int, - default=8, + default=2, help="The number of training dataloader workers that " "collect the batches.", ) @@ -155,6 +155,8 @@ class LibriTTSCodecDataModule: self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, + world_size: Optional[int] = None, + rank: Optional[int] = None, ) -> DataLoader: """ Args: @@ -182,6 +184,8 @@ class LibriTTSCodecDataModule: buffer_size=self.args.num_buckets * 2000, shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, ) else: logging.info("Using SimpleCutSampler.") @@ -189,6 +193,8 @@ class LibriTTSCodecDataModule: cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, + world_size=world_size, + rank=rank, ) logging.info("About to create train dataloader") @@ -206,13 +212,18 @@ class LibriTTSCodecDataModule: sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, - persistent_workers=False, + persistent_workers=True, worker_init_fn=worker_init_fn, ) return train_dl - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: Optional[int] = None, + rank: Optional[int] = None, + ) -> DataLoader: logging.info("About to create dev dataset") validate = SpeechSynthesisDataset( @@ -226,14 +237,17 @@ class LibriTTSCodecDataModule: cuts_valid, max_duration=self.args.max_duration, shuffle=False, + world_size=world_size, + rank=rank, ) logging.info("About to create valid dataloader") valid_dl = DataLoader( validate, sampler=valid_sampler, batch_size=None, - num_workers=2, - persistent_workers=False, + num_workers=1, + drop_last=False, + persistent_workers=True, ) return valid_dl diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index 071dc19ba..32d80eb38 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -74,14 +74,18 @@ class Encodec(nn.Module): if not self.cache_generator_outputs or self._cache is None: reuse_cache = False e = self.encoder(speech) - bw = random.choice(self.target_bandwidths) + index = torch.tensor( + random.randint(0, len(self.target_bandwidths) - 1), device=speech.device, + ) + if torch.distributed.is_initialized(): + torch.distributed.broadcast(index, src=0) + bw = self.target_bandwidths[index.item()] quantized, codes, bandwidth, commit_loss = self.quantizer( e, self.frame_rate, bw ) speech_hat = self.decoder(quantized) else: speech_hat = self._cache - # store cache if self.training and self.cache_generator_outputs and not reuse_cache: self._cache = speech_hat @@ -169,7 +173,12 @@ class Encodec(nn.Module): if not self.cache_generator_outputs or self._cache is None: reuse_cache = False e = self.encoder(speech) - bw = random.choice(self.target_bandwidths) + index = torch.tensor( + random.randint(0, len(self.target_bandwidths) - 1), device=speech.device, + ) + if torch.distributed.is_initialized(): + torch.distributed.broadcast(index, src=0) + bw = self.target_bandwidths[index.item()] quantized, codes, bandwidth, commit_loss = self.quantizer( e, self.frame_rate, bw ) diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index 9ec80f536..0614abf92 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -78,10 +78,10 @@ def reconstruction_loss(x, x_hat, args, eps=1e-7): wkwargs={"device": x_hat.device}, ).to(x_hat.device) S_x = melspec(x) - S_G_x = melspec(x_hat) - l1_loss = (S_x - S_G_x).abs().mean() + S_x_hat = melspec(x_hat) + l1_loss = (S_x - S_x_hat).abs().mean() l2_loss = ( - ((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps)) ** 2).mean( + ((torch.log(S_x.abs() + eps) - torch.log(S_x_hat.abs() + eps)) ** 2).mean( dim=-2 ) ** 0.5 diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 11c845856..7dfbef2b6 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -552,13 +552,14 @@ def train_one_epoch( model=model, valid_dl=valid_dl, world_size=world_size, + rank=rank, ) model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) - if tb_writer is not None: + if tb_writer is not None and rank == 0 and speech_hat is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) @@ -647,6 +648,8 @@ def compute_validation_loss( inner_model = model.module if isinstance(model, DDP) else model audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw) returned_sample = (audio_pred, audio) + else: + returned_sample = (None, None) if world_size > 1: tot_loss.reduce(device) @@ -796,7 +799,12 @@ def run(rank, world_size, args): if world_size > 1: logging.info("Using DDP") model = nn.SyncBatchNorm.convert_sync_batchnorm(model) - model = DDP(model, device_ids=[rank], find_unused_parameters=True) + model = DDP( + model, + device_ids=[rank], + find_unused_parameters=True, + broadcast_buffers=False, + ) optimizer_g = torch.optim.AdamW( itertools.chain( @@ -846,10 +854,18 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - train_dl = libritts.train_dataloaders(train_cuts) + train_dl = libritts.train_dataloaders( + train_cuts, + world_size=world_size, + rank=rank, + ) valid_cuts = libritts.dev_clean_cuts() - valid_dl = libritts.valid_dataloaders(valid_cuts) + valid_dl = libritts.valid_dataloaders( + valid_cuts, + world_size=world_size, + rank=rank, + ) # if not params.print_diagnostics: # scan_pessimistic_batches_for_oom(