fix tts stage decode

This commit is contained in:
root 2025-05-28 02:34:07 +00:00
parent 5a7c72cb47
commit 49256fa917
4 changed files with 252 additions and 45 deletions

View File

@ -2,8 +2,8 @@
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
set -eou pipefail
stage=$1
@ -121,7 +121,6 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
$train_cmd_args
fi
export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
log "stage 19: Training TTS Model"
exp_dir=./qwen_omni/exp_tts
@ -218,16 +217,17 @@ if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
$train_cmd_args
fi
if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
log "stage 21: TTS Decoding Test Set"
exp_dir=./qwen_omni/exp_tts
torchrun --nproc_per_node=4 python3 ./qwen_omni/decode_tts.py \
torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
--use-flash-attn True \
--enable-speech-output True \
--token2wav-path /lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice2-0.5B \
--token2wav-path /workspace/CosyVoice2-0.5B \
--use-lora True
fi
fi

View File

@ -63,16 +63,9 @@ from model import SPEECH_LLM, EncoderProjector
from peft import LoraConfig, get_peft_model
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
from utils import AttributeDict, setup_logger, store_transcripts, write_error_stats
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
@ -418,11 +411,7 @@ def get_parser():
def get_params() -> AttributeDict:
params = AttributeDict(
{
"env_info": get_env_info(),
}
)
params = AttributeDict({})
return params

View File

@ -40,36 +40,34 @@ import copy
import logging
import os
import random
import sys
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import soundfile as sf
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from datasets import load_dataset
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from cosyvoice.cli.cosyvoice import CosyVoice2
from datasets import Audio, load_dataset
from decode import audio_decode_cosyvoice2
from label_smoothing import LabelSmoothingLoss
from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM
from peft import LoraConfig, get_peft_model
from torch import Tensor
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from train import add_model_arguments, add_training_arguments, get_model, get_params
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Qwen2Config,
Qwen2ForCausalLM,
)
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.utils.data import DistributedSampler, DataLoader
from train import add_model_arguments, add_training_arguments, get_params, get_model
from decode import audio_decode_cosyvoice2
from utils import ( # filter_uneven_sized_batch,
AttributeDict,
MetricsTracker,
@ -79,9 +77,9 @@ from utils import ( # filter_uneven_sized_batch,
setup_logger,
str2bool,
)
from cosyvoice.cli.cosyvoice import CosyVoice2
sys.path.append("/lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice/third_party/Matcha-TTS")
# sys.path.append("/lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice/third_party/Matcha-TTS")
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
DEFAULT_SPEECH_TOKEN = "<speech>"
try:
torch.multiprocessing.set_start_method("spawn")
@ -116,9 +114,11 @@ def get_parser():
)
add_model_arguments(parser)
add_training_arguments(parser)
return parser
def preprocess(
messages,
tokenizer: transformers.PreTrainedTokenizer,
@ -177,30 +177,41 @@ def preprocess(
attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids
def data_collator(batch):
prompt_texts, prompt_speech_16k, messages, ids = [], [], [], []
prompt_texts, prompt_speech_16k, messages, ids, target_texts = [], [], [], [], []
for i, item in enumerate(batch):
# speech_tokens.append(item["prompt_audio_cosy2_tokens"])
message_list_item = []
message_list_item += [
{"role": "user", "content": f"Generate a speech from the following text:\n\n{item['target_text']}{DEFAULT_SPEECH_TOKEN}"},
{
"role": "user",
"content": f"Generate a speech from the following text:\n\n{item['target_text']}{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": ""},
]
messages.append(message_list_item)
target_texts.append(item["target_text"])
ids.append(item["id"])
prompt_texts.append(item["prompt_text"])
prompt_speech_16k.append(item["prompt_audio"])
print(item["prompt_audio"], 233333333333333333)
speech_org = item["prompt_audio"]
speech_org = torch.tensor(speech_org["array"], dtype=torch.float32).unsqueeze(0)
speech_org = speech_org.mean(dim=0, keepdim=True)
prompt_speech_16k.append(speech_org)
# resample to 16k
return {
"prompt_texts": prompt_texts,
"target_texts": target_texts,
"prompt_speech_16k": prompt_speech_16k,
"messages": messages,
"ids": ids,
}
def run(rank, world_size, args):
"""
Args:
@ -215,7 +226,7 @@ def run(rank, world_size, args):
"""
params = get_params()
params.update(vars(args))
params.log_dir = Path(params.exp_dir) / f"log-results-wav"
params.log_dir = Path(params.exp_dir) / "log-results-wav"
params.log_dir.mkdir(parents=True, exist_ok=True)
fix_random_seed(params.seed)
@ -232,11 +243,9 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}")
model.to(device)
assert params.deepspeed and world_size > 1
logging.info("Using DeepSpeed")
dataset = load_dataset("yuekai/seed_tts_cosy2", split=params.split_name)
dataset = dataset.cast_column("prompt_audio", Audio(sampling_rate=16000))
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
data_loader = DataLoader(
dataset,
@ -245,7 +254,7 @@ def run(rank, world_size, args):
shuffle=False,
num_workers=1,
prefetch_factor=1,
collate_fn=data_collator
collate_fn=data_collator,
)
token2wav_model = CosyVoice2(
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
@ -254,6 +263,7 @@ def run(rank, world_size, args):
messages = batch["messages"]
prompt_texts = batch["prompt_texts"]
prompt_speech_16k = batch["prompt_speech_16k"]
target_texts = batch["target_texts"]
ids = batch["ids"]
input_ids, attention_mask, _ = preprocess(messages, tokenizer)
generated_ids, generated_speech_output = model.decode_with_speech_output(
@ -262,8 +272,13 @@ def run(rank, world_size, args):
generated_speech_output = [
generated_speech_output
] # WAR: only support batch = 1 for now
for cut_id, audio_tokens, prompt_text, prompt_speech in zip(ids, generated_speech_output, prompt_texts, prompt_speech_16k):
for cut_id, audio_tokens, prompt_text, prompt_speech, target_text in zip(
ids, generated_speech_output, prompt_texts, prompt_speech_16k, target_texts
):
speech_file_name = params.log_dir / f"{cut_id}.wav"
# save target_text to file
with open(params.log_dir / f"{cut_id}.txt", "w") as f:
f.write(f"{target_text}\n")
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
if "CosyVoice2" in params.token2wav_path:
audio_hat = audio_decode_cosyvoice2(
@ -276,6 +291,7 @@ def run(rank, world_size, args):
logging.info("Done!")
def main():
parser = get_parser()
args = parser.parse_args()
@ -285,7 +301,7 @@ def main():
rank = get_rank()
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
# torch.set_num_interop_threads(1)
warnings.filterwarnings("ignore", category=FutureWarning)
run(rank=rank, world_size=world_size, args=args)

View File

@ -11,15 +11,16 @@ from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
import kaldialign
import torch
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
Pathlike = Union[str, Path]
def get_world_size():
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"])
@ -37,6 +38,7 @@ def get_rank():
else:
return 0
def get_local_rank():
if "LOCAL_RANK" in os.environ:
return int(os.environ["LOCAL_RANK"])
@ -45,6 +47,7 @@ def get_local_rank():
else:
return 0
def str2bool(v):
"""Used in argparse.ArgumentParser.add_argument to indicate
that a type is a bool type and user can enter
@ -63,6 +66,7 @@ def str2bool(v):
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
class AttributeDict(dict):
def __getattr__(self, key):
if key in self:
@ -87,6 +91,7 @@ class AttributeDict(dict):
tmp[k] = v
return json.dumps(tmp, indent=indent, sort_keys=True)
def setup_logger(
log_filename: Pathlike,
log_level: str = "info",
@ -139,6 +144,7 @@ def setup_logger(
console.setFormatter(logging.Formatter(formatter))
logging.getLogger("").addHandler(console)
class MetricsTracker(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
@ -228,4 +234,200 @@ class MetricsTracker(collections.defaultdict):
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
tb_writer.add_scalar(prefix + k, v, batch_idx)
def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
) -> None:
"""Save predicted results and reference transcripts to a file.
Args:
filename:
File to save the results to.
texts:
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
If it is a multi-talker ASR system, the ref and hyp may also be lists of
strings.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf8") as f:
for cut_id, ref, hyp in texts:
if char_level:
ref = list("".join(ref))
hyp = list("".join(hyp))
print(f"{cut_id}:\tref={ref}", file=f)
print(f"{cut_id}:\thyp={hyp}", file=f)
def write_error_stats(
f: TextIO,
test_set_name: str,
results: List[Tuple[str, str]],
enable_log: bool = True,
compute_CER: bool = False,
sclite_mode: bool = False,
) -> float:
"""Write statistics based on predicted results and reference transcripts.
It will write the following to the given file:
- WER
- number of insertions, deletions, substitutions, corrects and total
reference words. For example::
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
reference words (2337 correct)
- The difference between the reference transcript and predicted result.
An instance is given below::
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
The above example shows that the reference word is `EDISON`,
but it is predicted to `ADDISON` (a substitution error).
Another example is::
FOR THE FIRST DAY (SIR->*) I THINK
The reference word `SIR` is missing in the predicted
results (a deletion error).
results:
An iterable of tuples. The first element is the cut_id, the second is
the reference transcript and the third element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Returns:
Return None.
"""
subs: Dict[Tuple[str, str], int] = defaultdict(int)
ins: Dict[str, int] = defaultdict(int)
dels: Dict[str, int] = defaultdict(int)
# `words` stores counts per word, as follows:
# corr, ref_sub, hyp_sub, ins, dels
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
if compute_CER:
for i, res in enumerate(results):
cut_id, ref, hyp = res
ref = list("".join(ref))
hyp = list("".join(hyp))
results[i] = (cut_id, ref, hyp)
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1
words[hyp_word][3] += 1
elif hyp_word == ERR:
dels[ref_word] += 1
words[ref_word][4] += 1
elif hyp_word != ref_word:
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
else:
words[ref_word][0] += 1
num_corr += 1
ref_len = sum([len(r) for _, r, _ in results])
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
if enable_log:
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
print(f"%WER = {tot_err_rate}", file=f)
print(
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
f"{sub_errs} substitutions, over {ref_len} reference "
f"words ({num_corr} correct)",
file=f,
)
print(
"Search below for sections starting with PER-UTT DETAILS:, "
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
file=f,
)
print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
combine_successive_errors = True
if combine_successive_errors:
ali = [[[x], [y]] for x, y in ali]
for i in range(len(ali) - 1):
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
ali[i] = [[], []]
ali = [
[
list(filter(lambda a: a != ERR, x)),
list(filter(lambda a: a != ERR, y)),
]
for x, y in ali
]
ali = list(filter(lambda x: x != [[], []], ali))
ali = [
[
ERR if x == [] else " ".join(x),
ERR if y == [] else " ".join(y),
]
for x, y in ali
]
print(
f"{cut_id}:\t"
+ " ".join(
(
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali
)
),
file=f,
)
print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f)
print("DELETIONS: count ref", file=f)
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
print(f"{count} {ref}", file=f)
print("", file=f)
print("INSERTIONS: count hyp", file=f)
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
print(f"{count} {hyp}", file=f)
print("", file=f)
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
):
(corr, ref_sub, hyp_sub, ins, dels) = counts
tot_errs = ref_sub + hyp_sub + ins + dels
ref_count = corr + ref_sub + dels
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate)