using official repo

This commit is contained in:
yuekaiz 2024-12-22 18:48:26 +08:00
parent 604ab6f6b3
commit 511f63b551
2 changed files with 100 additions and 95 deletions

View File

@ -44,7 +44,7 @@ from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model.cfm import CFM from model.cfm import CFM
from model.dit import DiT from model.dit import DiT
from model.utils import MelSpec from model.utils import convert_char_to_pinyin
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
@ -151,7 +151,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=Path,
default="exp/valle_dev", default="exp/valle_dev",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
@ -162,7 +162,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--tokens", "--tokens",
type=str, type=str,
default="ft-tts/vocab.txt", default="f5-tts/vocab.txt",
help="Path to the unique text tokens file", help="Path to the unique text tokens file",
) )
@ -409,7 +409,7 @@ def get_model(params):
def load_pretrained_checkpoint( def load_pretrained_checkpoint(
model, ckpt_path, device: str = "cpu", dtype=torch.float32 model, ckpt_path, device: str = "cpu", dtype=torch.float32
): ):
model = model.to(dtype) # model = model.to(dtype)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
checkpoint["model_state_dict"] = { checkpoint["model_state_dict"] = {
@ -548,14 +548,15 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def prepare_input(batch: dict, tokenizer, device: torch.device): def prepare_input(batch: dict, device: torch.device):
"""Parse batch data""" """Parse batch data"""
print(batch.keys())
print(batch)
text_inputs = batch["text"] text_inputs = batch["text"]
mel_spec = batch["mel"].permute(0, 2, 1) # texts.extend(convert_char_to_pinyin([text], polyphone=true))
mel_lengths = batch["mel_lengths"] text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
return text_inputs, mel_spec, mel_lengths print(text_inputs)
mel_spec = batch["features"]
mel_lengths = batch["features_lens"]
return text_inputs, mel_spec.to(device), mel_lengths.to(device)
def compute_loss( def compute_loss(
@ -584,34 +585,28 @@ def compute_loss(
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device
(mel_spec, text_inputs, mel_lengths) = prepare_input(batch, device) (text_inputs, mel_spec, mel_lengths) = prepare_input(batch, device=device)
# at entry, TextTokens is (N, P) # at entry, TextTokens is (N, P)
assert text_inputs.ndim == 2
assert mel_spec.ndim == 3
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths) loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
print(loss)
# from accelerate import Accelerator
# from accelerate.utils import DistributedDataParallelKwargs
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# accelerator = Accelerator(
# kwargs_handlers=[ddp_kwargs],
# )
# accelerator.backward(loss)
# loss.backward()
info = MetricsTracker() info = MetricsTracker()
exit(0)
# with warnings.catch_warnings(): # with warnings.catch_warnings():
# warnings.simplefilter("ignore") # warnings.simplefilter("ignore")
# info["frames"] = (audio_features_lens).sum().item() # info["samples"] = mel_lengths.size(0)
# info["utterances"] = text_tokens.size(0)
# # # Note: We use reduction=sum while computing the loss. # info["loss"] = loss.detach().cpu().item() * info["samples"]
# # info["loss"] = loss.detach().cpu().item()
# # for metric in metrics:
# # info[metric] = metrics[metric].detach().cpu().item()
# # del metrics
# # Note: We use reduction=sum while computing the loss.
# info["loss"] = loss.detach().cpu().item() * info["frames"]
# for i in range(len(loss_list)):
# info[f"loss_{i}"] = loss_list[i].detach().cpu().item() * info["frames"]
# for i in range(len(acc_list)):
# info[f"acc_{i}"] = acc_list[i] * info["frames"]
return loss, info return loss, info
@ -734,6 +729,7 @@ def train_one_epoch(
batch=batch, batch=batch,
is_training=True, is_training=True,
) )
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * ( tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * (
1 / params.reset_interval 1 / params.reset_interval
@ -753,7 +749,9 @@ def train_one_epoch(
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad() # optimizer.zero_grad()
# loss.backward()
# optimizer.step()
for k in range(params.accumulate_grad_steps): for k in range(params.accumulate_grad_steps):
if isinstance(scheduler, Eden): if isinstance(scheduler, Eden):
@ -926,12 +924,7 @@ def run(rank, world_size, args):
logging.info("Training started") logging.info("Training started")
if args.tensorboard and rank == 0: if args.tensorboard and rank == 0:
if params.train_stage: tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
tb_writer = SummaryWriter(
log_dir=f"{params.exp_dir}/tensorboard_stage{params.train_stage}"
)
else:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else: else:
tb_writer = None tb_writer = None
@ -950,7 +943,7 @@ def run(rank, world_size, args):
logging.info("About to create model") logging.info("About to create model")
model = get_model(params) model = get_model(params)
model = load_pretrained_checkpoint(model, params.pretrained_model_path) # model = load_pretrained_checkpoint(model, params.pretrained_model_path)
model = model.to(device) model = model.to(device)
@ -968,6 +961,7 @@ def run(rank, world_size, args):
model_avg = copy.deepcopy(model).to(torch.float64) model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available( checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg params=params, model=model, model_avg=model_avg
) )
@ -1029,7 +1023,7 @@ def run(rank, world_size, args):
dataset = TtsDataModule(args) dataset = TtsDataModule(args)
train_cuts = dataset.train_cuts() train_cuts = dataset.train_cuts()
valid_cuts = dataset.dev_cuts() valid_cuts = dataset.valid_cuts()
train_cuts = filter_short_and_long_utterances( train_cuts = filter_short_and_long_utterances(
train_cuts, params.filter_min_duration, params.filter_max_duration train_cuts, params.filter_min_duration, params.filter_max_duration
@ -1041,7 +1035,7 @@ def run(rank, world_size, args):
train_dl = dataset.train_dataloaders( train_dl = dataset.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
valid_dl = dataset.dev_dataloaders(valid_cuts) valid_dl = dataset.valid_dataloaders(valid_cuts)
if params.oom_check: if params.oom_check:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
@ -1136,7 +1130,7 @@ def scan_pessimistic_batches_for_oom(
"Sanity check -- see if any of the batches in epoch 1 would cause OOM." "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
) )
batches, crit_values = find_pessimistic_batches(train_dl.sampler) batches, crit_values = find_pessimistic_batches(train_dl.sampler)
print(23333)
dtype = torch.float32 dtype = torch.float32
if params.dtype in ["bfloat16", "bf16"]: if params.dtype in ["bfloat16", "bf16"]:
dtype = torch.bfloat16 dtype = torch.bfloat16
@ -1145,16 +1139,17 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
print(batch.keys())
try: try:
with torch.cuda.amp.autocast(dtype=dtype): with torch.cuda.amp.autocast(dtype=dtype):
_, loss, _ = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
batch=batch, batch=batch,
is_training=True, is_training=True,
) )
loss.backward() loss.backward(retain_graph=True)
optimizer.zero_grad() optimizer.zero_grad()
except Exception as e: except Exception as e:
if "CUDA out of memory" in str(e): if "CUDA out of memory" in str(e):

View File

@ -24,21 +24,22 @@ from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch import torch
from fbank import MatchaFbank, MatchaFbankConfig
# from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, load_manifest_lazy from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset,
CutConcatenate, CutConcatenate,
CutMix, CutMix,
DynamicBucketingSampler, DynamicBucketingSampler,
PrecomputedFeatures, PrecomputedFeatures,
SimpleCutSampler, SimpleCutSampler,
SpeechSynthesisDataset,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples, AudioSamples,
OnTheFlyFeatures, OnTheFlyFeatures,
) )
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from speech_synthesis import SpeechSynthesisDataset # noqa F401
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
@ -174,29 +175,32 @@ class TtsDataModule:
""" """
logging.info("About to create train dataset") logging.info("About to create train dataset")
train = SpeechSynthesisDataset( train = SpeechSynthesisDataset(
return_text=False, return_text=True,
return_tokens=True, return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(), feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
sampling_rate = 22050 raise NotImplementedError(
config = MatchaFbankConfig( "On-the-fly feature extraction is not implemented yet."
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
) )
# sampling_rate = 22050
# config = MatchaFbankConfig(
# n_fft=1024,
# n_mels=80,
# sampling_rate=sampling_rate,
# hop_length=256,
# win_length=1024,
# f_min=0,
# f_max=8000,
# )
# train = SpeechSynthesisDataset(
# return_text=True,
# return_tokens=False,
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
# return_cuts=self.args.return_cuts,
# )
if self.args.bucketing_sampler: if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.") logging.info("Using DynamicBucketingSampler.")
@ -242,26 +246,29 @@ class TtsDataModule:
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
sampling_rate = 22050 raise NotImplementedError(
config = MatchaFbankConfig( "On-the-fly feature extraction is not implemented yet."
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
) )
# sampling_rate = 22050
# config = MatchaFbankConfig(
# n_fft=1024,
# n_mels=80,
# sampling_rate=sampling_rate,
# hop_length=256,
# win_length=1024,
# f_min=0,
# f_max=8000,
# )
# validate = SpeechSynthesisDataset(
# return_text=True,
# return_tokens=False,
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
# return_cuts=self.args.return_cuts,
# )
else: else:
validate = SpeechSynthesisDataset( validate = SpeechSynthesisDataset(
return_text=False, return_text=True,
return_tokens=True, return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(), feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -286,26 +293,29 @@ class TtsDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset") logging.info("About to create test dataset")
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
sampling_rate = 22050 raise NotImplementedError(
config = MatchaFbankConfig( "On-the-fly feature extraction is not implemented yet."
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
) )
# sampling_rate = 22050
# config = MatchaFbankConfig(
# n_fft=1024,
# n_mels=80,
# sampling_rate=sampling_rate,
# hop_length=256,
# win_length=1024,
# f_min=0,
# f_max=8000,
# )
# test = SpeechSynthesisDataset(
# return_text=True,
# return_tokens=False,
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
# return_cuts=self.args.return_cuts,
# )
else: else:
test = SpeechSynthesisDataset( test = SpeechSynthesisDataset(
return_text=False, return_text=True,
return_tokens=True, return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(), feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )