added scheduler w/ warmup

This commit is contained in:
JinZr 2024-10-06 19:07:07 +08:00
parent d83ce89fca
commit 58f6562824
4 changed files with 209 additions and 38 deletions

View File

@ -60,7 +60,7 @@ class Encodec(nn.Module):
self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss(
average_by_discriminators=True, loss_type="hinge"
)
self.feature_match_loss = FeatureLoss(average_by_layers=False)
self.feature_match_loss = FeatureLoss()
self.wav_reconstruction_loss = WavReconstructionLoss()
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
sampling_rate=self.sampling_rate

View File

@ -45,8 +45,8 @@ class GeneratorAdversarialLoss(torch.nn.Module):
Tensor: Generator adversarial loss value.
"""
adv_loss = 0.0
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0
for i, outputs_ in enumerate(outputs):
if isinstance(outputs_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
@ -55,9 +55,10 @@ class GeneratorAdversarialLoss(torch.nn.Module):
if self.average_by_discriminators:
adv_loss /= i + 1
else:
adv_loss = self.criterion(outputs)
return adv_loss / len(outputs)
for i, outputs_ in enumerate(outputs):
adv_loss += self.criterion(outputs_)
adv_loss /= i + 1
return adv_loss
def _mse_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
@ -112,9 +113,9 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
Tensor: Discriminator fake loss value.
"""
real_loss = 0.0
fake_loss = 0.0
if isinstance(outputs, (tuple, list)):
real_loss = 0.0
fake_loss = 0.0
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
@ -126,10 +127,13 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
fake_loss /= i + 1
real_loss /= i + 1
else:
real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat)
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
fake_loss /= i + 1
real_loss /= i + 1
return real_loss / len(outputs), fake_loss / len(outputs)
return real_loss, fake_loss
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_ones(x.size()))
@ -204,7 +208,7 @@ class FeatureLoss(torch.nn.Module):
if self.average_by_discriminators:
feat_match_loss /= i + 1
return feat_match_loss / (len(feats) * len(feats[0]))
return feat_match_loss
class MelSpectrogramReconstructionLoss(torch.nn.Module):
@ -233,7 +237,7 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
self.wav_to_specs.append(
MelSpectrogram(
sample_rate=sampling_rate,
n_fft=s,
n_fft=max(s, 512),
win_length=s,
hop_length=s // 4,
n_mels=n_mels,
@ -462,8 +466,15 @@ def loss_g(
if __name__ == "__main__":
la = FeatureLoss(average_by_layers=False, average_by_discriminators=False)
aa = [torch.rand(192, 192) for _ in range(3)]
bb = [torch.rand(192, 192) for _ in range(3)]
print(la(bb, aa))
print(feature_loss(aa, bb))
# la = FeatureLoss(average_by_layers=True, average_by_discriminators=True)
# aa = [torch.rand(192, 192) for _ in range(3)]
# bb = [torch.rand(192, 192) for _ in range(3)]
# print(la(bb, aa))
# print(feature_loss(aa, bb))
la = GeneratorAdversarialLoss(average_by_discriminators=True, loss_type="hinge")
aa = torch.Tensor([0.1, 0.2, 0.3, 0.4])
bb = torch.Tensor([0.4, 0.3, 0.2, 0.1])
print(la(aa))
print(adversarial_g_loss(aa))
print(la(bb))
print(adversarial_g_loss(bb))

View File

@ -0,0 +1,141 @@
import math
from bisect import bisect_right
from torch.optim.lr_scheduler import _LRScheduler
class WarmupLrScheduler(_LRScheduler):
def __init__(
self,
optimizer,
warmup_epoch=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.warmup_epoch = warmup_epoch
self.warmup_ratio = warmup_ratio
self.warmup = warmup
super(WarmupLrScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
ratio = self.get_lr_ratio()
lrs = [ratio * lr for lr in self.base_lrs]
return lrs
def get_lr_ratio(self):
if self.last_epoch < self.warmup_epoch:
ratio = self.get_warmup_ratio()
else:
ratio = self.get_main_ratio()
return ratio
def get_main_ratio(self):
raise NotImplementedError
def get_warmup_ratio(self):
assert self.warmup in ("linear", "exp")
alpha = self.last_epoch / self.warmup_epoch
if self.warmup == "linear":
ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
elif self.warmup == "exp":
ratio = self.warmup_ratio ** (1.0 - alpha)
return ratio
class WarmupPolyLrScheduler(WarmupLrScheduler):
def __init__(
self,
optimizer,
power,
max_epoch,
warmup_epoch=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.power = power
self.max_epoch = max_epoch
super(WarmupPolyLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch
real_max_epoch = self.max_epoch - self.warmup_epoch
alpha = real_epoch / real_max_epoch
ratio = (1 - alpha) ** self.power
return ratio
class WarmupExpLrScheduler(WarmupLrScheduler):
def __init__(
self,
optimizer,
gamma,
interval=1,
warmup_epoch=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.gamma = gamma
self.interval = interval
super(WarmupExpLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch
ratio = self.gamma ** (real_epoch // self.interval)
return ratio
class WarmupCosineLrScheduler(WarmupLrScheduler):
def __init__(
self,
optimizer,
max_epoch,
eta_ratio=0,
warmup_epoch=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.eta_ratio = eta_ratio
self.max_epoch = max_epoch
super(WarmupCosineLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_max_epoch = self.max_epoch - self.warmup_epoch
return (
self.eta_ratio
+ (1 - self.eta_ratio)
* (1 + math.cos(math.pi * self.last_epoch / real_max_epoch))
/ 2
)
class WarmupStepLrScheduler(WarmupLrScheduler):
def __init__(
self,
optimizer,
milestones: list,
gamma=0.1,
warmup_epoch=500,
warmup_ratio=5e-4,
warmup="exp",
last_epoch=-1,
):
self.milestones = milestones
self.gamma = gamma
super(WarmupStepLrScheduler, self).__init__(
optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch
)
def get_main_ratio(self):
real_epoch = self.last_epoch - self.warmup_epoch
ratio = self.gamma ** bisect_right(self.milestones, real_epoch)
return ratio

View File

@ -16,6 +16,7 @@ from encodec import Encodec
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from loss import adopt_weight
from scheduler import WarmupCosineLrScheduler
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
@ -188,10 +189,10 @@ def get_params() -> AttributeDict:
"sampling_rate": 24000,
"chunk_size": 1.0, # in seconds
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
"lambda_wav": 0.1, # loss scaling coefficient for waveform loss
"lambda_wav": 1.0, # loss scaling coefficient for waveform loss
"lambda_feat": 3.0, # loss scaling coefficient for feat loss
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
"lambda_com": 1.0, # loss scaling coefficient for commitment loss
"lambda_com": 100.0, # loss scaling coefficient for commitment loss
}
)
@ -260,7 +261,7 @@ def get_model(params: AttributeDict) -> nn.Module:
# }
# discriminator_params = {
# "stft_discriminator_n_filters": 32,
# "discriminator_iter_start": 500,
# "discriminator_epoch_start": 5,
# }
# inference_params = {
# "target_bw": 7.5,
@ -275,7 +276,10 @@ def get_model(params: AttributeDict) -> nn.Module:
}
discriminator_params = {
"stft_discriminator_n_filters": 32,
"discriminator_iter_start": 500,
"discriminator_epoch_start": 3,
"n_ffts": [1024, 2048, 512],
"hop_lengths": [256, 512, 128],
"win_lengths": [1024, 2048, 512],
}
inference_params = {
"target_bw": 12,
@ -316,7 +320,10 @@ def get_model(params: AttributeDict) -> nn.Module:
multi_scale_discriminator=None,
multi_period_discriminator=None,
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(
n_filters=params.stft_discriminator_n_filters
n_filters=params.stft_discriminator_n_filters,
n_ffts=params.n_ffts,
hop_lengths=params.hop_lengths,
win_lengths=params.win_lengths,
),
)
return model
@ -437,8 +444,8 @@ def train_one_epoch(
with autocast(enabled=params.use_fp16):
d_weight = adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
params.cur_epoch,
threshold=params.discriminator_epoch_start,
)
# forward discriminator
(
@ -473,8 +480,8 @@ def train_one_epoch(
with autocast(enabled=params.use_fp16):
g_weight = adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
params.cur_epoch,
threshold=params.discriminator_epoch_start,
)
# forward generator
(
@ -507,7 +514,7 @@ def train_one_epoch(
gen_loss = (
gen_adv_loss
+ reconstruction_loss
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
+ params.lambda_feat * feature_loss
+ params.lambda_com * commit_loss
)
for k, v in stats_g.items():
@ -688,8 +695,8 @@ def compute_validation_loss(
d_weight = adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
params.cur_epoch,
threshold=params.discriminator_epoch_start,
)
# forward discriminator
@ -721,8 +728,8 @@ def compute_validation_loss(
g_weight = adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
params.cur_epoch,
threshold=params.discriminator_epoch_start,
)
# forward generator
(
@ -753,7 +760,7 @@ def compute_validation_loss(
gen_loss = (
gen_adv_loss
+ reconstruction_loss
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
+ params.lambda_feat * feature_loss
+ params.lambda_com * commit_loss
)
assert gen_loss.requires_grad is False
@ -831,8 +838,8 @@ def scan_pessimistic_batches_for_oom(
+ disc_scale_fake_adv_loss
) * adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
params.cur_epoch,
threshold=params.discriminator_train_start,
)
optimizer_d.zero_grad()
loss_d.backward()
@ -859,8 +866,8 @@ def scan_pessimistic_batches_for_oom(
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
* adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
0,
threshold=params.discriminator_epoch_start,
)
+ (
params.lambda_wav * wav_reconstruction_loss
@ -1000,8 +1007,20 @@ def run(rank, world_size, args):
betas=(0.5, 0.9),
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
scheduler_g = WarmupCosineLrScheduler(
optimizer=optimizer_g,
max_epoch=params.num_epochs,
eta_ratio=0.1,
warmup_epoch=params.discriminator_epoch_start,
warmup_ratio=1e-4,
)
scheduler_d = WarmupCosineLrScheduler(
optimizer=optimizer_d,
max_epoch=params.num_epochs,
eta_ratio=0.1,
warmup_epoch=params.discriminator_epoch_start,
warmup_ratio=1e-4,
)
if checkpoints is not None:
# load state_dict for optimizers