diff --git a/egs/tokenizer/CODEC/encodec/encodec.py b/egs/tokenizer/CODEC/encodec/encodec.py index f21d494b6..9dc0ad61f 100644 --- a/egs/tokenizer/CODEC/encodec/encodec.py +++ b/egs/tokenizer/CODEC/encodec/encodec.py @@ -26,6 +26,7 @@ from loss import ( FeatureLoss, GeneratorAdversarialLoss, MelSpectrogramReconstructionLoss, + SpectrogramReconstructionLoss, WavReconstructionLoss, ) from torch import nn @@ -79,7 +80,7 @@ class Encodec(nn.Module): ) self.feature_match_loss = FeatureLoss() self.wav_reconstruction_loss = WavReconstructionLoss() - self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss( + self.spec_reconstruction_loss = SpectrogramReconstructionLoss( sampling_rate=self.sampling_rate ) @@ -170,7 +171,7 @@ class Encodec(nn.Module): wav_reconstruction_loss = self.wav_reconstruction_loss( x=speech, x_hat=speech_hat ) - mel_reconstruction_loss = self.mel_reconstruction_loss( + mel_reconstruction_loss = self.spec_reconstruction_loss( x=speech, x_hat=speech_hat ) diff --git a/egs/tokenizer/CODEC/encodec/loss.py b/egs/tokenizer/CODEC/encodec/loss.py index 9cf1d42d2..fc7161e85 100644 --- a/egs/tokenizer/CODEC/encodec/loss.py +++ b/egs/tokenizer/CODEC/encodec/loss.py @@ -14,7 +14,7 @@ from typing import List, Tuple, Union import torch import torch.nn.functional as F -from torchaudio.transforms import MelSpectrogram +from torchaudio.transforms import MelSpectrogram, Spectrogram class GeneratorAdversarialLoss(torch.nn.Module): @@ -295,6 +295,80 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module): return mel_loss +class SpectrogramReconstructionLoss(torch.nn.Module): + """Spec Reconstruction loss designed for EEG signals.""" + + def __init__( + self, + sampling_rate: int = 22050, + return_spec: bool = False, + ): + super().__init__() + self.wav_to_specs = [] + for i in range(5, 10): + s = 2**i // 8 + self.wav_to_specs.append( + Spectrogram( + sample_rate=sampling_rate, + n_fft=s, + win_length=s, + hop_length=s // 4, + normalized=True, + center=False, + pad_mode=None, + power=None, + ) + ) + self.return_mel = return_spec + + def forward( + self, + x_hat: torch.Tensor, + x: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: + """Calculate Mel-spectrogram loss. + + Args: + x_hat (Tensor): Generated waveform tensor (B, 1, T). + x (Tensor): Groundtruth waveform tensor (B, 1, T). + spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor + (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth + waveform. + + Returns: + Tensor: Mel-spectrogram loss value. + + """ + mel_loss = 0.0 + + for i, wav_to_spec in enumerate(self.wav_to_specs): + s = 2 ** (i + 5) + wav_to_spec.to(x.device) + + mel_hat = wav_to_spec(x_hat.squeeze(1)) + mel = wav_to_spec(x.squeeze(1)) + + mel_loss += ( + F.l1_loss(mel_hat, mel, reduce=True, reduction="mean") + + ( + ( + (torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7)) + ** 2 + ).mean(dim=-2) + ** 0.5 + ).mean() + ) + + # mel_hat = self.wav_to_spec(x_hat.squeeze(1)) + # mel = self.wav_to_spec(x.squeeze(1)) + # mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel) + + if self.return_mel: + return mel_loss, (mel_hat, mel) + + return mel_loss + + class WavReconstructionLoss(torch.nn.Module): """Wav Reconstruction loss.""" diff --git a/egs/tokenizer/CODEC/encodec/train.py b/egs/tokenizer/CODEC/encodec/train.py index 380d05ded..41f6e2cb7 100755 --- a/egs/tokenizer/CODEC/encodec/train.py +++ b/egs/tokenizer/CODEC/encodec/train.py @@ -32,13 +32,14 @@ import torch.nn as nn from codec_datamodule import CodecDataModule from encodec import Encodec from lhotse.utils import fix_random_seed +from PIL import Image from scheduler import WarmupCosineLrScheduler from torch import nn from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter -from utils import MetricsTracker, save_checkpoint +from utils import MetricsTracker, plot_curve, plot_feature, save_checkpoint from icefall import diagnostics from icefall.checkpoint import load_checkpoint @@ -155,6 +156,20 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--sampling-rate", + type=int, + default=128, + help="The sampling rate of the biomarker.", + ) + + parser.add_argument( + "--chunk-size", + type=int, + default=180, + help="The chunk size of the biomarker (in second).", + ) + return parser @@ -202,8 +217,7 @@ def get_params() -> AttributeDict: "log_interval": 50, "valid_interval": 200, "env_info": get_env_info(), - "sampling_rate": 24000, - "audio_normalization": False, + "wave_normalization": False, "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss "lambda_wav": 0.1, # loss scaling coefficient for waveform loss "lambda_feat": 4.0, # loss scaling coefficient for feat loss @@ -352,27 +366,27 @@ def prepare_input( is_training: bool = True, ): """Parse batch data""" - audio = batch["audio"].to(device, memory_format=torch.contiguous_format) + wave = batch["audio"].to(device, memory_format=torch.contiguous_format) features = batch["features"].to(device, memory_format=torch.contiguous_format) - audio_lens = batch["audio_lens"].to(device) + wave_lens = batch["audio_lens"].to(device) features_lens = batch["features_lens"].to(device) if is_training: - audio_dims = audio.size(-1) + audio_dims = wave.size(-1) start_idx = random.randint(0, max(0, audio_dims - params.sampling_rate)) - audio = audio[:, start_idx : params.sampling_rate + start_idx] + wave = wave[:, start_idx : params.sampling_rate + start_idx] else: # NOTE(zengrui): a very coarse setup - audio = audio[ + wave = wave[ :, params.sampling_rate : params.sampling_rate + params.sampling_rate ] - if params.audio_normalization: - mean = audio.mean(dim=-1, keepdim=True) - std = audio.std(dim=-1, keepdim=True) - audio = (audio - mean) / (std + 1e-7) + if params.wave_normalization: + mean = wave.mean(dim=-1, keepdim=True) + std = wave.std(dim=-1, keepdim=True) + wave = (wave - mean) / (std + 1e-7) - return audio, audio_lens, features, features_lens + return wave, wave_lens, features, features_lens def train_discriminator(weight, global_step, threshold=0, value=0.0): @@ -623,17 +637,29 @@ def train_one_epoch( if speech_hat_i.dim() > 1: speech_hat_i = speech_hat_i.squeeze(0) speech_i = speech_i.squeeze(0) - tb_writer.add_audio( - f"train/speech_hat_", - speech_hat_i, + # tb_writer.add_audio( + # f"train/speech_hat_", + # speech_hat_i, + # params.batch_idx_train, + # params.sampling_rate, + # ) + # tb_writer.add_audio( + # f"train/speech_", + # speech_i, + # params.batch_idx_train, + # params.sampling_rate, + # ) + tb_writer.add_image( + "train/speech_hat_", + np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))), params.batch_idx_train, - params.sampling_rate, + dataformats="HWC", ) - tb_writer.add_audio( - f"train/speech_", - speech_i, + tb_writer.add_image( + "train/speech_", + np.array(Image.open(plot_curve(speech_i, params.sampling_rate))), params.batch_idx_train, - params.sampling_rate, + dataformats="HWC", ) # tb_writer.add_image( # "train/mel_hat_", @@ -675,18 +701,31 @@ def train_one_epoch( if speech_hat_i.dim() > 1: speech_hat_i = speech_hat_i.squeeze(0) speech_i = speech_i.squeeze(0) - tb_writer.add_audio( + # tb_writer.add_audio( + # f"train/valid_speech_hat_{index}", + # speech_hat_i, + # params.batch_idx_train, + # params.sampling_rate, + # ) + # tb_writer.add_audio( + # f"train/valid_speech_{index}", + # speech_i, + # params.batch_idx_train, + # params.sampling_rate, + # ) + tb_writer.add_image( f"train/valid_speech_hat_{index}", - speech_hat_i, + np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))), params.batch_idx_train, - params.sampling_rate, + dataformats="HWC", ) - tb_writer.add_audio( + tb_writer.add_image( f"train/valid_speech_{index}", - speech_i, + np.array(Image.open(plot_curve(speech_i, params.sampling_rate))), params.batch_idx_train, - params.sampling_rate, + dataformats="HWC", ) + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] params.train_loss = loss_value @@ -1164,11 +1203,19 @@ def run(rank, world_size, args): cleanup_dist() +def override_sampling_rate(args) -> int: + logging.info( + f"Overriding sampling rate from {args.sampling_rate} to {args.sampling_rate * args.chunk_size}" + ) + return int(args.sampling_rate * args.chunk_size) + + def main(): parser = get_parser() CodecDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) + args.sampling_rate = override_sampling_rate(args) world_size = args.world_size assert world_size >= 1 diff --git a/egs/tokenizer/CODEC/encodec/utils.py b/egs/tokenizer/CODEC/encodec/utils.py index 6a067f596..77cfd3c50 100644 --- a/egs/tokenizer/CODEC/encodec/utils.py +++ b/egs/tokenizer/CODEC/encodec/utils.py @@ -118,6 +118,20 @@ def plot_feature(spectrogram): plt.close() return data +def plot_curve(speech: torch.Tensor, sampling_rate: int) -> bytes: + import io + + import matplotlib.pyplot as plt + import numpy as np + + plt.figure() + plt.plot(np.arange(sampling_rate) / sampling_rate, speech.detach().cpu().numpy().T) + buf = io.BytesIO() + plt.savefig(buf, format="jpeg") + buf.seek(0) + plt.close() + return buf + class MetricsTracker(collections.defaultdict): def __init__(self):