2025-06-02 23:16:03 -07:00

670 lines
23 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
"""
import argparse
import copy
import logging
import os
import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from datasets import load_dataset
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM
from peft import LoraConfig, get_peft_model
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Qwen2Config,
Qwen2ForCausalLM,
)
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.utils.data import DistributedSampler, DataLoader
from pathlib import Path
from train import add_model_arguments, add_training_arguments, get_params, get_model
from utils import ( # filter_uneven_sized_batch,
AttributeDict,
MetricsTracker,
get_local_rank,
get_rank,
get_world_size,
setup_logger,
str2bool,
)
DEFAULT_SPEECH_TOKEN = "<speech>"
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--batch-size",
type=int,
default=16,
help="The batch size to use.",
)
parser = deepspeed.add_config_arguments(parser)
add_model_arguments(parser)
add_training_arguments(parser)
return parser
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocesses the data for supervised fine-tuning."""
texts = []
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
for i, msg in enumerate(messages):
texts.append(
tokenizer.apply_chat_template(
msg,
tokenize=True,
chat_template=TEMPLATE,
add_generation_prompt=False,
padding="longest", # FIX me change padding to longest
truncation=False,
)
)
if len(texts) != len(messages):
logging.warning(f"Remove too long text, {messages} ")
max_len_texts = max([len(text) for text in texts])
if tokenizer.padding_side == "right":
texts = [
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
for text in texts
]
else:
texts = [
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id <speech> with IGNORE_TOKEN_ID
# first get the indices of the tokens
mask_prompt = True
if mask_prompt:
default_speech_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_SPEECH_TOKEN)
mask_indices = torch.where(input_ids == default_speech_token_id)
for i in range(mask_indices[0].size(0)):
row = mask_indices[0][i]
col = mask_indices[1][i]
# + 2 to skip: 'assistant', '\n'
# WAR: TODO FIXME check qwen3
# THIS IS THE ONLY DIFFERENCE FROM preprocess
target_ids[row, : col + 6] = IGNORE_TOKEN_ID
target_ids[row, col] = default_speech_token_id
# remove default_speech_token_id from target_ids and input_ids
batch_size = target_ids.size(0)
target_ids = target_ids[target_ids != default_speech_token_id].view(batch_size, -1)
input_ids = input_ids[input_ids != default_speech_token_id].view(batch_size, -1)
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids
def data_collator(batch):
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
for i, item in enumerate(batch):
speech_tokens.append(item["code"])
message_list_item = []
message_list_item += [
{"role": "user", "content": f"Generate a speech from the following text:\n\n{item['text']}{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": item["text"]},
]
# message_list_item += [
# {"role": "user", "content": f"TTS{DEFAULT_SPEECH_TOKEN}"},
# {"role": "assistant", "content": item["text"]},
# ]
messages.append(message_list_item)
durations.append(item["duration"])
ids.append(item["index"] if "index" in item else item["id"])
lang.append(item["language"])
return {
"speech_tokens": speech_tokens,
"messages": messages,
"durations": durations,
"ids": ids,
"lang": lang,
}
def data_collator_concate_items(batch, concat_items_num: int = 3):
"""Concatenate *concat_items_num* consecutive dataset items into one.
The function groups the incoming ``batch`` (a list of dataset items)
into non-overlapping chunks of *concat_items_num*. For each group it
concatenates the textual fields and speech codec tokens so that the
model generates one longer utterance instead of several short ones.
Any remainder (when ``len(batch)`` is not divisible by
*concat_items_num*) is also kept as a smaller group.
"""
grouped_speech_tokens, grouped_messages, grouped_durations = [], [], []
grouped_ids, grouped_lang = [], []
# Iterate over the batch in strides of *concat_items_num*
for start_idx in range(0, len(batch), concat_items_num):
group = batch[start_idx : start_idx + concat_items_num]
if not group:
continue
# 1) Speech tokens --------------------------------------------------
# ``item['code']`` can be a list[int] or a 1-D tensor. Use the first
# element to decide how to concatenate.
first_code = group[0]["code"]
if isinstance(first_code, torch.Tensor):
concat_code = torch.cat([item["code"] for item in group], dim=0)
else:
# assume list / iterable of ints
concat_code = []
for item in group:
concat_code.extend(item["code"])
# 2) Text -----------------------------------------------------------
concat_text = "".join([item["text"] for item in group])
# 3) Build chat template messages -----------------------------------
message_list_item = [
{
"role": "user",
"content": f"Generate a speech from the following text:\n\n{concat_text}{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": concat_text},
]
# 4) Misc meta fields ----------------------------------------------
total_duration = sum(item["duration"] for item in group)
group_ids = [item.get("index", item.get("id")) for item in group]
language = group[0].get("language", "")
# 5) Append to output lists ----------------------------------------
grouped_speech_tokens.append(concat_code)
grouped_messages.append(message_list_item)
grouped_durations.append(total_duration)
grouped_ids.append(group_ids)
grouped_lang.append(language)
return {
"speech_tokens": grouped_speech_tokens,
"messages": grouped_messages,
"durations": grouped_durations,
"ids": grouped_ids,
"lang": grouped_lang,
}
def data_collator_ultra_chat(batch):
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
for i, item in enumerate(batch):
speech_tokens.append(item["custom"]["speech_token"])
text = item["supervisions"][0]["text"]
message_list_item = []
message_list_item += [
{"role": "user", "content": f"Generate a speech from the following text:\n\n{text}{DEFAULT_SPEECH_TOKEN}"},
{"role": "assistant", "content": text},
]
messages.append(message_list_item)
durations.append(item["duration"])
ids.append(item["id"])
return {
"speech_tokens": speech_tokens,
"messages": messages,
"durations": durations,
"ids": ids,
}
def compute_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute the loss for the given batch.
Args:
params:
It is returned by :func:`get_params`.
tokenizer:
The tokenizer used to encode the text.
model:
The model for training.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
Whether it is training.
Returns:
Return a tuple of two elements. The first element is the loss tensor.
"""
device = next(model.parameters()).device
messages, answer_cosyvoice_speech_token = batch["messages"], batch["speech_tokens"]
input_ids, attention_mask, target_ids = preprocess(messages, tokenizer)
target_ids = target_ids.type(torch.LongTensor)
input_ids = input_ids.type(torch.LongTensor)
with torch.set_grad_enabled(is_training):
(
text_loss,
acc,
codec_loss,
codec_acc,
codec_topk_acc,
) = model.forward_with_speech_output(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=target_ids.to(device),
speech_codec_ids=answer_cosyvoice_speech_token,
)
loss = text_loss + codec_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = len(messages)
# Note: We use reduction=sum while computing the loss.
info["acc"] = acc * len(messages)
info["codec_acc"] = codec_acc * len(messages)
info["codec_topk_acc"] = codec_topk_acc * len(messages)
info["loss"] = loss.detach().cpu().item()
info["codec_loss"] = codec_loss.detach().cpu().item()
info["text_loss"] = text_loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
model=model,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
# FIX ME
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
tokenizer: AutoTokenizer,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
scheduler:
The learning rate scheduler, we call step() every step.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
model_avg:
The stored model averaged from the start of training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.train()
# model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["durations"])
if batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
tokenizer=tokenizer,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
# model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx != 0:
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"zero-checkpoint-{params.batch_idx_train}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/checkpoint-{params.batch_idx_train}",
tag=f"zero-checkpoint-{params.batch_idx_train}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
# sampler_state_dict = train_dl.sampler.state_dict()
sampler_state_dict = train_dl.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/checkpoint-{params.batch_idx_train}/sampler.pt",
)
os.system(
f"rm -rf {params.exp_dir}/zero-checkpoint-{params.batch_idx_train}"
)
try:
with torch.amp.autocast("cuda", enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
tokenizer=tokenizer,
model=model,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
# deepspeed's backward() is different from torch's backward()
# in that it does not accept a loss tensor as input.
# It computes the loss internally.
model.backward(loss)
model.step()
except: # noqa
raise
if batch_idx % params.log_interval == 0:
try:
cur_lr = scheduler.get_last_lr()[0]
except: # noqa
cur_lr = 0.0
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
loss_value = tot_loss["loss"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
params.valid_interval = 2000
fix_random_seed(params.seed)
if rank == 0:
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info(params)
logging.info("About to create model")
model, tokenizer = get_model(params)
if torch.cuda.is_available():
device = torch.device("cuda", get_local_rank())
else:
device = torch.device("cpu")
logging.info(f"Device: {device}")
model.to(device)
assert params.deepspeed and world_size > 1
logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters()
)
sampler_state_dict = None
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
if params.dataset == "ultra_chat_voice_assistant":
data_dir = "data/fbank"
json_file_lists = ["data/fbank/cuts_voice_assistant_00001-00049.jsonl", "data/fbank/cuts_ultrachat_train.jsonl.gz"]
ds = load_dataset("json", data_files=json_file_lists, split="train")
# shuffle the dataset
train_dataset = ds.shuffle(seed=42)
eval_dataset = load_dataset("json", data_files=["data/fbank/cuts_voice_assistant.00000.jsonl"], split="train")
else:
data_dir = Path(params.dataset)
json_file_lists = [str(file) for file in data_dir.glob("*.jsonl")]
ds = load_dataset("json", data_files=json_file_lists, split="train")
# shuffle the dataset
ds = ds.shuffle(seed=42)
train_test_split = ds.train_test_split(test_size=1000, seed=42)
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_dl = StatefulDataLoader(
train_dataset,
batch_size=params.batch_size,
sampler=sampler,
shuffle=False,
num_workers=4,
prefetch_factor=2,
collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
)
train_dl.load_state_dict(sampler_state_dict)
valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
valid_dl = DataLoader(
eval_dataset,
batch_size=params.batch_size,
sampler=valid_sampler,
shuffle=False,
num_workers=1,
prefetch_factor=1,
collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
logging.info(f"start training from epoch {params.start_epoch}")
for epoch in range(params.start_epoch, params.num_epochs + 1):
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
tokenizer=tokenizer,
model=model,
optimizer=optimizer,
scheduler=scheduler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"zero-epoch-{params.cur_epoch}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}",
tag=f"zero-epoch-{params.cur_epoch}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
# sampler_state_dict = train_dl.sampler.state_dict()
sampler_state_dict = train_dl.state_dict()
torch.save(
sampler_state_dict,
f"{params.exp_dir}/epoch-{params.cur_epoch}/sampler.pt",
)
os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
logging.info("Done!")
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = get_world_size()
rank = get_rank()
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
warnings.filterwarnings("ignore", category=FutureWarning)
run(rank=rank, world_size=world_size, args=args)
if __name__ == "__main__":
main()