remove unused code

This commit is contained in:
Fangjun Kuang 2024-10-28 19:18:21 +08:00
parent 7994684bf4
commit c558328dc5
3 changed files with 154 additions and 112 deletions

View File

@ -27,28 +27,100 @@ The generated fbank features are saved in data/fbank.
import argparse import argparse
import logging import logging
import os import os
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Union
import numpy as np
import torch import torch
from lhotse import ( from lhotse import CutSet, LilcomChunkyWriter, load_manifest
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
load_manifest,
load_manifest_lazy,
)
from lhotse.audio import RecordingSet from lhotse.audio import RecordingSet
from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.supervision import SupervisionSet from lhotse.supervision import SupervisionSet
from lhotse.utils import Seconds, compute_num_frames
from matcha.utils.audio import mel_spectrogram
from icefall.utils import get_executor from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down. @dataclass
# Do this outside of main() in case it needs to take effect class MyFbankConfig:
# even when we are not invoking the main (e.g. when spawning subprocesses). n_fft: int
torch.set_num_threads(1) n_mels: int
torch.set_num_interop_threads(1) sampling_rate: int
hop_length: int
win_length: int
f_min: float
f_max: float
@register_extractor
class MyFbank(FeatureExtractor):
name = "MyFbank"
config_type = MyFbankConfig
def __init__(self, config):
super().__init__(config=config)
@property
def device(self) -> Union[str, torch.device]:
return self.config.device
def feature_dim(self, sampling_rate: int) -> int:
return self.config.n_mels
def extract(
self,
samples: np.ndarray,
sampling_rate: int,
) -> torch.Tensor:
# Check for sampling rate compatibility.
expected_sr = self.config.sampling_rate
assert sampling_rate == expected_sr, (
f"Mismatched sampling rate: extractor expects {expected_sr}, "
f"got {sampling_rate}"
)
samples = torch.from_numpy(samples)
assert samples.ndim == 2, samples.shape
assert samples.shape[0] == 1, samples.shape
mel = (
mel_spectrogram(
samples,
self.config.n_fft,
self.config.n_mels,
self.config.sampling_rate,
self.config.hop_length,
self.config.win_length,
self.config.f_min,
self.config.f_max,
center=False,
)
.squeeze()
.t()
)
assert mel.ndim == 2, mel.shape
assert mel.shape[1] == self.config.n_mels, mel.shape
num_frames = compute_num_frames(
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
)
if mel.shape[0] > num_frames:
mel = mel[:num_frames]
elif mel.shape[0] < num_frames:
mel = mel.unsqueeze(0)
mel = torch.nn.functional.pad(
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
).squeeze(0)
return mel.numpy()
@property
def frame_shift(self) -> Seconds:
return self.config.hop_length / self.config.sampling_rate
def get_parser(): def get_parser():
@ -77,10 +149,15 @@ def compute_fbank_ljspeech(num_jobs: int):
logging.info(f"num_jobs: {num_jobs}") logging.info(f"num_jobs: {num_jobs}")
logging.info(f"src_dir: {src_dir}") logging.info(f"src_dir: {src_dir}")
logging.info(f"output_dir: {output_dir}") logging.info(f"output_dir: {output_dir}")
config = MyFbankConfig(
sampling_rate = 22050 n_fft=1024,
frame_length = 1024 / sampling_rate # (in second) n_mels=80,
frame_shift = 256 / sampling_rate # (in second) sampling_rate=22050,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
prefix = "ljspeech" prefix = "ljspeech"
suffix = "jsonl.gz" suffix = "jsonl.gz"
@ -93,25 +170,7 @@ def compute_fbank_ljspeech(num_jobs: int):
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
) )
# Differences with matcha-tts extractor = MyFbank(config)
# 1. we use pre-emphasis
# 2. we remove dc offset
# 3. we use a different window
# 4. we use a different mel filter bank matrix
# 5. we don't normalize features
config = FbankConfig(
sampling_rate=sampling_rate,
frame_length=frame_length,
frame_shift=frame_shift,
use_fft_mag=True,
low_freq=0,
high_freq=8000,
remove_dc_offset=False,
preemph_coeff=0,
# should be identical to n_feats in ../matcha/train.py
num_filters=80,
)
extractor = Fbank(config)
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
@ -135,6 +194,12 @@ def compute_fbank_ljspeech(num_jobs: int):
if __name__ == "__main__": if __name__ == "__main__":
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -35,6 +35,7 @@ from pathlib import Path
from lhotse import CutSet, load_manifest_lazy from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset.speech_synthesis import validate_for_tts from lhotse.dataset.speech_synthesis import validate_for_tts
from compute_fbank_ljspeech import MyFbank
def get_args(): def get_args():

View File

@ -3,18 +3,17 @@
import argparse import argparse
import json
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import json
import k2 import k2
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from matcha.data.text_mel_datamodule import TextMelDataModule
from matcha.models.matcha_tts import MatchaTTS from matcha.models.matcha_tts import MatchaTTS
from matcha.tokenizer import Tokenizer from matcha.tokenizer import Tokenizer
from matcha.utils.model import fix_len_compatibility from matcha.utils.model import fix_len_compatibility
@ -355,36 +354,27 @@ def compute_validation_loss(
with torch.no_grad(): with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
if "tokens" in batch: (
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device, params)
( losses = get_losses(
audio, {
audio_lens, "x": tokens,
features, "x_lengths": tokens_lens,
features_lens, "y": features.permute(0, 2, 1),
tokens, "y_lengths": features_lens,
tokens_lens, "spks": None, # should change it for multi-speakers
) = prepare_input(batch, tokenizer, device, params) "durations": None,
}
)
losses = get_losses( batch_size = len(batch["tokens"])
{
"x": tokens,
"x_lengths": tokens_lens,
"y": features.permute(0, 2, 1),
"y_lengths": features_lens,
"spks": None, # should change it for multi-speakers
"durations": None,
}
)
batch_size = len(batch["tokens"])
else:
batch_size = batch["x"].shape[0]
batch["x"] = batch["x"].to(device)
batch["x_lengths"] = batch["x_lengths"].to(device)
batch["y"] = batch["y"].to(device)
batch["y_lengths"] = batch["y_lengths"].to(device)
losses = get_losses(batch)
loss_info = MetricsTracker() loss_info = MetricsTracker()
loss_info["samples"] = batch_size loss_info["samples"] = batch_size
@ -478,38 +468,28 @@ def train_one_epoch(
# features_lens, (N,), int32 # features_lens, (N,), int32
# tokens: List[List[str]], len(tokens) == N # tokens: List[List[str]], len(tokens) == N
if "tokens" in batch: batch_size = len(batch["tokens"])
batch_size = len(batch["tokens"])
( (
audio, audio,
audio_lens, audio_lens,
features, features,
features_lens, features_lens,
tokens, tokens,
tokens_lens, tokens_lens,
) = prepare_input(batch, tokenizer, device, params) ) = prepare_input(batch, tokenizer, device, params)
else:
batch_size = batch["x"].shape[0]
try: try:
with autocast(enabled=params.use_fp16): with autocast(enabled=params.use_fp16):
if "tokens" in batch: losses = get_losses(
losses = get_losses( {
{ "x": tokens,
"x": tokens, "x_lengths": tokens_lens,
"x_lengths": tokens_lens, "y": features.permute(0, 2, 1),
"y": features.permute(0, 2, 1), "y_lengths": features_lens,
"y_lengths": features_lens, "spks": None, # should change it for multi-speakers
"spks": None, # should change it for multi-speakers "durations": None,
"durations": None, }
} )
)
else:
batch["x"] = batch["x"].to(device)
batch["x_lengths"] = batch["x_lengths"].to(device)
batch["y"] = batch["y"].to(device)
batch["y_lengths"] = batch["y_lengths"].to(device)
losses = get_losses(batch)
loss = sum(losses.values()) loss = sum(losses.values())
@ -535,8 +515,9 @@ def train_one_epoch(
raise raise
if params.batch_idx_train % 100 == 0 and params.use_fp16: if params.batch_idx_train % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval # If the grad scale was less than 1, try increasing it.
# of the grad scaler is configurable, but we can't configure it to have different # The _growth_interval of the grad scaler is configurable,
# but we can't configure it to have different
# behavior depending on the current grad scale. # behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item() cur_grad_scale = scaler._scale.item()
@ -560,7 +541,8 @@ def train_one_epoch(
logging.info( logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " f"global_batch_idx: {params.batch_idx_train}, "
f"batch size: {batch_size}, "
f"loss[{loss_info}], tot_loss[{tot_loss}], " f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
) )
@ -588,7 +570,8 @@ def train_one_epoch(
model.train() model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info( logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" "Maximum memory allocated so far is "
f"{torch.cuda.max_memory_allocated()//1000000}MB"
) )
if tb_writer is not None: if tb_writer is not None:
valid_info.write_summary( valid_info.write_summary(
@ -658,20 +641,13 @@ def run(rank, world_size, args):
logging.info("About to create datamodule") logging.info("About to create datamodule")
if False: ljspeech = LJSpeechTtsDataModule(args)
params.data_args.tokenizer = tokenizer
data_module = TextMelDataModule(hparams=params.data_args)
del params.data_args.tokenizer
train_dl = data_module.train_dataloader()
valid_dl = data_module.val_dataloader()
else:
ljspeech = LJSpeechTtsDataModule(args)
train_cuts = ljspeech.train_cuts() train_cuts = ljspeech.train_cuts()
train_dl = ljspeech.train_dataloaders(train_cuts) train_dl = ljspeech.train_dataloaders(train_cuts)
valid_cuts = ljspeech.valid_cuts() valid_cuts = ljspeech.valid_cuts()
valid_dl = ljspeech.valid_dataloaders(valid_cuts) valid_dl = ljspeech.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints: