mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
remove unused code
This commit is contained in:
parent
7994684bf4
commit
c558328dc5
@ -27,28 +27,100 @@ The generated fbank features are saved in data/fbank.
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from lhotse import (
|
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
|
||||||
CutSet,
|
|
||||||
Fbank,
|
|
||||||
FbankConfig,
|
|
||||||
LilcomChunkyWriter,
|
|
||||||
load_manifest,
|
|
||||||
load_manifest_lazy,
|
|
||||||
)
|
|
||||||
from lhotse.audio import RecordingSet
|
from lhotse.audio import RecordingSet
|
||||||
|
from lhotse.features.base import FeatureExtractor, register_extractor
|
||||||
from lhotse.supervision import SupervisionSet
|
from lhotse.supervision import SupervisionSet
|
||||||
|
from lhotse.utils import Seconds, compute_num_frames
|
||||||
|
from matcha.utils.audio import mel_spectrogram
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
|
||||||
# it wastes a lot of CPU and slow things down.
|
@dataclass
|
||||||
# Do this outside of main() in case it needs to take effect
|
class MyFbankConfig:
|
||||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
n_fft: int
|
||||||
torch.set_num_threads(1)
|
n_mels: int
|
||||||
torch.set_num_interop_threads(1)
|
sampling_rate: int
|
||||||
|
hop_length: int
|
||||||
|
win_length: int
|
||||||
|
f_min: float
|
||||||
|
f_max: float
|
||||||
|
|
||||||
|
|
||||||
|
@register_extractor
|
||||||
|
class MyFbank(FeatureExtractor):
|
||||||
|
|
||||||
|
name = "MyFbank"
|
||||||
|
config_type = MyFbankConfig
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Union[str, torch.device]:
|
||||||
|
return self.config.device
|
||||||
|
|
||||||
|
def feature_dim(self, sampling_rate: int) -> int:
|
||||||
|
return self.config.n_mels
|
||||||
|
|
||||||
|
def extract(
|
||||||
|
self,
|
||||||
|
samples: np.ndarray,
|
||||||
|
sampling_rate: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Check for sampling rate compatibility.
|
||||||
|
expected_sr = self.config.sampling_rate
|
||||||
|
assert sampling_rate == expected_sr, (
|
||||||
|
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
||||||
|
f"got {sampling_rate}"
|
||||||
|
)
|
||||||
|
samples = torch.from_numpy(samples)
|
||||||
|
assert samples.ndim == 2, samples.shape
|
||||||
|
assert samples.shape[0] == 1, samples.shape
|
||||||
|
|
||||||
|
mel = (
|
||||||
|
mel_spectrogram(
|
||||||
|
samples,
|
||||||
|
self.config.n_fft,
|
||||||
|
self.config.n_mels,
|
||||||
|
self.config.sampling_rate,
|
||||||
|
self.config.hop_length,
|
||||||
|
self.config.win_length,
|
||||||
|
self.config.f_min,
|
||||||
|
self.config.f_max,
|
||||||
|
center=False,
|
||||||
|
)
|
||||||
|
.squeeze()
|
||||||
|
.t()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mel.ndim == 2, mel.shape
|
||||||
|
assert mel.shape[1] == self.config.n_mels, mel.shape
|
||||||
|
|
||||||
|
num_frames = compute_num_frames(
|
||||||
|
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
if mel.shape[0] > num_frames:
|
||||||
|
mel = mel[:num_frames]
|
||||||
|
elif mel.shape[0] < num_frames:
|
||||||
|
mel = mel.unsqueeze(0)
|
||||||
|
mel = torch.nn.functional.pad(
|
||||||
|
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
return mel.numpy()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def frame_shift(self) -> Seconds:
|
||||||
|
return self.config.hop_length / self.config.sampling_rate
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -77,10 +149,15 @@ def compute_fbank_ljspeech(num_jobs: int):
|
|||||||
logging.info(f"num_jobs: {num_jobs}")
|
logging.info(f"num_jobs: {num_jobs}")
|
||||||
logging.info(f"src_dir: {src_dir}")
|
logging.info(f"src_dir: {src_dir}")
|
||||||
logging.info(f"output_dir: {output_dir}")
|
logging.info(f"output_dir: {output_dir}")
|
||||||
|
config = MyFbankConfig(
|
||||||
sampling_rate = 22050
|
n_fft=1024,
|
||||||
frame_length = 1024 / sampling_rate # (in second)
|
n_mels=80,
|
||||||
frame_shift = 256 / sampling_rate # (in second)
|
sampling_rate=22050,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
f_min=0,
|
||||||
|
f_max=8000,
|
||||||
|
)
|
||||||
|
|
||||||
prefix = "ljspeech"
|
prefix = "ljspeech"
|
||||||
suffix = "jsonl.gz"
|
suffix = "jsonl.gz"
|
||||||
@ -93,25 +170,7 @@ def compute_fbank_ljspeech(num_jobs: int):
|
|||||||
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
|
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
|
||||||
)
|
)
|
||||||
|
|
||||||
# Differences with matcha-tts
|
extractor = MyFbank(config)
|
||||||
# 1. we use pre-emphasis
|
|
||||||
# 2. we remove dc offset
|
|
||||||
# 3. we use a different window
|
|
||||||
# 4. we use a different mel filter bank matrix
|
|
||||||
# 5. we don't normalize features
|
|
||||||
config = FbankConfig(
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
frame_length=frame_length,
|
|
||||||
frame_shift=frame_shift,
|
|
||||||
use_fft_mag=True,
|
|
||||||
low_freq=0,
|
|
||||||
high_freq=8000,
|
|
||||||
remove_dc_offset=False,
|
|
||||||
preemph_coeff=0,
|
|
||||||
# should be identical to n_feats in ../matcha/train.py
|
|
||||||
num_filters=80,
|
|
||||||
)
|
|
||||||
extractor = Fbank(config)
|
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||||
@ -135,6 +194,12 @@ def compute_fbank_ljspeech(num_jobs: int):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
@ -35,6 +35,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from lhotse import CutSet, load_manifest_lazy
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
from lhotse.dataset.speech_synthesis import validate_for_tts
|
from lhotse.dataset.speech_synthesis import validate_for_tts
|
||||||
|
from compute_fbank_ljspeech import MyFbank
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -3,18 +3,17 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
import json
|
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from matcha.data.text_mel_datamodule import TextMelDataModule
|
|
||||||
from matcha.models.matcha_tts import MatchaTTS
|
from matcha.models.matcha_tts import MatchaTTS
|
||||||
from matcha.tokenizer import Tokenizer
|
from matcha.tokenizer import Tokenizer
|
||||||
from matcha.utils.model import fix_len_compatibility
|
from matcha.utils.model import fix_len_compatibility
|
||||||
@ -355,36 +354,27 @@ def compute_validation_loss(
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
if "tokens" in batch:
|
(
|
||||||
|
audio,
|
||||||
|
audio_lens,
|
||||||
|
features,
|
||||||
|
features_lens,
|
||||||
|
tokens,
|
||||||
|
tokens_lens,
|
||||||
|
) = prepare_input(batch, tokenizer, device, params)
|
||||||
|
|
||||||
(
|
losses = get_losses(
|
||||||
audio,
|
{
|
||||||
audio_lens,
|
"x": tokens,
|
||||||
features,
|
"x_lengths": tokens_lens,
|
||||||
features_lens,
|
"y": features.permute(0, 2, 1),
|
||||||
tokens,
|
"y_lengths": features_lens,
|
||||||
tokens_lens,
|
"spks": None, # should change it for multi-speakers
|
||||||
) = prepare_input(batch, tokenizer, device, params)
|
"durations": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
losses = get_losses(
|
batch_size = len(batch["tokens"])
|
||||||
{
|
|
||||||
"x": tokens,
|
|
||||||
"x_lengths": tokens_lens,
|
|
||||||
"y": features.permute(0, 2, 1),
|
|
||||||
"y_lengths": features_lens,
|
|
||||||
"spks": None, # should change it for multi-speakers
|
|
||||||
"durations": None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size = len(batch["tokens"])
|
|
||||||
else:
|
|
||||||
batch_size = batch["x"].shape[0]
|
|
||||||
batch["x"] = batch["x"].to(device)
|
|
||||||
batch["x_lengths"] = batch["x_lengths"].to(device)
|
|
||||||
batch["y"] = batch["y"].to(device)
|
|
||||||
batch["y_lengths"] = batch["y_lengths"].to(device)
|
|
||||||
losses = get_losses(batch)
|
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info["samples"] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
@ -478,38 +468,28 @@ def train_one_epoch(
|
|||||||
# features_lens, (N,), int32
|
# features_lens, (N,), int32
|
||||||
# tokens: List[List[str]], len(tokens) == N
|
# tokens: List[List[str]], len(tokens) == N
|
||||||
|
|
||||||
if "tokens" in batch:
|
batch_size = len(batch["tokens"])
|
||||||
batch_size = len(batch["tokens"])
|
|
||||||
|
|
||||||
(
|
(
|
||||||
audio,
|
audio,
|
||||||
audio_lens,
|
audio_lens,
|
||||||
features,
|
features,
|
||||||
features_lens,
|
features_lens,
|
||||||
tokens,
|
tokens,
|
||||||
tokens_lens,
|
tokens_lens,
|
||||||
) = prepare_input(batch, tokenizer, device, params)
|
) = prepare_input(batch, tokenizer, device, params)
|
||||||
else:
|
|
||||||
batch_size = batch["x"].shape[0]
|
|
||||||
try:
|
try:
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
if "tokens" in batch:
|
losses = get_losses(
|
||||||
losses = get_losses(
|
{
|
||||||
{
|
"x": tokens,
|
||||||
"x": tokens,
|
"x_lengths": tokens_lens,
|
||||||
"x_lengths": tokens_lens,
|
"y": features.permute(0, 2, 1),
|
||||||
"y": features.permute(0, 2, 1),
|
"y_lengths": features_lens,
|
||||||
"y_lengths": features_lens,
|
"spks": None, # should change it for multi-speakers
|
||||||
"spks": None, # should change it for multi-speakers
|
"durations": None,
|
||||||
"durations": None,
|
}
|
||||||
}
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
batch["x"] = batch["x"].to(device)
|
|
||||||
batch["x_lengths"] = batch["x_lengths"].to(device)
|
|
||||||
batch["y"] = batch["y"].to(device)
|
|
||||||
batch["y_lengths"] = batch["y_lengths"].to(device)
|
|
||||||
losses = get_losses(batch)
|
|
||||||
|
|
||||||
loss = sum(losses.values())
|
loss = sum(losses.values())
|
||||||
|
|
||||||
@ -535,8 +515,9 @@ def train_one_epoch(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
# If the grad scale was less than 1, try increasing it.
|
||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
# The _growth_interval of the grad scaler is configurable,
|
||||||
|
# but we can't configure it to have different
|
||||||
# behavior depending on the current grad scale.
|
# behavior depending on the current grad scale.
|
||||||
cur_grad_scale = scaler._scale.item()
|
cur_grad_scale = scaler._scale.item()
|
||||||
|
|
||||||
@ -560,7 +541,8 @@ def train_one_epoch(
|
|||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||||
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
|
f"global_batch_idx: {params.batch_idx_train}, "
|
||||||
|
f"batch size: {batch_size}, "
|
||||||
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
||||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||||
)
|
)
|
||||||
@ -588,7 +570,8 @@ def train_one_epoch(
|
|||||||
model.train()
|
model.train()
|
||||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
"Maximum memory allocated so far is "
|
||||||
|
f"{torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
)
|
)
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
@ -658,20 +641,13 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
logging.info("About to create datamodule")
|
logging.info("About to create datamodule")
|
||||||
|
|
||||||
if False:
|
ljspeech = LJSpeechTtsDataModule(args)
|
||||||
params.data_args.tokenizer = tokenizer
|
|
||||||
data_module = TextMelDataModule(hparams=params.data_args)
|
|
||||||
del params.data_args.tokenizer
|
|
||||||
train_dl = data_module.train_dataloader()
|
|
||||||
valid_dl = data_module.val_dataloader()
|
|
||||||
else:
|
|
||||||
ljspeech = LJSpeechTtsDataModule(args)
|
|
||||||
|
|
||||||
train_cuts = ljspeech.train_cuts()
|
train_cuts = ljspeech.train_cuts()
|
||||||
train_dl = ljspeech.train_dataloaders(train_cuts)
|
train_dl = ljspeech.train_dataloaders(train_cuts)
|
||||||
|
|
||||||
valid_cuts = ljspeech.valid_cuts()
|
valid_cuts = ljspeech.valid_cuts()
|
||||||
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
|
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user