add tts training

This commit is contained in:
root 2025-05-27 00:18:23 -07:00
parent 39700d5c94
commit 1281d7a515

View File

@ -50,6 +50,7 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import transformers import transformers
from datasets import load_dataset
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss from label_smoothing import LabelSmoothingLoss
@ -68,7 +69,7 @@ from transformers import (
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
from torch.utils.data import DistributedSampler, DataLoader from torch.utils.data import DistributedSampler, DataLoader
from train import add_model_arguments, add_training_arguments, get_params, compute_validation_loss, get_model, display_and_save_batch from train import add_model_arguments, add_training_arguments, get_params, get_model
from utils import ( # filter_uneven_sized_batch, from utils import ( # filter_uneven_sized_batch,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
@ -91,12 +92,12 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
# parser.add_argument( parser.add_argument(
# "--loss-type", "--batch-size",
# type=str, type=int,
# default="ce", default=16,
# help="The type of loss to use.", help="The batch size to use.",
# ) )
parser = deepspeed.add_config_arguments(parser) parser = deepspeed.add_config_arguments(parser)
add_model_arguments(parser) add_model_arguments(parser)
@ -161,7 +162,7 @@ def preprocess(
attention_mask = input_ids.ne(tokenizer.pad_token_id) attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids return input_ids, attention_mask, target_ids
def data_collator(batch, tokenizer, cut_off_len=2048): def data_collator(batch):
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], [] speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
for i, item in enumerate(batch): for i, item in enumerate(batch):
speech_tokens.append(item["code"]) speech_tokens.append(item["code"])
@ -176,21 +177,15 @@ def data_collator(batch, tokenizer, cut_off_len=2048):
lang.append(item["language"]) lang.append(item["language"])
dnsmos.append(item["dnsmos"]) dnsmos.append(item["dnsmos"])
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
target_ids = target_ids.type(torch.LongTensor)
input_ids = input_ids.type(torch.LongTensor)
return { return {
"speech_tokens": speech_tokens, "speech_tokens": speech_tokens,
"input_ids": input_ids, "messages": messages,
"attention_mask": attention_mask,
"target_ids": target_ids,
"durations": durations, "durations": durations,
"ids": ids, "ids": ids,
"lang": lang, "lang": lang,
"dnsmos": dnsmos, "dnsmos": dnsmos,
} }
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
@ -216,7 +211,10 @@ def compute_loss(
Return a tuple of two elements. The first element is the loss tensor. Return a tuple of two elements. The first element is the loss tensor.
""" """
device = next(model.parameters()).device device = next(model.parameters()).device
input_ids, attention_mask, target_ids, answer_cosyvoice_speech_token = batch["input_ids"], batch["attention_mask"], batch["target_ids"], batch["speech_tokens"] messages, answer_cosyvoice_speech_token = batch["messages"], batch["speech_tokens"]
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
target_ids = target_ids.type(torch.LongTensor)
input_ids = input_ids.type(torch.LongTensor)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
( (
@ -235,24 +233,51 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): info["frames"] = len(messages)
warnings.simplefilter("ignore")
feature_lens = batch["supervisions"]["num_frames"]
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["acc"] = acc * len(messages)
info["codec_acc"] = codec_acc * len(messages)
info["codec_topk_acc"] = codec_topk_acc * len(messages)
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["acc"] = (
acc * info["frames"]
) # WAR: to avoid normalization by the number of frames
info["codec_acc"] = codec_acc * info["frames"]
info["codec_topk_acc"] = codec_topk_acc * info["frames"]
info["codec_loss"] = codec_loss.detach().cpu().item() info["codec_loss"] = codec_loss.detach().cpu().item()
info["text_loss"] = text_loss.detach().cpu().item() info["text_loss"] = text_loss.detach().cpu().item()
return loss, info return loss, info
def compute_validation_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
model=model,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
# FIX ME
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch( def train_one_epoch(
params: AttributeDict, params: AttributeDict,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
@ -297,14 +322,14 @@ def train_one_epoch(
be set to 0. be set to 0.
""" """
model.train() model.train()
model.encoder.eval() # model.encoder.eval()
if not params.unfreeze_llm: if not params.unfreeze_llm:
model.llm.eval() model.llm.eval()
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["durations"])
if batch_idx % params.valid_interval == 0: if batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
@ -315,7 +340,7 @@ def train_one_epoch(
world_size=world_size, world_size=world_size,
) )
model.train() model.train()
model.encoder.eval() # model.encoder.eval()
if not params.unfreeze_llm: if not params.unfreeze_llm:
model.llm.eval() model.llm.eval()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
@ -373,7 +398,6 @@ def train_one_epoch(
model.step() model.step()
except: # noqa except: # noqa
display_and_save_batch(batch, params=params)
raise raise
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
@ -399,7 +423,7 @@ def train_one_epoch(
) )
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
loss_value = tot_loss["loss"] / tot_loss["frames"] loss_value = tot_loss["loss"]
params.train_loss = loss_value params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch
@ -421,6 +445,7 @@ def run(rank, world_size, args):
""" """
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params.valid_interval = 2000
fix_random_seed(params.seed) fix_random_seed(params.seed)
@ -428,9 +453,7 @@ def run(rank, world_size, args):
setup_logger(f"{params.exp_dir}/log/log-train") setup_logger(f"{params.exp_dir}/log/log-train")
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model, tokenizer = get_model(params) model, tokenizer = get_model(params)
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", get_local_rank()) device = torch.device("cuda", get_local_rank())
else: else:
@ -447,36 +470,34 @@ def run(rank, world_size, args):
sampler_state_dict = None sampler_state_dict = None
if params.sampler_state_dict_path: if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path) sampler_state_dict = torch.load(params.sampler_state_dict_path)
# print(params.dataset)
data_path = "/lustre/fsw/general_sa/yuekaiz/s2s" + "/emilia_en" ds = load_dataset(params.dataset, split="train")
ds = load_dataset(data_path, split="train") # shuffle the dataset
train_test_split = dataset.train_test_split(test_size=1000, seed=42) ds = ds.shuffle(seed=42)
train_test_split = ds.train_test_split(test_size=1000, seed=42)
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"] train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
# train_dataset, eval_dataset = train_test_split["test"], train_test_split["test"]
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_dl = StatefulDataLoader( train_dl = StatefulDataLoader(
train_dataset, train_dataset,
batch_size=2, batch_size=params.batch_size,
sampler=sampler, sampler=sampler,
shuffle=False, shuffle=False,
num_workers=1, num_workers=4,
prefetch_factor=1, prefetch_factor=2,
collate_fn=lambda features: data_collator( collate_fn=data_collator
features, tokenizer
),
) )
train_dl.load_state_dict(sampler_state_dict) train_dl.load_state_dict(sampler_state_dict)
valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank) valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
valid_dl = DataLoader( valid_dl = DataLoader(
eval_dataset, eval_dataset,
batch_size=2, batch_size=params.batch_size,
sampler=valid_sampler, sampler=valid_sampler,
shuffle=False, shuffle=False,
num_workers=1, num_workers=1,
prefetch_factor=1, prefetch_factor=1,
collate_fn=lambda features: data_collator( collate_fn=data_collator
features
),
) )
if args.tensorboard and rank == 0: if args.tensorboard and rank == 0:
@ -533,7 +554,6 @@ def run(rank, world_size, args):
logging.info("Done!") logging.info("Done!")
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()