added scheduler w/ warmup

This commit is contained in:
JinZr 2024-10-06 19:07:07 +08:00
parent d83ce89fca
commit 58f6562824
4 changed files with 209 additions and 38 deletions

View File

@ -60,7 +60,7 @@ class Encodec(nn.Module):
self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss( self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss(
average_by_discriminators=True, loss_type="hinge" average_by_discriminators=True, loss_type="hinge"
) )
self.feature_match_loss = FeatureLoss(average_by_layers=False) self.feature_match_loss = FeatureLoss()
self.wav_reconstruction_loss = WavReconstructionLoss() self.wav_reconstruction_loss = WavReconstructionLoss()
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss( self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
sampling_rate=self.sampling_rate sampling_rate=self.sampling_rate

View File

@ -45,8 +45,8 @@ class GeneratorAdversarialLoss(torch.nn.Module):
Tensor: Generator adversarial loss value. Tensor: Generator adversarial loss value.
""" """
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0 adv_loss = 0.0
if isinstance(outputs, (tuple, list)):
for i, outputs_ in enumerate(outputs): for i, outputs_ in enumerate(outputs):
if isinstance(outputs_, (tuple, list)): if isinstance(outputs_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps # NOTE(kan-bayashi): case including feature maps
@ -55,9 +55,10 @@ class GeneratorAdversarialLoss(torch.nn.Module):
if self.average_by_discriminators: if self.average_by_discriminators:
adv_loss /= i + 1 adv_loss /= i + 1
else: else:
adv_loss = self.criterion(outputs) for i, outputs_ in enumerate(outputs):
adv_loss += self.criterion(outputs_)
return adv_loss / len(outputs) adv_loss /= i + 1
return adv_loss
def _mse_loss(self, x): def _mse_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size())) return F.mse_loss(x, x.new_ones(x.size()))
@ -112,9 +113,9 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
Tensor: Discriminator fake loss value. Tensor: Discriminator fake loss value.
""" """
if isinstance(outputs, (tuple, list)):
real_loss = 0.0 real_loss = 0.0
fake_loss = 0.0 fake_loss = 0.0
if isinstance(outputs, (tuple, list)):
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)): if isinstance(outputs_hat_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps # NOTE(kan-bayashi): case including feature maps
@ -126,10 +127,13 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
fake_loss /= i + 1 fake_loss /= i + 1
real_loss /= i + 1 real_loss /= i + 1
else: else:
real_loss = self.real_criterion(outputs) for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
fake_loss = self.fake_criterion(outputs_hat) real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
fake_loss /= i + 1
real_loss /= i + 1
return real_loss / len(outputs), fake_loss / len(outputs) return real_loss, fake_loss
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_ones(x.size())) return F.mse_loss(x, x.new_ones(x.size()))
@ -204,7 +208,7 @@ class FeatureLoss(torch.nn.Module):
if self.average_by_discriminators: if self.average_by_discriminators:
feat_match_loss /= i + 1 feat_match_loss /= i + 1
return feat_match_loss / (len(feats) * len(feats[0])) return feat_match_loss
class MelSpectrogramReconstructionLoss(torch.nn.Module): class MelSpectrogramReconstructionLoss(torch.nn.Module):
@ -233,7 +237,7 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
self.wav_to_specs.append( self.wav_to_specs.append(
MelSpectrogram( MelSpectrogram(
sample_rate=sampling_rate, sample_rate=sampling_rate,
n_fft=s, n_fft=max(s, 512),
win_length=s, win_length=s,
hop_length=s // 4, hop_length=s // 4,
n_mels=n_mels, n_mels=n_mels,
@ -462,8 +466,15 @@ def loss_g(
if __name__ == "__main__": if __name__ == "__main__":
la = FeatureLoss(average_by_layers=False, average_by_discriminators=False) # la = FeatureLoss(average_by_layers=True, average_by_discriminators=True)
aa = [torch.rand(192, 192) for _ in range(3)] # aa = [torch.rand(192, 192) for _ in range(3)]
bb = [torch.rand(192, 192) for _ in range(3)] # bb = [torch.rand(192, 192) for _ in range(3)]
print(la(bb, aa)) # print(la(bb, aa))
print(feature_loss(aa, bb)) # print(feature_loss(aa, bb))
la = GeneratorAdversarialLoss(average_by_discriminators=True, loss_type="hinge")
aa = torch.Tensor([0.1, 0.2, 0.3, 0.4])
bb = torch.Tensor([0.4, 0.3, 0.2, 0.1])
print(la(aa))
print(adversarial_g_loss(aa))
print(la(bb))
print(adversarial_g_loss(bb))

View File

@ -0,0 +1,141 @@
import math
from bisect import bisect_right
from torch.optim.lr_scheduler import _LRScheduler
class WarmupLrScheduler(_LRScheduler):
def __init__(
self,
optimizer,
warmup_epoch=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.warmup_epoch = warmup_epoch
self.warmup_ratio = warmup_ratio
self.warmup = warmup
super(WarmupLrScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
ratio = self.get_lr_ratio()
lrs = [ratio * lr for lr in self.base_lrs]
return lrs
def get_lr_ratio(self):
if self.last_epoch < self.warmup_epoch:
ratio = self.get_warmup_ratio()
else:
ratio = self.get_main_ratio()
return ratio
def get_main_ratio(self):
raise NotImplementedError
def get_warmup_ratio(self):
assert self.warmup in ("linear", "exp")
alpha = self.last_epoch / self.warmup_epoch
if self.warmup == "linear":
ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
elif self.warmup == "exp":
ratio = self.warmup_ratio ** (1.0 - alpha)
return ratio
class WarmupPolyLrScheduler(WarmupLrScheduler):
def __init__(
self,
optimizer,
power,
max_epoch,
warmup_epoch=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.power = power
self.max_epoch = max_epoch
super(WarmupPolyLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch
real_max_epoch = self.max_epoch - self.warmup_epoch
alpha = real_epoch / real_max_epoch
ratio = (1 - alpha) ** self.power
return ratio
class WarmupExpLrScheduler(WarmupLrScheduler):
def __init__(
self,
optimizer,
gamma,
interval=1,
warmup_epoch=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.gamma = gamma
self.interval = interval
super(WarmupExpLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch
ratio = self.gamma ** (real_epoch // self.interval)
return ratio
class WarmupCosineLrScheduler(WarmupLrScheduler):
def __init__(
self,
optimizer,
max_epoch,
eta_ratio=0,
warmup_epoch=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.eta_ratio = eta_ratio
self.max_epoch = max_epoch
super(WarmupCosineLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_max_epoch = self.max_epoch - self.warmup_epoch
return (
self.eta_ratio
+ (1 - self.eta_ratio)
* (1 + math.cos(math.pi * self.last_epoch / real_max_epoch))
/ 2
)
class WarmupStepLrScheduler(WarmupLrScheduler):
def __init__(
self,
optimizer,
milestones: list,
gamma=0.1,
warmup_epoch=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.milestones = milestones
self.gamma = gamma
super(WarmupStepLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch
ratio = self.gamma ** bisect_right(self.milestones, real_epoch)
return ratio

View File

@ -16,6 +16,7 @@ from encodec import Encodec
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from loss import adopt_weight from loss import adopt_weight
from scheduler import WarmupCosineLrScheduler
from torch import nn from torch import nn
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -188,10 +189,10 @@ def get_params() -> AttributeDict:
"sampling_rate": 24000, "sampling_rate": 24000,
"chunk_size": 1.0, # in seconds "chunk_size": 1.0, # in seconds
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
"lambda_wav": 0.1, # loss scaling coefficient for waveform loss "lambda_wav": 1.0, # loss scaling coefficient for waveform loss
"lambda_feat": 3.0, # loss scaling coefficient for feat loss "lambda_feat": 3.0, # loss scaling coefficient for feat loss
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss "lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
"lambda_com": 1.0, # loss scaling coefficient for commitment loss "lambda_com": 100.0, # loss scaling coefficient for commitment loss
} }
) )
@ -260,7 +261,7 @@ def get_model(params: AttributeDict) -> nn.Module:
# } # }
# discriminator_params = { # discriminator_params = {
# "stft_discriminator_n_filters": 32, # "stft_discriminator_n_filters": 32,
# "discriminator_iter_start": 500, # "discriminator_epoch_start": 5,
# } # }
# inference_params = { # inference_params = {
# "target_bw": 7.5, # "target_bw": 7.5,
@ -275,7 +276,10 @@ def get_model(params: AttributeDict) -> nn.Module:
} }
discriminator_params = { discriminator_params = {
"stft_discriminator_n_filters": 32, "stft_discriminator_n_filters": 32,
"discriminator_iter_start": 500, "discriminator_epoch_start": 3,
"n_ffts": [1024, 2048, 512],
"hop_lengths": [256, 512, 128],
"win_lengths": [1024, 2048, 512],
} }
inference_params = { inference_params = {
"target_bw": 12, "target_bw": 12,
@ -316,7 +320,10 @@ def get_model(params: AttributeDict) -> nn.Module:
multi_scale_discriminator=None, multi_scale_discriminator=None,
multi_period_discriminator=None, multi_period_discriminator=None,
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator( multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(
n_filters=params.stft_discriminator_n_filters n_filters=params.stft_discriminator_n_filters,
n_ffts=params.n_ffts,
hop_lengths=params.hop_lengths,
win_lengths=params.win_lengths,
), ),
) )
return model return model
@ -437,8 +444,8 @@ def train_one_epoch(
with autocast(enabled=params.use_fp16): with autocast(enabled=params.use_fp16):
d_weight = adopt_weight( d_weight = adopt_weight(
params.lambda_adv, params.lambda_adv,
params.batch_idx_train, params.cur_epoch,
threshold=params.discriminator_iter_start, threshold=params.discriminator_epoch_start,
) )
# forward discriminator # forward discriminator
( (
@ -473,8 +480,8 @@ def train_one_epoch(
with autocast(enabled=params.use_fp16): with autocast(enabled=params.use_fp16):
g_weight = adopt_weight( g_weight = adopt_weight(
params.lambda_adv, params.lambda_adv,
params.batch_idx_train, params.cur_epoch,
threshold=params.discriminator_iter_start, threshold=params.discriminator_epoch_start,
) )
# forward generator # forward generator
( (
@ -507,7 +514,7 @@ def train_one_epoch(
gen_loss = ( gen_loss = (
gen_adv_loss gen_adv_loss
+ reconstruction_loss + reconstruction_loss
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + params.lambda_feat * feature_loss
+ params.lambda_com * commit_loss + params.lambda_com * commit_loss
) )
for k, v in stats_g.items(): for k, v in stats_g.items():
@ -688,8 +695,8 @@ def compute_validation_loss(
d_weight = adopt_weight( d_weight = adopt_weight(
params.lambda_adv, params.lambda_adv,
params.batch_idx_train, params.cur_epoch,
threshold=params.discriminator_iter_start, threshold=params.discriminator_epoch_start,
) )
# forward discriminator # forward discriminator
@ -721,8 +728,8 @@ def compute_validation_loss(
g_weight = adopt_weight( g_weight = adopt_weight(
params.lambda_adv, params.lambda_adv,
params.batch_idx_train, params.cur_epoch,
threshold=params.discriminator_iter_start, threshold=params.discriminator_epoch_start,
) )
# forward generator # forward generator
( (
@ -753,7 +760,7 @@ def compute_validation_loss(
gen_loss = ( gen_loss = (
gen_adv_loss gen_adv_loss
+ reconstruction_loss + reconstruction_loss
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + params.lambda_feat * feature_loss
+ params.lambda_com * commit_loss + params.lambda_com * commit_loss
) )
assert gen_loss.requires_grad is False assert gen_loss.requires_grad is False
@ -831,8 +838,8 @@ def scan_pessimistic_batches_for_oom(
+ disc_scale_fake_adv_loss + disc_scale_fake_adv_loss
) * adopt_weight( ) * adopt_weight(
params.lambda_adv, params.lambda_adv,
params.batch_idx_train, params.cur_epoch,
threshold=params.discriminator_iter_start, threshold=params.discriminator_train_start,
) )
optimizer_d.zero_grad() optimizer_d.zero_grad()
loss_d.backward() loss_d.backward()
@ -859,8 +866,8 @@ def scan_pessimistic_batches_for_oom(
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
* adopt_weight( * adopt_weight(
params.lambda_adv, params.lambda_adv,
params.batch_idx_train, 0,
threshold=params.discriminator_iter_start, threshold=params.discriminator_epoch_start,
) )
+ ( + (
params.lambda_wav * wav_reconstruction_loss params.lambda_wav * wav_reconstruction_loss
@ -1000,8 +1007,20 @@ def run(rank, world_size, args):
betas=(0.5, 0.9), betas=(0.5, 0.9),
) )
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) scheduler_g = WarmupCosineLrScheduler(
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) optimizer=optimizer_g,
max_epoch=params.num_epochs,
eta_ratio=0.1,
warmup_epoch=params.discriminator_epoch_start,
warmup_ratio=1e-4,
)
scheduler_d = WarmupCosineLrScheduler(
optimizer=optimizer_d,
max_epoch=params.num_epochs,
eta_ratio=0.1,
warmup_epoch=params.discriminator_epoch_start,
warmup_ratio=1e-4,
)
if checkpoints is not None: if checkpoints is not None:
# load state_dict for optimizers # load state_dict for optimizers