sort of fixed DDP training issue

This commit is contained in:
JinZr 2024-09-06 18:07:50 +08:00
parent 2df992f98a
commit 91f7b1ce6f
4 changed files with 54 additions and 15 deletions

View File

@ -139,7 +139,7 @@ class LibriTTSCodecDataModule:
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=8, default=2,
help="The number of training dataloader workers that " help="The number of training dataloader workers that "
"collect the batches.", "collect the batches.",
) )
@ -155,6 +155,8 @@ class LibriTTSCodecDataModule:
self, self,
cuts_train: CutSet, cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None, sampler_state_dict: Optional[Dict[str, Any]] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader: ) -> DataLoader:
""" """
Args: Args:
@ -182,6 +184,8 @@ class LibriTTSCodecDataModule:
buffer_size=self.args.num_buckets * 2000, buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000, shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
) )
else: else:
logging.info("Using SimpleCutSampler.") logging.info("Using SimpleCutSampler.")
@ -189,6 +193,8 @@ class LibriTTSCodecDataModule:
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
world_size=world_size,
rank=rank,
) )
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
@ -206,13 +212,18 @@ class LibriTTSCodecDataModule:
sampler=train_sampler, sampler=train_sampler,
batch_size=None, batch_size=None,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
persistent_workers=False, persistent_workers=True,
worker_init_fn=worker_init_fn, worker_init_fn=worker_init_fn,
) )
return train_dl return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def valid_dataloaders(
self,
cuts_valid: CutSet,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
validate = SpeechSynthesisDataset( validate = SpeechSynthesisDataset(
@ -226,14 +237,17 @@ class LibriTTSCodecDataModule:
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
world_size=world_size,
rank=rank,
) )
logging.info("About to create valid dataloader") logging.info("About to create valid dataloader")
valid_dl = DataLoader( valid_dl = DataLoader(
validate, validate,
sampler=valid_sampler, sampler=valid_sampler,
batch_size=None, batch_size=None,
num_workers=2, num_workers=1,
persistent_workers=False, drop_last=False,
persistent_workers=True,
) )
return valid_dl return valid_dl

View File

@ -74,14 +74,18 @@ class Encodec(nn.Module):
if not self.cache_generator_outputs or self._cache is None: if not self.cache_generator_outputs or self._cache is None:
reuse_cache = False reuse_cache = False
e = self.encoder(speech) e = self.encoder(speech)
bw = random.choice(self.target_bandwidths) index = torch.tensor(
random.randint(0, len(self.target_bandwidths) - 1), device=speech.device,
)
if torch.distributed.is_initialized():
torch.distributed.broadcast(index, src=0)
bw = self.target_bandwidths[index.item()]
quantized, codes, bandwidth, commit_loss = self.quantizer( quantized, codes, bandwidth, commit_loss = self.quantizer(
e, self.frame_rate, bw e, self.frame_rate, bw
) )
speech_hat = self.decoder(quantized) speech_hat = self.decoder(quantized)
else: else:
speech_hat = self._cache speech_hat = self._cache
# store cache # store cache
if self.training and self.cache_generator_outputs and not reuse_cache: if self.training and self.cache_generator_outputs and not reuse_cache:
self._cache = speech_hat self._cache = speech_hat
@ -169,7 +173,12 @@ class Encodec(nn.Module):
if not self.cache_generator_outputs or self._cache is None: if not self.cache_generator_outputs or self._cache is None:
reuse_cache = False reuse_cache = False
e = self.encoder(speech) e = self.encoder(speech)
bw = random.choice(self.target_bandwidths) index = torch.tensor(
random.randint(0, len(self.target_bandwidths) - 1), device=speech.device,
)
if torch.distributed.is_initialized():
torch.distributed.broadcast(index, src=0)
bw = self.target_bandwidths[index.item()]
quantized, codes, bandwidth, commit_loss = self.quantizer( quantized, codes, bandwidth, commit_loss = self.quantizer(
e, self.frame_rate, bw e, self.frame_rate, bw
) )

View File

@ -78,10 +78,10 @@ def reconstruction_loss(x, x_hat, args, eps=1e-7):
wkwargs={"device": x_hat.device}, wkwargs={"device": x_hat.device},
).to(x_hat.device) ).to(x_hat.device)
S_x = melspec(x) S_x = melspec(x)
S_G_x = melspec(x_hat) S_x_hat = melspec(x_hat)
l1_loss = (S_x - S_G_x).abs().mean() l1_loss = (S_x - S_x_hat).abs().mean()
l2_loss = ( l2_loss = (
((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps)) ** 2).mean( ((torch.log(S_x.abs() + eps) - torch.log(S_x_hat.abs() + eps)) ** 2).mean(
dim=-2 dim=-2
) )
** 0.5 ** 0.5

View File

@ -552,13 +552,14 @@ def train_one_epoch(
model=model, model=model,
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
rank=rank,
) )
model.train() model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info( logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
) )
if tb_writer is not None: if tb_writer is not None and rank == 0 and speech_hat is not None:
valid_info.write_summary( valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
@ -647,6 +648,8 @@ def compute_validation_loss(
inner_model = model.module if isinstance(model, DDP) else model inner_model = model.module if isinstance(model, DDP) else model
audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw) audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw)
returned_sample = (audio_pred, audio) returned_sample = (audio_pred, audio)
else:
returned_sample = (None, None)
if world_size > 1: if world_size > 1:
tot_loss.reduce(device) tot_loss.reduce(device)
@ -796,7 +799,12 @@ def run(rank, world_size, args):
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[rank], find_unused_parameters=True) model = DDP(
model,
device_ids=[rank],
find_unused_parameters=True,
broadcast_buffers=False,
)
optimizer_g = torch.optim.AdamW( optimizer_g = torch.optim.AdamW(
itertools.chain( itertools.chain(
@ -846,10 +854,18 @@ def run(rank, world_size, args):
if params.inf_check: if params.inf_check:
register_inf_check_hooks(model) register_inf_check_hooks(model)
train_dl = libritts.train_dataloaders(train_cuts) train_dl = libritts.train_dataloaders(
train_cuts,
world_size=world_size,
rank=rank,
)
valid_cuts = libritts.dev_clean_cuts() valid_cuts = libritts.dev_clean_cuts()
valid_dl = libritts.valid_dataloaders(valid_cuts) valid_dl = libritts.valid_dataloaders(
valid_cuts,
world_size=world_size,
rank=rank,
)
# if not params.print_diagnostics: # if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom( # scan_pessimistic_batches_for_oom(