init commit for the EEG setup

This commit is contained in:
zr_jin 2024-10-22 20:35:54 +08:00
parent 39b3d1a050
commit c345a4a572
4 changed files with 166 additions and 30 deletions

View File

@ -26,6 +26,7 @@ from loss import (
FeatureLoss, FeatureLoss,
GeneratorAdversarialLoss, GeneratorAdversarialLoss,
MelSpectrogramReconstructionLoss, MelSpectrogramReconstructionLoss,
SpectrogramReconstructionLoss,
WavReconstructionLoss, WavReconstructionLoss,
) )
from torch import nn from torch import nn
@ -79,7 +80,7 @@ class Encodec(nn.Module):
) )
self.feature_match_loss = FeatureLoss() self.feature_match_loss = FeatureLoss()
self.wav_reconstruction_loss = WavReconstructionLoss() self.wav_reconstruction_loss = WavReconstructionLoss()
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss( self.spec_reconstruction_loss = SpectrogramReconstructionLoss(
sampling_rate=self.sampling_rate sampling_rate=self.sampling_rate
) )
@ -170,7 +171,7 @@ class Encodec(nn.Module):
wav_reconstruction_loss = self.wav_reconstruction_loss( wav_reconstruction_loss = self.wav_reconstruction_loss(
x=speech, x_hat=speech_hat x=speech, x_hat=speech_hat
) )
mel_reconstruction_loss = self.mel_reconstruction_loss( mel_reconstruction_loss = self.spec_reconstruction_loss(
x=speech, x_hat=speech_hat x=speech, x_hat=speech_hat
) )

View File

@ -14,7 +14,7 @@ from typing import List, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torchaudio.transforms import MelSpectrogram from torchaudio.transforms import MelSpectrogram, Spectrogram
class GeneratorAdversarialLoss(torch.nn.Module): class GeneratorAdversarialLoss(torch.nn.Module):
@ -295,6 +295,80 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
return mel_loss return mel_loss
class SpectrogramReconstructionLoss(torch.nn.Module):
"""Spec Reconstruction loss designed for EEG signals."""
def __init__(
self,
sampling_rate: int = 22050,
return_spec: bool = False,
):
super().__init__()
self.wav_to_specs = []
for i in range(5, 10):
s = 2**i // 8
self.wav_to_specs.append(
Spectrogram(
sample_rate=sampling_rate,
n_fft=s,
win_length=s,
hop_length=s // 4,
normalized=True,
center=False,
pad_mode=None,
power=None,
)
)
self.return_mel = return_spec
def forward(
self,
x_hat: torch.Tensor,
x: torch.Tensor,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
"""Calculate Mel-spectrogram loss.
Args:
x_hat (Tensor): Generated waveform tensor (B, 1, T).
x (Tensor): Groundtruth waveform tensor (B, 1, T).
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
waveform.
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_loss = 0.0
for i, wav_to_spec in enumerate(self.wav_to_specs):
s = 2 ** (i + 5)
wav_to_spec.to(x.device)
mel_hat = wav_to_spec(x_hat.squeeze(1))
mel = wav_to_spec(x.squeeze(1))
mel_loss += (
F.l1_loss(mel_hat, mel, reduce=True, reduction="mean")
+ (
(
(torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7))
** 2
).mean(dim=-2)
** 0.5
).mean()
)
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
# mel = self.wav_to_spec(x.squeeze(1))
# mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel)
if self.return_mel:
return mel_loss, (mel_hat, mel)
return mel_loss
class WavReconstructionLoss(torch.nn.Module): class WavReconstructionLoss(torch.nn.Module):
"""Wav Reconstruction loss.""" """Wav Reconstruction loss."""

View File

@ -32,13 +32,14 @@ import torch.nn as nn
from codec_datamodule import CodecDataModule from codec_datamodule import CodecDataModule
from encodec import Encodec from encodec import Encodec
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from PIL import Image
from scheduler import WarmupCosineLrScheduler from scheduler import WarmupCosineLrScheduler
from torch import nn from torch import nn
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from utils import MetricsTracker, save_checkpoint from utils import MetricsTracker, plot_curve, plot_feature, save_checkpoint
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
@ -155,6 +156,20 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--sampling-rate",
type=int,
default=128,
help="The sampling rate of the biomarker.",
)
parser.add_argument(
"--chunk-size",
type=int,
default=180,
help="The chunk size of the biomarker (in second).",
)
return parser return parser
@ -202,8 +217,7 @@ def get_params() -> AttributeDict:
"log_interval": 50, "log_interval": 50,
"valid_interval": 200, "valid_interval": 200,
"env_info": get_env_info(), "env_info": get_env_info(),
"sampling_rate": 24000, "wave_normalization": False,
"audio_normalization": False,
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
"lambda_wav": 0.1, # loss scaling coefficient for waveform loss "lambda_wav": 0.1, # loss scaling coefficient for waveform loss
"lambda_feat": 4.0, # loss scaling coefficient for feat loss "lambda_feat": 4.0, # loss scaling coefficient for feat loss
@ -352,27 +366,27 @@ def prepare_input(
is_training: bool = True, is_training: bool = True,
): ):
"""Parse batch data""" """Parse batch data"""
audio = batch["audio"].to(device, memory_format=torch.contiguous_format) wave = batch["audio"].to(device, memory_format=torch.contiguous_format)
features = batch["features"].to(device, memory_format=torch.contiguous_format) features = batch["features"].to(device, memory_format=torch.contiguous_format)
audio_lens = batch["audio_lens"].to(device) wave_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device) features_lens = batch["features_lens"].to(device)
if is_training: if is_training:
audio_dims = audio.size(-1) audio_dims = wave.size(-1)
start_idx = random.randint(0, max(0, audio_dims - params.sampling_rate)) start_idx = random.randint(0, max(0, audio_dims - params.sampling_rate))
audio = audio[:, start_idx : params.sampling_rate + start_idx] wave = wave[:, start_idx : params.sampling_rate + start_idx]
else: else:
# NOTE(zengrui): a very coarse setup # NOTE(zengrui): a very coarse setup
audio = audio[ wave = wave[
:, params.sampling_rate : params.sampling_rate + params.sampling_rate :, params.sampling_rate : params.sampling_rate + params.sampling_rate
] ]
if params.audio_normalization: if params.wave_normalization:
mean = audio.mean(dim=-1, keepdim=True) mean = wave.mean(dim=-1, keepdim=True)
std = audio.std(dim=-1, keepdim=True) std = wave.std(dim=-1, keepdim=True)
audio = (audio - mean) / (std + 1e-7) wave = (wave - mean) / (std + 1e-7)
return audio, audio_lens, features, features_lens return wave, wave_lens, features, features_lens
def train_discriminator(weight, global_step, threshold=0, value=0.0): def train_discriminator(weight, global_step, threshold=0, value=0.0):
@ -623,17 +637,29 @@ def train_one_epoch(
if speech_hat_i.dim() > 1: if speech_hat_i.dim() > 1:
speech_hat_i = speech_hat_i.squeeze(0) speech_hat_i = speech_hat_i.squeeze(0)
speech_i = speech_i.squeeze(0) speech_i = speech_i.squeeze(0)
tb_writer.add_audio( # tb_writer.add_audio(
f"train/speech_hat_", # f"train/speech_hat_",
speech_hat_i, # speech_hat_i,
# params.batch_idx_train,
# params.sampling_rate,
# )
# tb_writer.add_audio(
# f"train/speech_",
# speech_i,
# params.batch_idx_train,
# params.sampling_rate,
# )
tb_writer.add_image(
"train/speech_hat_",
np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))),
params.batch_idx_train, params.batch_idx_train,
params.sampling_rate, dataformats="HWC",
) )
tb_writer.add_audio( tb_writer.add_image(
f"train/speech_", "train/speech_",
speech_i, np.array(Image.open(plot_curve(speech_i, params.sampling_rate))),
params.batch_idx_train, params.batch_idx_train,
params.sampling_rate, dataformats="HWC",
) )
# tb_writer.add_image( # tb_writer.add_image(
# "train/mel_hat_", # "train/mel_hat_",
@ -675,18 +701,31 @@ def train_one_epoch(
if speech_hat_i.dim() > 1: if speech_hat_i.dim() > 1:
speech_hat_i = speech_hat_i.squeeze(0) speech_hat_i = speech_hat_i.squeeze(0)
speech_i = speech_i.squeeze(0) speech_i = speech_i.squeeze(0)
tb_writer.add_audio( # tb_writer.add_audio(
# f"train/valid_speech_hat_{index}",
# speech_hat_i,
# params.batch_idx_train,
# params.sampling_rate,
# )
# tb_writer.add_audio(
# f"train/valid_speech_{index}",
# speech_i,
# params.batch_idx_train,
# params.sampling_rate,
# )
tb_writer.add_image(
f"train/valid_speech_hat_{index}", f"train/valid_speech_hat_{index}",
speech_hat_i, np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))),
params.batch_idx_train, params.batch_idx_train,
params.sampling_rate, dataformats="HWC",
) )
tb_writer.add_audio( tb_writer.add_image(
f"train/valid_speech_{index}", f"train/valid_speech_{index}",
speech_i, np.array(Image.open(plot_curve(speech_i, params.sampling_rate))),
params.batch_idx_train, params.batch_idx_train,
params.sampling_rate, dataformats="HWC",
) )
loss_value = tot_loss["generator_loss"] / tot_loss["samples"] loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
params.train_loss = loss_value params.train_loss = loss_value
@ -1164,11 +1203,19 @@ def run(rank, world_size, args):
cleanup_dist() cleanup_dist()
def override_sampling_rate(args) -> int:
logging.info(
f"Overriding sampling rate from {args.sampling_rate} to {args.sampling_rate * args.chunk_size}"
)
return int(args.sampling_rate * args.chunk_size)
def main(): def main():
parser = get_parser() parser = get_parser()
CodecDataModule.add_arguments(parser) CodecDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.sampling_rate = override_sampling_rate(args)
world_size = args.world_size world_size = args.world_size
assert world_size >= 1 assert world_size >= 1

View File

@ -118,6 +118,20 @@ def plot_feature(spectrogram):
plt.close() plt.close()
return data return data
def plot_curve(speech: torch.Tensor, sampling_rate: int) -> bytes:
import io
import matplotlib.pyplot as plt
import numpy as np
plt.figure()
plt.plot(np.arange(sampling_rate) / sampling_rate, speech.detach().cpu().numpy().T)
buf = io.BytesIO()
plt.savefig(buf, format="jpeg")
buf.seek(0)
plt.close()
return buf
class MetricsTracker(collections.defaultdict): class MetricsTracker(collections.defaultdict):
def __init__(self): def __init__(self):