fix tts stage decode

This commit is contained in:
root 2025-05-28 02:34:07 +00:00
parent 5a7c72cb47
commit 49256fa917
4 changed files with 252 additions and 45 deletions

View File

@ -2,8 +2,8 @@
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
set -eou pipefail set -eou pipefail
stage=$1 stage=$1
@ -121,7 +121,6 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
$train_cmd_args $train_cmd_args
fi fi
export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
log "stage 19: Training TTS Model" log "stage 19: Training TTS Model"
exp_dir=./qwen_omni/exp_tts exp_dir=./qwen_omni/exp_tts
@ -218,16 +217,17 @@ if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
$train_cmd_args $train_cmd_args
fi fi
if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
log "stage 21: TTS Decoding Test Set" log "stage 21: TTS Decoding Test Set"
exp_dir=./qwen_omni/exp_tts exp_dir=./qwen_omni/exp_tts
torchrun --nproc_per_node=4 python3 ./qwen_omni/decode_tts.py \ torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \
--exp-dir $exp_dir \ --exp-dir $exp_dir \
--speech-encoder-path-or-name models/large-v2.pt \ --speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \ --pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
--use-flash-attn True \ --use-flash-attn True \
--enable-speech-output True \ --enable-speech-output True \
--token2wav-path /lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice2-0.5B \ --token2wav-path /workspace/CosyVoice2-0.5B \
--use-lora True --use-lora True
fi fi

View File

@ -63,16 +63,9 @@ from model import SPEECH_LLM, EncoderProjector
from peft import LoraConfig, get_peft_model from peft import LoraConfig, get_peft_model
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
from utils import AttributeDict, setup_logger, store_transcripts, write_error_stats
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS") sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
@ -418,11 +411,7 @@ def get_parser():
def get_params() -> AttributeDict: def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict({})
{
"env_info": get_env_info(),
}
)
return params return params

View File

@ -40,36 +40,34 @@ import copy
import logging import logging
import os import os
import random import random
import sys
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import soundfile as sf
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import transformers import transformers
from datasets import load_dataset from cosyvoice.cli.cosyvoice import CosyVoice2
from datasets import Audio, load_dataset
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from decode import audio_decode_cosyvoice2
from label_smoothing import LabelSmoothingLoss from label_smoothing import LabelSmoothingLoss
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM from model import IGNORE_TOKEN_ID, SPEECH_LLM
from peft import LoraConfig, get_peft_model from peft import LoraConfig, get_peft_model
from torch import Tensor from torch import Tensor
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from train import add_model_arguments, add_training_arguments, get_model, get_params
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
Qwen2Config, Qwen2Config,
Qwen2ForCausalLM, Qwen2ForCausalLM,
) )
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.utils.data import DistributedSampler, DataLoader
from train import add_model_arguments, add_training_arguments, get_params, get_model
from decode import audio_decode_cosyvoice2
from utils import ( # filter_uneven_sized_batch, from utils import ( # filter_uneven_sized_batch,
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
@ -79,9 +77,9 @@ from utils import ( # filter_uneven_sized_batch,
setup_logger, setup_logger,
str2bool, str2bool,
) )
from cosyvoice.cli.cosyvoice import CosyVoice2
sys.path.append("/lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice/third_party/Matcha-TTS")
# sys.path.append("/lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice/third_party/Matcha-TTS")
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
DEFAULT_SPEECH_TOKEN = "<speech>" DEFAULT_SPEECH_TOKEN = "<speech>"
try: try:
torch.multiprocessing.set_start_method("spawn") torch.multiprocessing.set_start_method("spawn")
@ -116,9 +114,11 @@ def get_parser():
) )
add_model_arguments(parser) add_model_arguments(parser)
add_training_arguments(parser)
return parser return parser
def preprocess( def preprocess(
messages, messages,
tokenizer: transformers.PreTrainedTokenizer, tokenizer: transformers.PreTrainedTokenizer,
@ -177,30 +177,41 @@ def preprocess(
attention_mask = input_ids.ne(tokenizer.pad_token_id) attention_mask = input_ids.ne(tokenizer.pad_token_id)
return input_ids, attention_mask, target_ids return input_ids, attention_mask, target_ids
def data_collator(batch): def data_collator(batch):
prompt_texts, prompt_speech_16k, messages, ids = [], [], [], [] prompt_texts, prompt_speech_16k, messages, ids, target_texts = [], [], [], [], []
for i, item in enumerate(batch): for i, item in enumerate(batch):
# speech_tokens.append(item["prompt_audio_cosy2_tokens"]) # speech_tokens.append(item["prompt_audio_cosy2_tokens"])
message_list_item = [] message_list_item = []
message_list_item += [ message_list_item += [
{"role": "user", "content": f"Generate a speech from the following text:\n\n{item['target_text']}{DEFAULT_SPEECH_TOKEN}"}, {
"role": "user",
"content": f"Generate a speech from the following text:\n\n{item['target_text']}{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": ""}, {"role": "assistant", "content": ""},
] ]
messages.append(message_list_item) messages.append(message_list_item)
target_texts.append(item["target_text"])
ids.append(item["id"]) ids.append(item["id"])
prompt_texts.append(item["prompt_text"]) prompt_texts.append(item["prompt_text"])
prompt_speech_16k.append(item["prompt_audio"]) speech_org = item["prompt_audio"]
print(item["prompt_audio"], 233333333333333333)
speech_org = torch.tensor(speech_org["array"], dtype=torch.float32).unsqueeze(0)
speech_org = speech_org.mean(dim=0, keepdim=True)
prompt_speech_16k.append(speech_org)
# resample to 16k
return { return {
"prompt_texts": prompt_texts, "prompt_texts": prompt_texts,
"target_texts": target_texts,
"prompt_speech_16k": prompt_speech_16k, "prompt_speech_16k": prompt_speech_16k,
"messages": messages, "messages": messages,
"ids": ids, "ids": ids,
} }
def run(rank, world_size, args): def run(rank, world_size, args):
""" """
Args: Args:
@ -215,7 +226,7 @@ def run(rank, world_size, args):
""" """
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params.log_dir = Path(params.exp_dir) / f"log-results-wav" params.log_dir = Path(params.exp_dir) / "log-results-wav"
params.log_dir.mkdir(parents=True, exist_ok=True) params.log_dir.mkdir(parents=True, exist_ok=True)
fix_random_seed(params.seed) fix_random_seed(params.seed)
@ -232,11 +243,9 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
model.to(device) model.to(device)
assert params.deepspeed and world_size > 1
logging.info("Using DeepSpeed")
dataset = load_dataset("yuekai/seed_tts_cosy2", split=params.split_name) dataset = load_dataset("yuekai/seed_tts_cosy2", split=params.split_name)
dataset = dataset.cast_column("prompt_audio", Audio(sampling_rate=16000))
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
data_loader = DataLoader( data_loader = DataLoader(
dataset, dataset,
@ -245,7 +254,7 @@ def run(rank, world_size, args):
shuffle=False, shuffle=False,
num_workers=1, num_workers=1,
prefetch_factor=1, prefetch_factor=1,
collate_fn=data_collator collate_fn=data_collator,
) )
token2wav_model = CosyVoice2( token2wav_model = CosyVoice2(
params.token2wav_path, load_jit=False, load_trt=False, fp16=False params.token2wav_path, load_jit=False, load_trt=False, fp16=False
@ -254,6 +263,7 @@ def run(rank, world_size, args):
messages = batch["messages"] messages = batch["messages"]
prompt_texts = batch["prompt_texts"] prompt_texts = batch["prompt_texts"]
prompt_speech_16k = batch["prompt_speech_16k"] prompt_speech_16k = batch["prompt_speech_16k"]
target_texts = batch["target_texts"]
ids = batch["ids"] ids = batch["ids"]
input_ids, attention_mask, _ = preprocess(messages, tokenizer) input_ids, attention_mask, _ = preprocess(messages, tokenizer)
generated_ids, generated_speech_output = model.decode_with_speech_output( generated_ids, generated_speech_output = model.decode_with_speech_output(
@ -262,8 +272,13 @@ def run(rank, world_size, args):
generated_speech_output = [ generated_speech_output = [
generated_speech_output generated_speech_output
] # WAR: only support batch = 1 for now ] # WAR: only support batch = 1 for now
for cut_id, audio_tokens, prompt_text, prompt_speech in zip(ids, generated_speech_output, prompt_texts, prompt_speech_16k): for cut_id, audio_tokens, prompt_text, prompt_speech, target_text in zip(
ids, generated_speech_output, prompt_texts, prompt_speech_16k, target_texts
):
speech_file_name = params.log_dir / f"{cut_id}.wav" speech_file_name = params.log_dir / f"{cut_id}.wav"
# save target_text to file
with open(params.log_dir / f"{cut_id}.txt", "w") as f:
f.write(f"{target_text}\n")
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0) audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
if "CosyVoice2" in params.token2wav_path: if "CosyVoice2" in params.token2wav_path:
audio_hat = audio_decode_cosyvoice2( audio_hat = audio_decode_cosyvoice2(
@ -276,6 +291,7 @@ def run(rank, world_size, args):
logging.info("Done!") logging.info("Done!")
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
@ -285,7 +301,7 @@ def main():
rank = get_rank() rank = get_rank()
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) # torch.set_num_interop_threads(1)
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
run(rank=rank, world_size=world_size, args=args) run(rank=rank, world_size=world_size, args=args)

View File

@ -11,15 +11,16 @@ from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
import kaldialign
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
Pathlike = Union[str, Path] Pathlike = Union[str, Path]
def get_world_size(): def get_world_size():
if "WORLD_SIZE" in os.environ: if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"]) return int(os.environ["WORLD_SIZE"])
@ -37,6 +38,7 @@ def get_rank():
else: else:
return 0 return 0
def get_local_rank(): def get_local_rank():
if "LOCAL_RANK" in os.environ: if "LOCAL_RANK" in os.environ:
return int(os.environ["LOCAL_RANK"]) return int(os.environ["LOCAL_RANK"])
@ -45,6 +47,7 @@ def get_local_rank():
else: else:
return 0 return 0
def str2bool(v): def str2bool(v):
"""Used in argparse.ArgumentParser.add_argument to indicate """Used in argparse.ArgumentParser.add_argument to indicate
that a type is a bool type and user can enter that a type is a bool type and user can enter
@ -63,6 +66,7 @@ def str2bool(v):
else: else:
raise argparse.ArgumentTypeError("Boolean value expected.") raise argparse.ArgumentTypeError("Boolean value expected.")
class AttributeDict(dict): class AttributeDict(dict):
def __getattr__(self, key): def __getattr__(self, key):
if key in self: if key in self:
@ -87,6 +91,7 @@ class AttributeDict(dict):
tmp[k] = v tmp[k] = v
return json.dumps(tmp, indent=indent, sort_keys=True) return json.dumps(tmp, indent=indent, sort_keys=True)
def setup_logger( def setup_logger(
log_filename: Pathlike, log_filename: Pathlike,
log_level: str = "info", log_level: str = "info",
@ -139,6 +144,7 @@ def setup_logger(
console.setFormatter(logging.Formatter(formatter)) console.setFormatter(logging.Formatter(formatter))
logging.getLogger("").addHandler(console) logging.getLogger("").addHandler(console)
class MetricsTracker(collections.defaultdict): class MetricsTracker(collections.defaultdict):
def __init__(self): def __init__(self):
# Passing the type 'int' to the base-class constructor # Passing the type 'int' to the base-class constructor
@ -228,4 +234,200 @@ class MetricsTracker(collections.defaultdict):
batch_idx: The current batch index, used as the x-axis of the plot. batch_idx: The current batch index, used as the x-axis of the plot.
""" """
for k, v in self.norm_items(): for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx) tb_writer.add_scalar(prefix + k, v, batch_idx)
def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
) -> None:
"""Save predicted results and reference transcripts to a file.
Args:
filename:
File to save the results to.
texts:
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
If it is a multi-talker ASR system, the ref and hyp may also be lists of
strings.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf8") as f:
for cut_id, ref, hyp in texts:
if char_level:
ref = list("".join(ref))
hyp = list("".join(hyp))
print(f"{cut_id}:\tref={ref}", file=f)
print(f"{cut_id}:\thyp={hyp}", file=f)
def write_error_stats(
f: TextIO,
test_set_name: str,
results: List[Tuple[str, str]],
enable_log: bool = True,
compute_CER: bool = False,
sclite_mode: bool = False,
) -> float:
"""Write statistics based on predicted results and reference transcripts.
It will write the following to the given file:
- WER
- number of insertions, deletions, substitutions, corrects and total
reference words. For example::
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
reference words (2337 correct)
- The difference between the reference transcript and predicted result.
An instance is given below::
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
The above example shows that the reference word is `EDISON`,
but it is predicted to `ADDISON` (a substitution error).
Another example is::
FOR THE FIRST DAY (SIR->*) I THINK
The reference word `SIR` is missing in the predicted
results (a deletion error).
results:
An iterable of tuples. The first element is the cut_id, the second is
the reference transcript and the third element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Returns:
Return None.
"""
subs: Dict[Tuple[str, str], int] = defaultdict(int)
ins: Dict[str, int] = defaultdict(int)
dels: Dict[str, int] = defaultdict(int)
# `words` stores counts per word, as follows:
# corr, ref_sub, hyp_sub, ins, dels
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
if compute_CER:
for i, res in enumerate(results):
cut_id, ref, hyp = res
ref = list("".join(ref))
hyp = list("".join(hyp))
results[i] = (cut_id, ref, hyp)
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1
words[hyp_word][3] += 1
elif hyp_word == ERR:
dels[ref_word] += 1
words[ref_word][4] += 1
elif hyp_word != ref_word:
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
else:
words[ref_word][0] += 1
num_corr += 1
ref_len = sum([len(r) for _, r, _ in results])
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
if enable_log:
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
print(f"%WER = {tot_err_rate}", file=f)
print(
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
f"{sub_errs} substitutions, over {ref_len} reference "
f"words ({num_corr} correct)",
file=f,
)
print(
"Search below for sections starting with PER-UTT DETAILS:, "
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
file=f,
)
print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
combine_successive_errors = True
if combine_successive_errors:
ali = [[[x], [y]] for x, y in ali]
for i in range(len(ali) - 1):
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
ali[i] = [[], []]
ali = [
[
list(filter(lambda a: a != ERR, x)),
list(filter(lambda a: a != ERR, y)),
]
for x, y in ali
]
ali = list(filter(lambda x: x != [[], []], ali))
ali = [
[
ERR if x == [] else " ".join(x),
ERR if y == [] else " ".join(y),
]
for x, y in ali
]
print(
f"{cut_id}:\t"
+ " ".join(
(
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali
)
),
file=f,
)
print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f)
print("DELETIONS: count ref", file=f)
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
print(f"{count} {ref}", file=f)
print("", file=f)
print("INSERTIONS: count hyp", file=f)
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
print(f"{count} {hyp}", file=f)
print("", file=f)
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
):
(corr, ref_sub, hyp_sub, ins, dels) = counts
tot_errs = ref_sub + hyp_sub + ins + dels
ref_count = corr + ref_sub + dels
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate)