mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
using official repo
This commit is contained in:
parent
604ab6f6b3
commit
511f63b551
@ -44,7 +44,7 @@ from lhotse.dataset.sampling.base import CutSampler
|
|||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model.cfm import CFM
|
from model.cfm import CFM
|
||||||
from model.dit import DiT
|
from model.dit import DiT
|
||||||
from model.utils import MelSpec
|
from model.utils import convert_char_to_pinyin
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
@ -151,7 +151,7 @@ def get_parser():
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=Path,
|
||||||
default="exp/valle_dev",
|
default="exp/valle_dev",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
@ -162,7 +162,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokens",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="ft-tts/vocab.txt",
|
default="f5-tts/vocab.txt",
|
||||||
help="Path to the unique text tokens file",
|
help="Path to the unique text tokens file",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -409,7 +409,7 @@ def get_model(params):
|
|||||||
def load_pretrained_checkpoint(
|
def load_pretrained_checkpoint(
|
||||||
model, ckpt_path, device: str = "cpu", dtype=torch.float32
|
model, ckpt_path, device: str = "cpu", dtype=torch.float32
|
||||||
):
|
):
|
||||||
model = model.to(dtype)
|
# model = model.to(dtype)
|
||||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
|
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||||
|
|
||||||
checkpoint["model_state_dict"] = {
|
checkpoint["model_state_dict"] = {
|
||||||
@ -548,14 +548,15 @@ def save_checkpoint(
|
|||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
def prepare_input(batch: dict, tokenizer, device: torch.device):
|
def prepare_input(batch: dict, device: torch.device):
|
||||||
"""Parse batch data"""
|
"""Parse batch data"""
|
||||||
print(batch.keys())
|
|
||||||
print(batch)
|
|
||||||
text_inputs = batch["text"]
|
text_inputs = batch["text"]
|
||||||
mel_spec = batch["mel"].permute(0, 2, 1)
|
# texts.extend(convert_char_to_pinyin([text], polyphone=true))
|
||||||
mel_lengths = batch["mel_lengths"]
|
text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True)
|
||||||
return text_inputs, mel_spec, mel_lengths
|
print(text_inputs)
|
||||||
|
mel_spec = batch["features"]
|
||||||
|
mel_lengths = batch["features_lens"]
|
||||||
|
return text_inputs, mel_spec.to(device), mel_lengths.to(device)
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
@ -584,34 +585,28 @@ def compute_loss(
|
|||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
(mel_spec, text_inputs, mel_lengths) = prepare_input(batch, device)
|
(text_inputs, mel_spec, mel_lengths) = prepare_input(batch, device=device)
|
||||||
# at entry, TextTokens is (N, P)
|
# at entry, TextTokens is (N, P)
|
||||||
assert text_inputs.ndim == 2
|
|
||||||
assert mel_spec.ndim == 3
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths)
|
loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths)
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
print(loss)
|
||||||
|
# from accelerate import Accelerator
|
||||||
|
# from accelerate.utils import DistributedDataParallelKwargs
|
||||||
|
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
|
# accelerator = Accelerator(
|
||||||
|
# kwargs_handlers=[ddp_kwargs],
|
||||||
|
# )
|
||||||
|
# accelerator.backward(loss)
|
||||||
|
# loss.backward()
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
exit(0)
|
|
||||||
# with warnings.catch_warnings():
|
# with warnings.catch_warnings():
|
||||||
# warnings.simplefilter("ignore")
|
# warnings.simplefilter("ignore")
|
||||||
# info["frames"] = (audio_features_lens).sum().item()
|
# info["samples"] = mel_lengths.size(0)
|
||||||
# info["utterances"] = text_tokens.size(0)
|
|
||||||
|
|
||||||
# # # Note: We use reduction=sum while computing the loss.
|
# info["loss"] = loss.detach().cpu().item() * info["samples"]
|
||||||
# # info["loss"] = loss.detach().cpu().item()
|
|
||||||
# # for metric in metrics:
|
|
||||||
# # info[metric] = metrics[metric].detach().cpu().item()
|
|
||||||
# # del metrics
|
|
||||||
# # Note: We use reduction=sum while computing the loss.
|
|
||||||
# info["loss"] = loss.detach().cpu().item() * info["frames"]
|
|
||||||
|
|
||||||
# for i in range(len(loss_list)):
|
|
||||||
# info[f"loss_{i}"] = loss_list[i].detach().cpu().item() * info["frames"]
|
|
||||||
# for i in range(len(acc_list)):
|
|
||||||
# info[f"acc_{i}"] = acc_list[i] * info["frames"]
|
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -734,6 +729,7 @@ def train_one_epoch(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * (
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * (
|
||||||
1 / params.reset_interval
|
1 / params.reset_interval
|
||||||
@ -753,7 +749,9 @@ def train_one_epoch(
|
|||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
# optimizer.zero_grad()
|
||||||
|
# loss.backward()
|
||||||
|
# optimizer.step()
|
||||||
|
|
||||||
for k in range(params.accumulate_grad_steps):
|
for k in range(params.accumulate_grad_steps):
|
||||||
if isinstance(scheduler, Eden):
|
if isinstance(scheduler, Eden):
|
||||||
@ -926,12 +924,7 @@ def run(rank, world_size, args):
|
|||||||
logging.info("Training started")
|
logging.info("Training started")
|
||||||
|
|
||||||
if args.tensorboard and rank == 0:
|
if args.tensorboard and rank == 0:
|
||||||
if params.train_stage:
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
tb_writer = SummaryWriter(
|
|
||||||
log_dir=f"{params.exp_dir}/tensorboard_stage{params.train_stage}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
|
||||||
else:
|
else:
|
||||||
tb_writer = None
|
tb_writer = None
|
||||||
|
|
||||||
@ -950,7 +943,7 @@ def run(rank, world_size, args):
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
model = load_pretrained_checkpoint(model, params.pretrained_model_path)
|
# model = load_pretrained_checkpoint(model, params.pretrained_model_path)
|
||||||
|
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
@ -968,6 +961,7 @@ def run(rank, world_size, args):
|
|||||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||||
|
|
||||||
assert params.start_epoch > 0, params.start_epoch
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
|
|
||||||
checkpoints = load_checkpoint_if_available(
|
checkpoints = load_checkpoint_if_available(
|
||||||
params=params, model=model, model_avg=model_avg
|
params=params, model=model, model_avg=model_avg
|
||||||
)
|
)
|
||||||
@ -1029,7 +1023,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
dataset = TtsDataModule(args)
|
dataset = TtsDataModule(args)
|
||||||
train_cuts = dataset.train_cuts()
|
train_cuts = dataset.train_cuts()
|
||||||
valid_cuts = dataset.dev_cuts()
|
valid_cuts = dataset.valid_cuts()
|
||||||
|
|
||||||
train_cuts = filter_short_and_long_utterances(
|
train_cuts = filter_short_and_long_utterances(
|
||||||
train_cuts, params.filter_min_duration, params.filter_max_duration
|
train_cuts, params.filter_min_duration, params.filter_max_duration
|
||||||
@ -1041,7 +1035,7 @@ def run(rank, world_size, args):
|
|||||||
train_dl = dataset.train_dataloaders(
|
train_dl = dataset.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
valid_dl = dataset.dev_dataloaders(valid_cuts)
|
valid_dl = dataset.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if params.oom_check:
|
if params.oom_check:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
@ -1136,7 +1130,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
||||||
)
|
)
|
||||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||||
|
print(23333)
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
if params.dtype in ["bfloat16", "bf16"]:
|
if params.dtype in ["bfloat16", "bf16"]:
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
@ -1145,16 +1139,17 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
|
|
||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
|
print(batch.keys())
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(dtype=dtype):
|
with torch.cuda.amp.autocast(dtype=dtype):
|
||||||
_, loss, _ = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward(retain_graph=True)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "CUDA out of memory" in str(e):
|
if "CUDA out of memory" in str(e):
|
||||||
|
@ -24,21 +24,22 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fbank import MatchaFbank, MatchaFbankConfig
|
|
||||||
|
# from fbank import MatchaFbank, MatchaFbankConfig
|
||||||
from lhotse import CutSet, load_manifest_lazy
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset,
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
CutMix,
|
CutMix,
|
||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SimpleCutSampler,
|
SimpleCutSampler,
|
||||||
SpeechSynthesisDataset,
|
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
AudioSamples,
|
AudioSamples,
|
||||||
OnTheFlyFeatures,
|
OnTheFlyFeatures,
|
||||||
)
|
)
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
from speech_synthesis import SpeechSynthesisDataset # noqa F401
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
@ -174,29 +175,32 @@ class TtsDataModule:
|
|||||||
"""
|
"""
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
train = SpeechSynthesisDataset(
|
train = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=True,
|
||||||
return_tokens=True,
|
return_tokens=False,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
sampling_rate = 22050
|
raise NotImplementedError(
|
||||||
config = MatchaFbankConfig(
|
"On-the-fly feature extraction is not implemented yet."
|
||||||
n_fft=1024,
|
|
||||||
n_mels=80,
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
hop_length=256,
|
|
||||||
win_length=1024,
|
|
||||||
f_min=0,
|
|
||||||
f_max=8000,
|
|
||||||
)
|
|
||||||
train = SpeechSynthesisDataset(
|
|
||||||
return_text=False,
|
|
||||||
return_tokens=True,
|
|
||||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
)
|
||||||
|
# sampling_rate = 22050
|
||||||
|
# config = MatchaFbankConfig(
|
||||||
|
# n_fft=1024,
|
||||||
|
# n_mels=80,
|
||||||
|
# sampling_rate=sampling_rate,
|
||||||
|
# hop_length=256,
|
||||||
|
# win_length=1024,
|
||||||
|
# f_min=0,
|
||||||
|
# f_max=8000,
|
||||||
|
# )
|
||||||
|
# train = SpeechSynthesisDataset(
|
||||||
|
# return_text=True,
|
||||||
|
# return_tokens=False,
|
||||||
|
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||||
|
# return_cuts=self.args.return_cuts,
|
||||||
|
# )
|
||||||
|
|
||||||
if self.args.bucketing_sampler:
|
if self.args.bucketing_sampler:
|
||||||
logging.info("Using DynamicBucketingSampler.")
|
logging.info("Using DynamicBucketingSampler.")
|
||||||
@ -242,26 +246,29 @@ class TtsDataModule:
|
|||||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
sampling_rate = 22050
|
raise NotImplementedError(
|
||||||
config = MatchaFbankConfig(
|
"On-the-fly feature extraction is not implemented yet."
|
||||||
n_fft=1024,
|
|
||||||
n_mels=80,
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
hop_length=256,
|
|
||||||
win_length=1024,
|
|
||||||
f_min=0,
|
|
||||||
f_max=8000,
|
|
||||||
)
|
|
||||||
validate = SpeechSynthesisDataset(
|
|
||||||
return_text=False,
|
|
||||||
return_tokens=True,
|
|
||||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
)
|
||||||
|
# sampling_rate = 22050
|
||||||
|
# config = MatchaFbankConfig(
|
||||||
|
# n_fft=1024,
|
||||||
|
# n_mels=80,
|
||||||
|
# sampling_rate=sampling_rate,
|
||||||
|
# hop_length=256,
|
||||||
|
# win_length=1024,
|
||||||
|
# f_min=0,
|
||||||
|
# f_max=8000,
|
||||||
|
# )
|
||||||
|
# validate = SpeechSynthesisDataset(
|
||||||
|
# return_text=True,
|
||||||
|
# return_tokens=False,
|
||||||
|
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||||
|
# return_cuts=self.args.return_cuts,
|
||||||
|
# )
|
||||||
else:
|
else:
|
||||||
validate = SpeechSynthesisDataset(
|
validate = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=True,
|
||||||
return_tokens=True,
|
return_tokens=False,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -286,26 +293,29 @@ class TtsDataModule:
|
|||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
logging.info("About to create test dataset")
|
logging.info("About to create test dataset")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
sampling_rate = 22050
|
raise NotImplementedError(
|
||||||
config = MatchaFbankConfig(
|
"On-the-fly feature extraction is not implemented yet."
|
||||||
n_fft=1024,
|
|
||||||
n_mels=80,
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
hop_length=256,
|
|
||||||
win_length=1024,
|
|
||||||
f_min=0,
|
|
||||||
f_max=8000,
|
|
||||||
)
|
|
||||||
test = SpeechSynthesisDataset(
|
|
||||||
return_text=False,
|
|
||||||
return_tokens=True,
|
|
||||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
)
|
||||||
|
# sampling_rate = 22050
|
||||||
|
# config = MatchaFbankConfig(
|
||||||
|
# n_fft=1024,
|
||||||
|
# n_mels=80,
|
||||||
|
# sampling_rate=sampling_rate,
|
||||||
|
# hop_length=256,
|
||||||
|
# win_length=1024,
|
||||||
|
# f_min=0,
|
||||||
|
# f_max=8000,
|
||||||
|
# )
|
||||||
|
# test = SpeechSynthesisDataset(
|
||||||
|
# return_text=True,
|
||||||
|
# return_tokens=False,
|
||||||
|
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||||
|
# return_cuts=self.args.return_cuts,
|
||||||
|
# )
|
||||||
else:
|
else:
|
||||||
test = SpeechSynthesisDataset(
|
test = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=True,
|
||||||
return_tokens=True,
|
return_tokens=False,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user