mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
add tts training
This commit is contained in:
parent
39700d5c94
commit
1281d7a515
@ -50,6 +50,7 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import transformers
|
import transformers
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||||
from label_smoothing import LabelSmoothingLoss
|
from label_smoothing import LabelSmoothingLoss
|
||||||
@ -68,7 +69,7 @@ from transformers import (
|
|||||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||||
from torch.utils.data import DistributedSampler, DataLoader
|
from torch.utils.data import DistributedSampler, DataLoader
|
||||||
|
|
||||||
from train import add_model_arguments, add_training_arguments, get_params, compute_validation_loss, get_model, display_and_save_batch
|
from train import add_model_arguments, add_training_arguments, get_params, get_model
|
||||||
from utils import ( # filter_uneven_sized_batch,
|
from utils import ( # filter_uneven_sized_batch,
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
@ -91,12 +92,12 @@ def get_parser():
|
|||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
|
|
||||||
# parser.add_argument(
|
parser.add_argument(
|
||||||
# "--loss-type",
|
"--batch-size",
|
||||||
# type=str,
|
type=int,
|
||||||
# default="ce",
|
default=16,
|
||||||
# help="The type of loss to use.",
|
help="The batch size to use.",
|
||||||
# )
|
)
|
||||||
|
|
||||||
parser = deepspeed.add_config_arguments(parser)
|
parser = deepspeed.add_config_arguments(parser)
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -161,7 +162,7 @@ def preprocess(
|
|||||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||||
return input_ids, attention_mask, target_ids
|
return input_ids, attention_mask, target_ids
|
||||||
|
|
||||||
def data_collator(batch, tokenizer, cut_off_len=2048):
|
def data_collator(batch):
|
||||||
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
|
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
|
||||||
for i, item in enumerate(batch):
|
for i, item in enumerate(batch):
|
||||||
speech_tokens.append(item["code"])
|
speech_tokens.append(item["code"])
|
||||||
@ -176,21 +177,15 @@ def data_collator(batch, tokenizer, cut_off_len=2048):
|
|||||||
lang.append(item["language"])
|
lang.append(item["language"])
|
||||||
dnsmos.append(item["dnsmos"])
|
dnsmos.append(item["dnsmos"])
|
||||||
|
|
||||||
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
|
|
||||||
target_ids = target_ids.type(torch.LongTensor)
|
|
||||||
input_ids = input_ids.type(torch.LongTensor)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"speech_tokens": speech_tokens,
|
"speech_tokens": speech_tokens,
|
||||||
"input_ids": input_ids,
|
"messages": messages,
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"target_ids": target_ids,
|
|
||||||
"durations": durations,
|
"durations": durations,
|
||||||
"ids": ids,
|
"ids": ids,
|
||||||
"lang": lang,
|
"lang": lang,
|
||||||
"dnsmos": dnsmos,
|
"dnsmos": dnsmos,
|
||||||
}
|
}
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
@ -216,7 +211,10 @@ def compute_loss(
|
|||||||
Return a tuple of two elements. The first element is the loss tensor.
|
Return a tuple of two elements. The first element is the loss tensor.
|
||||||
"""
|
"""
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
input_ids, attention_mask, target_ids, answer_cosyvoice_speech_token = batch["input_ids"], batch["attention_mask"], batch["target_ids"], batch["speech_tokens"]
|
messages, answer_cosyvoice_speech_token = batch["messages"], batch["speech_tokens"]
|
||||||
|
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
|
||||||
|
target_ids = target_ids.type(torch.LongTensor)
|
||||||
|
input_ids = input_ids.type(torch.LongTensor)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
(
|
(
|
||||||
@ -235,24 +233,51 @@ def compute_loss(
|
|||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
info["frames"] = len(messages)
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
feature_lens = batch["supervisions"]["num_frames"]
|
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
|
info["acc"] = acc * len(messages)
|
||||||
|
info["codec_acc"] = codec_acc * len(messages)
|
||||||
|
info["codec_topk_acc"] = codec_topk_acc * len(messages)
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["acc"] = (
|
|
||||||
acc * info["frames"]
|
|
||||||
) # WAR: to avoid normalization by the number of frames
|
|
||||||
|
|
||||||
info["codec_acc"] = codec_acc * info["frames"]
|
|
||||||
info["codec_topk_acc"] = codec_topk_acc * info["frames"]
|
|
||||||
info["codec_loss"] = codec_loss.detach().cpu().item()
|
info["codec_loss"] = codec_loss.detach().cpu().item()
|
||||||
info["text_loss"] = text_loss.detach().cpu().item()
|
info["text_loss"] = text_loss.detach().cpu().item()
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
|
def compute_validation_loss(
|
||||||
|
params: AttributeDict,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
model: nn.Module,
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
world_size: int = 1,
|
||||||
|
) -> MetricsTracker:
|
||||||
|
"""Run the validation process."""
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
|
with torch.amp.autocast("cuda", enabled=params.use_fp16):
|
||||||
|
loss, loss_info = compute_loss(
|
||||||
|
params=params,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model=model,
|
||||||
|
batch=batch,
|
||||||
|
is_training=False,
|
||||||
|
)
|
||||||
|
assert loss.requires_grad is False
|
||||||
|
tot_loss = tot_loss + loss_info
|
||||||
|
|
||||||
|
# FIX ME
|
||||||
|
if world_size > 1:
|
||||||
|
tot_loss.reduce(loss.device)
|
||||||
|
|
||||||
|
loss_value = tot_loss["loss"]
|
||||||
|
if loss_value < params.best_valid_loss:
|
||||||
|
params.best_valid_epoch = params.cur_epoch
|
||||||
|
params.best_valid_loss = loss_value
|
||||||
|
|
||||||
|
return tot_loss
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
@ -297,14 +322,14 @@ def train_one_epoch(
|
|||||||
be set to 0.
|
be set to 0.
|
||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
model.encoder.eval()
|
# model.encoder.eval()
|
||||||
if not params.unfreeze_llm:
|
if not params.unfreeze_llm:
|
||||||
model.llm.eval()
|
model.llm.eval()
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["durations"])
|
||||||
if batch_idx % params.valid_interval == 0:
|
if batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
@ -315,7 +340,7 @@ def train_one_epoch(
|
|||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
model.encoder.eval()
|
# model.encoder.eval()
|
||||||
if not params.unfreeze_llm:
|
if not params.unfreeze_llm:
|
||||||
model.llm.eval()
|
model.llm.eval()
|
||||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
@ -373,7 +398,6 @@ def train_one_epoch(
|
|||||||
model.step()
|
model.step()
|
||||||
|
|
||||||
except: # noqa
|
except: # noqa
|
||||||
display_and_save_batch(batch, params=params)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
@ -399,7 +423,7 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
|
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
params.best_train_epoch = params.cur_epoch
|
params.best_train_epoch = params.cur_epoch
|
||||||
@ -421,6 +445,7 @@ def run(rank, world_size, args):
|
|||||||
"""
|
"""
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
params.valid_interval = 2000
|
||||||
|
|
||||||
fix_random_seed(params.seed)
|
fix_random_seed(params.seed)
|
||||||
|
|
||||||
@ -428,9 +453,7 @@ def run(rank, world_size, args):
|
|||||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
|
||||||
model, tokenizer = get_model(params)
|
model, tokenizer = get_model(params)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", get_local_rank())
|
device = torch.device("cuda", get_local_rank())
|
||||||
else:
|
else:
|
||||||
@ -447,36 +470,34 @@ def run(rank, world_size, args):
|
|||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
if params.sampler_state_dict_path:
|
if params.sampler_state_dict_path:
|
||||||
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
||||||
|
# print(params.dataset)
|
||||||
data_path = "/lustre/fsw/general_sa/yuekaiz/s2s" + "/emilia_en"
|
ds = load_dataset(params.dataset, split="train")
|
||||||
ds = load_dataset(data_path, split="train")
|
# shuffle the dataset
|
||||||
train_test_split = dataset.train_test_split(test_size=1000, seed=42)
|
ds = ds.shuffle(seed=42)
|
||||||
|
train_test_split = ds.train_test_split(test_size=1000, seed=42)
|
||||||
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
|
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
|
||||||
|
# train_dataset, eval_dataset = train_test_split["test"], train_test_split["test"]
|
||||||
|
|
||||||
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
|
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
|
||||||
train_dl = StatefulDataLoader(
|
train_dl = StatefulDataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=2,
|
batch_size=params.batch_size,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=1,
|
num_workers=4,
|
||||||
prefetch_factor=1,
|
prefetch_factor=2,
|
||||||
collate_fn=lambda features: data_collator(
|
collate_fn=data_collator
|
||||||
features, tokenizer
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
train_dl.load_state_dict(sampler_state_dict)
|
train_dl.load_state_dict(sampler_state_dict)
|
||||||
valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
|
valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
|
||||||
valid_dl = DataLoader(
|
valid_dl = DataLoader(
|
||||||
eval_dataset,
|
eval_dataset,
|
||||||
batch_size=2,
|
batch_size=params.batch_size,
|
||||||
sampler=valid_sampler,
|
sampler=valid_sampler,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=1,
|
num_workers=1,
|
||||||
prefetch_factor=1,
|
prefetch_factor=1,
|
||||||
collate_fn=lambda features: data_collator(
|
collate_fn=data_collator
|
||||||
features
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.tensorboard and rank == 0:
|
if args.tensorboard and rank == 0:
|
||||||
@ -533,7 +554,6 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user