mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
comments updated
This commit is contained in:
parent
df87a0fe2c
commit
5492a6a5e2
@ -1,8 +1,19 @@
|
|||||||
|
# Modified from egs/ljspeech/TTS/vits/loss.py by: Zengrui JIN (Tsinghua University)
|
||||||
|
# original implementation is from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
|
||||||
|
|
||||||
|
# Copyright 2021 Tomoki Hayashi
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Encodec-related loss modules.
|
||||||
|
|
||||||
|
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from lhotse.features.kaldi import Wav2LogFilterBank
|
|
||||||
from torchaudio.transforms import MelSpectrogram
|
from torchaudio.transforms import MelSpectrogram
|
||||||
|
|
||||||
|
|
||||||
@ -225,15 +236,6 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
|||||||
self.wav_to_specs = []
|
self.wav_to_specs = []
|
||||||
for i in range(5, 12):
|
for i in range(5, 12):
|
||||||
s = 2**i
|
s = 2**i
|
||||||
# self.wav_to_specs.append(
|
|
||||||
# Wav2LogFilterBank(
|
|
||||||
# sampling_rate=sampling_rate,
|
|
||||||
# frame_length=s,
|
|
||||||
# frame_shift=s // 4,
|
|
||||||
# use_fft_mag=use_fft_mag,
|
|
||||||
# num_filters=n_mels,
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
self.wav_to_specs.append(
|
self.wav_to_specs.append(
|
||||||
MelSpectrogram(
|
MelSpectrogram(
|
||||||
sample_rate=sampling_rate,
|
sample_rate=sampling_rate,
|
||||||
|
@ -186,12 +186,11 @@ def get_params() -> AttributeDict:
|
|||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
"sampling_rate": 24000,
|
"sampling_rate": 24000,
|
||||||
"audio_normalization": False,
|
"audio_normalization": False,
|
||||||
"chunk_size": 1.0, # in seconds
|
|
||||||
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
||||||
"lambda_wav": 0.1, # loss scaling coefficient for waveform loss
|
"lambda_wav": 0.1, # loss scaling coefficient for waveform loss
|
||||||
"lambda_feat": 4.0, # loss scaling coefficient for feat loss
|
"lambda_feat": 4.0, # loss scaling coefficient for feat loss
|
||||||
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
|
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
|
||||||
"lambda_com": 1.0, # loss scaling coefficient for commitment loss
|
"lambda_com": 1000.0, # loss scaling coefficient for commitment loss
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -342,12 +341,10 @@ def prepare_input(
|
|||||||
|
|
||||||
if is_training:
|
if is_training:
|
||||||
audio_dims = audio.size(-1)
|
audio_dims = audio.size(-1)
|
||||||
start_idx = random.randint(
|
start_idx = random.randint(0, max(0, audio_dims - params.sampling_rate))
|
||||||
0, max(0, audio_dims - params.chunk_size * params.sampling_rate)
|
|
||||||
)
|
|
||||||
audio = audio[:, start_idx : params.sampling_rate + start_idx]
|
audio = audio[:, start_idx : params.sampling_rate + start_idx]
|
||||||
else:
|
else:
|
||||||
# NOTE: a very coarse setup
|
# NOTE(zengrui): a very coarse setup
|
||||||
audio = audio[
|
audio = audio[
|
||||||
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
|
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
|
||||||
]
|
]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user