mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
sort of fixed DDP training issue
This commit is contained in:
parent
2df992f98a
commit
91f7b1ce6f
@ -139,7 +139,7 @@ class LibriTTSCodecDataModule:
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=2,
|
||||||
help="The number of training dataloader workers that "
|
help="The number of training dataloader workers that "
|
||||||
"collect the batches.",
|
"collect the batches.",
|
||||||
)
|
)
|
||||||
@ -155,6 +155,8 @@ class LibriTTSCodecDataModule:
|
|||||||
self,
|
self,
|
||||||
cuts_train: CutSet,
|
cuts_train: CutSet,
|
||||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -182,6 +184,8 @@ class LibriTTSCodecDataModule:
|
|||||||
buffer_size=self.args.num_buckets * 2000,
|
buffer_size=self.args.num_buckets * 2000,
|
||||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SimpleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
@ -189,6 +193,8 @@ class LibriTTSCodecDataModule:
|
|||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
@ -206,13 +212,18 @@ class LibriTTSCodecDataModule:
|
|||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
persistent_workers=False,
|
persistent_workers=True,
|
||||||
worker_init_fn=worker_init_fn,
|
worker_init_fn=worker_init_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_dl
|
return train_dl
|
||||||
|
|
||||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
def valid_dataloaders(
|
||||||
|
self,
|
||||||
|
cuts_valid: CutSet,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
|
) -> DataLoader:
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
|
|
||||||
validate = SpeechSynthesisDataset(
|
validate = SpeechSynthesisDataset(
|
||||||
@ -226,14 +237,17 @@ class LibriTTSCodecDataModule:
|
|||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create valid dataloader")
|
logging.info("About to create valid dataloader")
|
||||||
valid_dl = DataLoader(
|
valid_dl = DataLoader(
|
||||||
validate,
|
validate,
|
||||||
sampler=valid_sampler,
|
sampler=valid_sampler,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
num_workers=2,
|
num_workers=1,
|
||||||
persistent_workers=False,
|
drop_last=False,
|
||||||
|
persistent_workers=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return valid_dl
|
return valid_dl
|
||||||
|
@ -74,14 +74,18 @@ class Encodec(nn.Module):
|
|||||||
if not self.cache_generator_outputs or self._cache is None:
|
if not self.cache_generator_outputs or self._cache is None:
|
||||||
reuse_cache = False
|
reuse_cache = False
|
||||||
e = self.encoder(speech)
|
e = self.encoder(speech)
|
||||||
bw = random.choice(self.target_bandwidths)
|
index = torch.tensor(
|
||||||
|
random.randint(0, len(self.target_bandwidths) - 1), device=speech.device,
|
||||||
|
)
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
torch.distributed.broadcast(index, src=0)
|
||||||
|
bw = self.target_bandwidths[index.item()]
|
||||||
quantized, codes, bandwidth, commit_loss = self.quantizer(
|
quantized, codes, bandwidth, commit_loss = self.quantizer(
|
||||||
e, self.frame_rate, bw
|
e, self.frame_rate, bw
|
||||||
)
|
)
|
||||||
speech_hat = self.decoder(quantized)
|
speech_hat = self.decoder(quantized)
|
||||||
else:
|
else:
|
||||||
speech_hat = self._cache
|
speech_hat = self._cache
|
||||||
|
|
||||||
# store cache
|
# store cache
|
||||||
if self.training and self.cache_generator_outputs and not reuse_cache:
|
if self.training and self.cache_generator_outputs and not reuse_cache:
|
||||||
self._cache = speech_hat
|
self._cache = speech_hat
|
||||||
@ -169,7 +173,12 @@ class Encodec(nn.Module):
|
|||||||
if not self.cache_generator_outputs or self._cache is None:
|
if not self.cache_generator_outputs or self._cache is None:
|
||||||
reuse_cache = False
|
reuse_cache = False
|
||||||
e = self.encoder(speech)
|
e = self.encoder(speech)
|
||||||
bw = random.choice(self.target_bandwidths)
|
index = torch.tensor(
|
||||||
|
random.randint(0, len(self.target_bandwidths) - 1), device=speech.device,
|
||||||
|
)
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
torch.distributed.broadcast(index, src=0)
|
||||||
|
bw = self.target_bandwidths[index.item()]
|
||||||
quantized, codes, bandwidth, commit_loss = self.quantizer(
|
quantized, codes, bandwidth, commit_loss = self.quantizer(
|
||||||
e, self.frame_rate, bw
|
e, self.frame_rate, bw
|
||||||
)
|
)
|
||||||
|
@ -78,10 +78,10 @@ def reconstruction_loss(x, x_hat, args, eps=1e-7):
|
|||||||
wkwargs={"device": x_hat.device},
|
wkwargs={"device": x_hat.device},
|
||||||
).to(x_hat.device)
|
).to(x_hat.device)
|
||||||
S_x = melspec(x)
|
S_x = melspec(x)
|
||||||
S_G_x = melspec(x_hat)
|
S_x_hat = melspec(x_hat)
|
||||||
l1_loss = (S_x - S_G_x).abs().mean()
|
l1_loss = (S_x - S_x_hat).abs().mean()
|
||||||
l2_loss = (
|
l2_loss = (
|
||||||
((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps)) ** 2).mean(
|
((torch.log(S_x.abs() + eps) - torch.log(S_x_hat.abs() + eps)) ** 2).mean(
|
||||||
dim=-2
|
dim=-2
|
||||||
)
|
)
|
||||||
** 0.5
|
** 0.5
|
||||||
|
@ -552,13 +552,14 @@ def train_one_epoch(
|
|||||||
model=model,
|
model=model,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
)
|
)
|
||||||
if tb_writer is not None:
|
if tb_writer is not None and rank == 0 and speech_hat is not None:
|
||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
@ -647,6 +648,8 @@ def compute_validation_loss(
|
|||||||
inner_model = model.module if isinstance(model, DDP) else model
|
inner_model = model.module if isinstance(model, DDP) else model
|
||||||
audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw)
|
audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw)
|
||||||
returned_sample = (audio_pred, audio)
|
returned_sample = (audio_pred, audio)
|
||||||
|
else:
|
||||||
|
returned_sample = (None, None)
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
tot_loss.reduce(device)
|
tot_loss.reduce(device)
|
||||||
@ -796,7 +799,12 @@ def run(rank, world_size, args):
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
model = DDP(
|
||||||
|
model,
|
||||||
|
device_ids=[rank],
|
||||||
|
find_unused_parameters=True,
|
||||||
|
broadcast_buffers=False,
|
||||||
|
)
|
||||||
|
|
||||||
optimizer_g = torch.optim.AdamW(
|
optimizer_g = torch.optim.AdamW(
|
||||||
itertools.chain(
|
itertools.chain(
|
||||||
@ -846,10 +854,18 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
train_dl = libritts.train_dataloaders(train_cuts)
|
train_dl = libritts.train_dataloaders(
|
||||||
|
train_cuts,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
valid_cuts = libritts.dev_clean_cuts()
|
valid_cuts = libritts.dev_clean_cuts()
|
||||||
valid_dl = libritts.valid_dataloaders(valid_cuts)
|
valid_dl = libritts.valid_dataloaders(
|
||||||
|
valid_cuts,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
# if not params.print_diagnostics:
|
# if not params.print_diagnostics:
|
||||||
# scan_pessimistic_batches_for_oom(
|
# scan_pessimistic_batches_for_oom(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user