init commit for the EEG setup

This commit is contained in:
zr_jin 2024-10-22 20:35:54 +08:00
parent 39b3d1a050
commit c345a4a572
4 changed files with 166 additions and 30 deletions

View File

@ -26,6 +26,7 @@ from loss import (
FeatureLoss,
GeneratorAdversarialLoss,
MelSpectrogramReconstructionLoss,
SpectrogramReconstructionLoss,
WavReconstructionLoss,
)
from torch import nn
@ -79,7 +80,7 @@ class Encodec(nn.Module):
)
self.feature_match_loss = FeatureLoss()
self.wav_reconstruction_loss = WavReconstructionLoss()
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
self.spec_reconstruction_loss = SpectrogramReconstructionLoss(
sampling_rate=self.sampling_rate
)
@ -170,7 +171,7 @@ class Encodec(nn.Module):
wav_reconstruction_loss = self.wav_reconstruction_loss(
x=speech, x_hat=speech_hat
)
mel_reconstruction_loss = self.mel_reconstruction_loss(
mel_reconstruction_loss = self.spec_reconstruction_loss(
x=speech, x_hat=speech_hat
)

View File

@ -14,7 +14,7 @@ from typing import List, Tuple, Union
import torch
import torch.nn.functional as F
from torchaudio.transforms import MelSpectrogram
from torchaudio.transforms import MelSpectrogram, Spectrogram
class GeneratorAdversarialLoss(torch.nn.Module):
@ -295,6 +295,80 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
return mel_loss
class SpectrogramReconstructionLoss(torch.nn.Module):
"""Spec Reconstruction loss designed for EEG signals."""
def __init__(
self,
sampling_rate: int = 22050,
return_spec: bool = False,
):
super().__init__()
self.wav_to_specs = []
for i in range(5, 10):
s = 2**i // 8
self.wav_to_specs.append(
Spectrogram(
sample_rate=sampling_rate,
n_fft=s,
win_length=s,
hop_length=s // 4,
normalized=True,
center=False,
pad_mode=None,
power=None,
)
)
self.return_mel = return_spec
def forward(
self,
x_hat: torch.Tensor,
x: torch.Tensor,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
"""Calculate Mel-spectrogram loss.
Args:
x_hat (Tensor): Generated waveform tensor (B, 1, T).
x (Tensor): Groundtruth waveform tensor (B, 1, T).
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
waveform.
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_loss = 0.0
for i, wav_to_spec in enumerate(self.wav_to_specs):
s = 2 ** (i + 5)
wav_to_spec.to(x.device)
mel_hat = wav_to_spec(x_hat.squeeze(1))
mel = wav_to_spec(x.squeeze(1))
mel_loss += (
F.l1_loss(mel_hat, mel, reduce=True, reduction="mean")
+ (
(
(torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7))
** 2
).mean(dim=-2)
** 0.5
).mean()
)
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
# mel = self.wav_to_spec(x.squeeze(1))
# mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel)
if self.return_mel:
return mel_loss, (mel_hat, mel)
return mel_loss
class WavReconstructionLoss(torch.nn.Module):
"""Wav Reconstruction loss."""

View File

@ -32,13 +32,14 @@ import torch.nn as nn
from codec_datamodule import CodecDataModule
from encodec import Encodec
from lhotse.utils import fix_random_seed
from PIL import Image
from scheduler import WarmupCosineLrScheduler
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from utils import MetricsTracker, save_checkpoint
from utils import MetricsTracker, plot_curve, plot_feature, save_checkpoint
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint
@ -155,6 +156,20 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=128,
help="The sampling rate of the biomarker.",
)
parser.add_argument(
"--chunk-size",
type=int,
default=180,
help="The chunk size of the biomarker (in second).",
)
return parser
@ -202,8 +217,7 @@ def get_params() -> AttributeDict:
"log_interval": 50,
"valid_interval": 200,
"env_info": get_env_info(),
"sampling_rate": 24000,
"audio_normalization": False,
"wave_normalization": False,
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
"lambda_wav": 0.1, # loss scaling coefficient for waveform loss
"lambda_feat": 4.0, # loss scaling coefficient for feat loss
@ -352,27 +366,27 @@ def prepare_input(
is_training: bool = True,
):
"""Parse batch data"""
audio = batch["audio"].to(device, memory_format=torch.contiguous_format)
wave = batch["audio"].to(device, memory_format=torch.contiguous_format)
features = batch["features"].to(device, memory_format=torch.contiguous_format)
audio_lens = batch["audio_lens"].to(device)
wave_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
if is_training:
audio_dims = audio.size(-1)
audio_dims = wave.size(-1)
start_idx = random.randint(0, max(0, audio_dims - params.sampling_rate))
audio = audio[:, start_idx : params.sampling_rate + start_idx]
wave = wave[:, start_idx : params.sampling_rate + start_idx]
else:
# NOTE(zengrui): a very coarse setup
audio = audio[
wave = wave[
:, params.sampling_rate : params.sampling_rate + params.sampling_rate
]
if params.audio_normalization:
mean = audio.mean(dim=-1, keepdim=True)
std = audio.std(dim=-1, keepdim=True)
audio = (audio - mean) / (std + 1e-7)
if params.wave_normalization:
mean = wave.mean(dim=-1, keepdim=True)
std = wave.std(dim=-1, keepdim=True)
wave = (wave - mean) / (std + 1e-7)
return audio, audio_lens, features, features_lens
return wave, wave_lens, features, features_lens
def train_discriminator(weight, global_step, threshold=0, value=0.0):
@ -623,17 +637,29 @@ def train_one_epoch(
if speech_hat_i.dim() > 1:
speech_hat_i = speech_hat_i.squeeze(0)
speech_i = speech_i.squeeze(0)
tb_writer.add_audio(
f"train/speech_hat_",
speech_hat_i,
# tb_writer.add_audio(
# f"train/speech_hat_",
# speech_hat_i,
# params.batch_idx_train,
# params.sampling_rate,
# )
# tb_writer.add_audio(
# f"train/speech_",
# speech_i,
# params.batch_idx_train,
# params.sampling_rate,
# )
tb_writer.add_image(
"train/speech_hat_",
np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))),
params.batch_idx_train,
params.sampling_rate,
dataformats="HWC",
)
tb_writer.add_audio(
f"train/speech_",
speech_i,
tb_writer.add_image(
"train/speech_",
np.array(Image.open(plot_curve(speech_i, params.sampling_rate))),
params.batch_idx_train,
params.sampling_rate,
dataformats="HWC",
)
# tb_writer.add_image(
# "train/mel_hat_",
@ -675,19 +701,32 @@ def train_one_epoch(
if speech_hat_i.dim() > 1:
speech_hat_i = speech_hat_i.squeeze(0)
speech_i = speech_i.squeeze(0)
tb_writer.add_audio(
# tb_writer.add_audio(
# f"train/valid_speech_hat_{index}",
# speech_hat_i,
# params.batch_idx_train,
# params.sampling_rate,
# )
# tb_writer.add_audio(
# f"train/valid_speech_{index}",
# speech_i,
# params.batch_idx_train,
# params.sampling_rate,
# )
tb_writer.add_image(
f"train/valid_speech_hat_{index}",
speech_hat_i,
np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))),
params.batch_idx_train,
params.sampling_rate,
dataformats="HWC",
)
tb_writer.add_audio(
tb_writer.add_image(
f"train/valid_speech_{index}",
speech_i,
np.array(Image.open(plot_curve(speech_i, params.sampling_rate))),
params.batch_idx_train,
params.sampling_rate,
dataformats="HWC",
)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
@ -1164,11 +1203,19 @@ def run(rank, world_size, args):
cleanup_dist()
def override_sampling_rate(args) -> int:
logging.info(
f"Overriding sampling rate from {args.sampling_rate} to {args.sampling_rate * args.chunk_size}"
)
return int(args.sampling_rate * args.chunk_size)
def main():
parser = get_parser()
CodecDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.sampling_rate = override_sampling_rate(args)
world_size = args.world_size
assert world_size >= 1

View File

@ -118,6 +118,20 @@ def plot_feature(spectrogram):
plt.close()
return data
def plot_curve(speech: torch.Tensor, sampling_rate: int) -> bytes:
import io
import matplotlib.pyplot as plt
import numpy as np
plt.figure()
plt.plot(np.arange(sampling_rate) / sampling_rate, speech.detach().cpu().numpy().T)
buf = io.BytesIO()
plt.savefig(buf, format="jpeg")
buf.seek(0)
plt.close()
return buf
class MetricsTracker(collections.defaultdict):
def __init__(self):