mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
add tts training
This commit is contained in:
parent
39700d5c94
commit
1281d7a515
@ -50,6 +50,7 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
@ -68,7 +69,7 @@ from transformers import (
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torch.utils.data import DistributedSampler, DataLoader
|
||||
|
||||
from train import add_model_arguments, add_training_arguments, get_params, compute_validation_loss, get_model, display_and_save_batch
|
||||
from train import add_model_arguments, add_training_arguments, get_params, get_model
|
||||
from utils import ( # filter_uneven_sized_batch,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
@ -91,12 +92,12 @@ def get_parser():
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
# parser.add_argument(
|
||||
# "--loss-type",
|
||||
# type=str,
|
||||
# default="ce",
|
||||
# help="The type of loss to use.",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The batch size to use.",
|
||||
)
|
||||
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
add_model_arguments(parser)
|
||||
@ -161,7 +162,7 @@ def preprocess(
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
return input_ids, attention_mask, target_ids
|
||||
|
||||
def data_collator(batch, tokenizer, cut_off_len=2048):
|
||||
def data_collator(batch):
|
||||
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
|
||||
for i, item in enumerate(batch):
|
||||
speech_tokens.append(item["code"])
|
||||
@ -176,15 +177,9 @@ def data_collator(batch, tokenizer, cut_off_len=2048):
|
||||
lang.append(item["language"])
|
||||
dnsmos.append(item["dnsmos"])
|
||||
|
||||
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
|
||||
target_ids = target_ids.type(torch.LongTensor)
|
||||
input_ids = input_ids.type(torch.LongTensor)
|
||||
|
||||
return {
|
||||
"speech_tokens": speech_tokens,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"target_ids": target_ids,
|
||||
"messages": messages,
|
||||
"durations": durations,
|
||||
"ids": ids,
|
||||
"lang": lang,
|
||||
@ -216,7 +211,10 @@ def compute_loss(
|
||||
Return a tuple of two elements. The first element is the loss tensor.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
input_ids, attention_mask, target_ids, answer_cosyvoice_speech_token = batch["input_ids"], batch["attention_mask"], batch["target_ids"], batch["speech_tokens"]
|
||||
messages, answer_cosyvoice_speech_token = batch["messages"], batch["speech_tokens"]
|
||||
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
|
||||
target_ids = target_ids.type(torch.LongTensor)
|
||||
input_ids = input_ids.type(torch.LongTensor)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
(
|
||||
@ -235,23 +233,50 @@ def compute_loss(
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
feature_lens = batch["supervisions"]["num_frames"]
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
info["frames"] = len(messages)
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["acc"] = acc * len(messages)
|
||||
info["codec_acc"] = codec_acc * len(messages)
|
||||
info["codec_topk_acc"] = codec_topk_acc * len(messages)
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["acc"] = (
|
||||
acc * info["frames"]
|
||||
) # WAR: to avoid normalization by the number of frames
|
||||
|
||||
info["codec_acc"] = codec_acc * info["frames"]
|
||||
info["codec_topk_acc"] = codec_topk_acc * info["frames"]
|
||||
info["codec_loss"] = codec_loss.detach().cpu().item()
|
||||
info["text_loss"] = text_loss.detach().cpu().item()
|
||||
return loss, info
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: AutoTokenizer,
|
||||
model: nn.Module,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
with torch.amp.autocast("cuda", enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
# FIX ME
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
loss_value = tot_loss["loss"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
@ -297,14 +322,14 @@ def train_one_epoch(
|
||||
be set to 0.
|
||||
"""
|
||||
model.train()
|
||||
model.encoder.eval()
|
||||
# model.encoder.eval()
|
||||
if not params.unfreeze_llm:
|
||||
model.llm.eval()
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
batch_size = len(batch["durations"])
|
||||
if batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
@ -315,7 +340,7 @@ def train_one_epoch(
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
model.encoder.eval()
|
||||
# model.encoder.eval()
|
||||
if not params.unfreeze_llm:
|
||||
model.llm.eval()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
@ -373,7 +398,6 @@ def train_one_epoch(
|
||||
model.step()
|
||||
|
||||
except: # noqa
|
||||
display_and_save_batch(batch, params=params)
|
||||
raise
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
@ -399,7 +423,7 @@ def train_one_epoch(
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
loss_value = tot_loss["loss"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
@ -421,6 +445,7 @@ def run(rank, world_size, args):
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.valid_interval = 2000
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
|
||||
@ -428,9 +453,7 @@ def run(rank, world_size, args):
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info(params)
|
||||
logging.info("About to create model")
|
||||
|
||||
model, tokenizer = get_model(params)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", get_local_rank())
|
||||
else:
|
||||
@ -447,36 +470,34 @@ def run(rank, world_size, args):
|
||||
sampler_state_dict = None
|
||||
if params.sampler_state_dict_path:
|
||||
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
||||
|
||||
data_path = "/lustre/fsw/general_sa/yuekaiz/s2s" + "/emilia_en"
|
||||
ds = load_dataset(data_path, split="train")
|
||||
train_test_split = dataset.train_test_split(test_size=1000, seed=42)
|
||||
# print(params.dataset)
|
||||
ds = load_dataset(params.dataset, split="train")
|
||||
# shuffle the dataset
|
||||
ds = ds.shuffle(seed=42)
|
||||
train_test_split = ds.train_test_split(test_size=1000, seed=42)
|
||||
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
|
||||
# train_dataset, eval_dataset = train_test_split["test"], train_test_split["test"]
|
||||
|
||||
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
|
||||
train_dl = StatefulDataLoader(
|
||||
train_dataset,
|
||||
batch_size=2,
|
||||
batch_size=params.batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
num_workers=1,
|
||||
prefetch_factor=1,
|
||||
collate_fn=lambda features: data_collator(
|
||||
features, tokenizer
|
||||
),
|
||||
num_workers=4,
|
||||
prefetch_factor=2,
|
||||
collate_fn=data_collator
|
||||
)
|
||||
train_dl.load_state_dict(sampler_state_dict)
|
||||
valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
|
||||
valid_dl = DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=2,
|
||||
batch_size=params.batch_size,
|
||||
sampler=valid_sampler,
|
||||
shuffle=False,
|
||||
num_workers=1,
|
||||
prefetch_factor=1,
|
||||
collate_fn=lambda features: data_collator(
|
||||
features
|
||||
),
|
||||
collate_fn=data_collator
|
||||
)
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
@ -533,7 +554,6 @@ def run(rank, world_size, args):
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user