mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
fix tts stage decode
This commit is contained in:
parent
5a7c72cb47
commit
49256fa917
@ -2,8 +2,8 @@
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
||||
export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
|
||||
set -eou pipefail
|
||||
|
||||
stage=$1
|
||||
@ -121,7 +121,6 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
|
||||
$train_cmd_args
|
||||
fi
|
||||
|
||||
export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
|
||||
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
log "stage 19: Training TTS Model"
|
||||
exp_dir=./qwen_omni/exp_tts
|
||||
@ -218,16 +217,17 @@ if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
$train_cmd_args
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
|
||||
log "stage 21: TTS Decoding Test Set"
|
||||
exp_dir=./qwen_omni/exp_tts
|
||||
torchrun --nproc_per_node=4 python3 ./qwen_omni/decode_tts.py \
|
||||
torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
|
||||
--use-flash-attn True \
|
||||
--enable-speech-output True \
|
||||
--token2wav-path /lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice2-0.5B \
|
||||
--token2wav-path /workspace/CosyVoice2-0.5B \
|
||||
--use-lora True
|
||||
fi
|
||||
fi
|
||||
|
@ -63,16 +63,9 @@ from model import SPEECH_LLM, EncoderProjector
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
||||
from utils import AttributeDict, setup_logger, store_transcripts, write_error_stats
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
|
||||
|
||||
@ -418,11 +411,7 @@ def get_parser():
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
params = AttributeDict({})
|
||||
return params
|
||||
|
||||
|
||||
|
@ -40,36 +40,34 @@ import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
from datasets import Audio, load_dataset
|
||||
from decode import audio_decode_cosyvoice2
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from train import add_model_arguments, add_training_arguments, get_model, get_params
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
Qwen2Config,
|
||||
Qwen2ForCausalLM,
|
||||
)
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torch.utils.data import DistributedSampler, DataLoader
|
||||
|
||||
from train import add_model_arguments, add_training_arguments, get_params, get_model
|
||||
from decode import audio_decode_cosyvoice2
|
||||
from utils import ( # filter_uneven_sized_batch,
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
@ -79,9 +77,9 @@ from utils import ( # filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
sys.path.append("/lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice/third_party/Matcha-TTS")
|
||||
|
||||
# sys.path.append("/lustre/fsw/general_sa/yuekaiz/s2s/CosyVoice/third_party/Matcha-TTS")
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
DEFAULT_SPEECH_TOKEN = "<speech>"
|
||||
try:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
@ -116,9 +114,11 @@ def get_parser():
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
add_training_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
def preprocess(
|
||||
messages,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
@ -177,30 +177,41 @@ def preprocess(
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
return input_ids, attention_mask, target_ids
|
||||
|
||||
|
||||
def data_collator(batch):
|
||||
prompt_texts, prompt_speech_16k, messages, ids = [], [], [], []
|
||||
prompt_texts, prompt_speech_16k, messages, ids, target_texts = [], [], [], [], []
|
||||
for i, item in enumerate(batch):
|
||||
# speech_tokens.append(item["prompt_audio_cosy2_tokens"])
|
||||
message_list_item = []
|
||||
message_list_item += [
|
||||
{"role": "user", "content": f"Generate a speech from the following text:\n\n{item['target_text']}{DEFAULT_SPEECH_TOKEN}"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Generate a speech from the following text:\n\n{item['target_text']}{DEFAULT_SPEECH_TOKEN}",
|
||||
},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
messages.append(message_list_item)
|
||||
target_texts.append(item["target_text"])
|
||||
|
||||
ids.append(item["id"])
|
||||
prompt_texts.append(item["prompt_text"])
|
||||
prompt_speech_16k.append(item["prompt_audio"])
|
||||
print(item["prompt_audio"], 233333333333333333)
|
||||
speech_org = item["prompt_audio"]
|
||||
|
||||
speech_org = torch.tensor(speech_org["array"], dtype=torch.float32).unsqueeze(0)
|
||||
speech_org = speech_org.mean(dim=0, keepdim=True)
|
||||
prompt_speech_16k.append(speech_org)
|
||||
|
||||
# resample to 16k
|
||||
|
||||
return {
|
||||
"prompt_texts": prompt_texts,
|
||||
"target_texts": target_texts,
|
||||
"prompt_speech_16k": prompt_speech_16k,
|
||||
"messages": messages,
|
||||
"ids": ids,
|
||||
}
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
@ -215,7 +226,7 @@ def run(rank, world_size, args):
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.log_dir = Path(params.exp_dir) / f"log-results-wav"
|
||||
params.log_dir = Path(params.exp_dir) / "log-results-wav"
|
||||
params.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
@ -232,11 +243,9 @@ def run(rank, world_size, args):
|
||||
logging.info(f"Device: {device}")
|
||||
model.to(device)
|
||||
|
||||
assert params.deepspeed and world_size > 1
|
||||
logging.info("Using DeepSpeed")
|
||||
|
||||
dataset = load_dataset("yuekai/seed_tts_cosy2", split=params.split_name)
|
||||
|
||||
dataset = dataset.cast_column("prompt_audio", Audio(sampling_rate=16000))
|
||||
|
||||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
@ -245,7 +254,7 @@ def run(rank, world_size, args):
|
||||
shuffle=False,
|
||||
num_workers=1,
|
||||
prefetch_factor=1,
|
||||
collate_fn=data_collator
|
||||
collate_fn=data_collator,
|
||||
)
|
||||
token2wav_model = CosyVoice2(
|
||||
params.token2wav_path, load_jit=False, load_trt=False, fp16=False
|
||||
@ -254,6 +263,7 @@ def run(rank, world_size, args):
|
||||
messages = batch["messages"]
|
||||
prompt_texts = batch["prompt_texts"]
|
||||
prompt_speech_16k = batch["prompt_speech_16k"]
|
||||
target_texts = batch["target_texts"]
|
||||
ids = batch["ids"]
|
||||
input_ids, attention_mask, _ = preprocess(messages, tokenizer)
|
||||
generated_ids, generated_speech_output = model.decode_with_speech_output(
|
||||
@ -262,8 +272,13 @@ def run(rank, world_size, args):
|
||||
generated_speech_output = [
|
||||
generated_speech_output
|
||||
] # WAR: only support batch = 1 for now
|
||||
for cut_id, audio_tokens, prompt_text, prompt_speech in zip(ids, generated_speech_output, prompt_texts, prompt_speech_16k):
|
||||
for cut_id, audio_tokens, prompt_text, prompt_speech, target_text in zip(
|
||||
ids, generated_speech_output, prompt_texts, prompt_speech_16k, target_texts
|
||||
):
|
||||
speech_file_name = params.log_dir / f"{cut_id}.wav"
|
||||
# save target_text to file
|
||||
with open(params.log_dir / f"{cut_id}.txt", "w") as f:
|
||||
f.write(f"{target_text}\n")
|
||||
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
|
||||
if "CosyVoice2" in params.token2wav_path:
|
||||
audio_hat = audio_decode_cosyvoice2(
|
||||
@ -276,6 +291,7 @@ def run(rank, world_size, args):
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
@ -285,7 +301,7 @@ def main():
|
||||
rank = get_rank()
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
# torch.set_num_interop_threads(1)
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
run(rank=rank, world_size=world_size, args=args)
|
||||
|
||||
|
@ -11,15 +11,16 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||
|
||||
import kaldialign
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
Pathlike = Union[str, Path]
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if "WORLD_SIZE" in os.environ:
|
||||
return int(os.environ["WORLD_SIZE"])
|
||||
@ -37,6 +38,7 @@ def get_rank():
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_local_rank():
|
||||
if "LOCAL_RANK" in os.environ:
|
||||
return int(os.environ["LOCAL_RANK"])
|
||||
@ -45,6 +47,7 @@ def get_local_rank():
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
"""Used in argparse.ArgumentParser.add_argument to indicate
|
||||
that a type is a bool type and user can enter
|
||||
@ -63,6 +66,7 @@ def str2bool(v):
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
|
||||
class AttributeDict(dict):
|
||||
def __getattr__(self, key):
|
||||
if key in self:
|
||||
@ -87,6 +91,7 @@ class AttributeDict(dict):
|
||||
tmp[k] = v
|
||||
return json.dumps(tmp, indent=indent, sort_keys=True)
|
||||
|
||||
|
||||
def setup_logger(
|
||||
log_filename: Pathlike,
|
||||
log_level: str = "info",
|
||||
@ -139,6 +144,7 @@ def setup_logger(
|
||||
console.setFormatter(logging.Formatter(formatter))
|
||||
logging.getLogger("").addHandler(console)
|
||||
|
||||
|
||||
class MetricsTracker(collections.defaultdict):
|
||||
def __init__(self):
|
||||
# Passing the type 'int' to the base-class constructor
|
||||
@ -228,4 +234,200 @@ class MetricsTracker(collections.defaultdict):
|
||||
batch_idx: The current batch index, used as the x-axis of the plot.
|
||||
"""
|
||||
for k, v in self.norm_items():
|
||||
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
||||
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
||||
|
||||
|
||||
def store_transcripts(
|
||||
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
|
||||
) -> None:
|
||||
"""Save predicted results and reference transcripts to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
File to save the results to.
|
||||
texts:
|
||||
An iterable of tuples. The first element is the cur_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
If it is a multi-talker ASR system, the ref and hyp may also be lists of
|
||||
strings.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w", encoding="utf8") as f:
|
||||
for cut_id, ref, hyp in texts:
|
||||
if char_level:
|
||||
ref = list("".join(ref))
|
||||
hyp = list("".join(hyp))
|
||||
print(f"{cut_id}:\tref={ref}", file=f)
|
||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||
|
||||
|
||||
def write_error_stats(
|
||||
f: TextIO,
|
||||
test_set_name: str,
|
||||
results: List[Tuple[str, str]],
|
||||
enable_log: bool = True,
|
||||
compute_CER: bool = False,
|
||||
sclite_mode: bool = False,
|
||||
) -> float:
|
||||
"""Write statistics based on predicted results and reference transcripts.
|
||||
|
||||
It will write the following to the given file:
|
||||
|
||||
- WER
|
||||
- number of insertions, deletions, substitutions, corrects and total
|
||||
reference words. For example::
|
||||
|
||||
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||
reference words (2337 correct)
|
||||
|
||||
- The difference between the reference transcript and predicted result.
|
||||
An instance is given below::
|
||||
|
||||
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||
|
||||
The above example shows that the reference word is `EDISON`,
|
||||
but it is predicted to `ADDISON` (a substitution error).
|
||||
|
||||
Another example is::
|
||||
|
||||
FOR THE FIRST DAY (SIR->*) I THINK
|
||||
|
||||
The reference word `SIR` is missing in the predicted
|
||||
results (a deletion error).
|
||||
results:
|
||||
An iterable of tuples. The first element is the cut_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
enable_log:
|
||||
If True, also print detailed WER to the console.
|
||||
Otherwise, it is written only to the given file.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||
ins: Dict[str, int] = defaultdict(int)
|
||||
dels: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# `words` stores counts per word, as follows:
|
||||
# corr, ref_sub, hyp_sub, ins, dels
|
||||
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||
num_corr = 0
|
||||
ERR = "*"
|
||||
|
||||
if compute_CER:
|
||||
for i, res in enumerate(results):
|
||||
cut_id, ref, hyp = res
|
||||
ref = list("".join(ref))
|
||||
hyp = list("".join(hyp))
|
||||
results[i] = (cut_id, ref, hyp)
|
||||
|
||||
for cut_id, ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
||||
for ref_word, hyp_word in ali:
|
||||
if ref_word == ERR:
|
||||
ins[hyp_word] += 1
|
||||
words[hyp_word][3] += 1
|
||||
elif hyp_word == ERR:
|
||||
dels[ref_word] += 1
|
||||
words[ref_word][4] += 1
|
||||
elif hyp_word != ref_word:
|
||||
subs[(ref_word, hyp_word)] += 1
|
||||
words[ref_word][1] += 1
|
||||
words[hyp_word][2] += 1
|
||||
else:
|
||||
words[ref_word][0] += 1
|
||||
num_corr += 1
|
||||
ref_len = sum([len(r) for _, r, _ in results])
|
||||
sub_errs = sum(subs.values())
|
||||
ins_errs = sum(ins.values())
|
||||
del_errs = sum(dels.values())
|
||||
tot_errs = sub_errs + ins_errs + del_errs
|
||||
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||
|
||||
if enable_log:
|
||||
logging.info(
|
||||
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||
f"{del_errs} del, {sub_errs} sub ]"
|
||||
)
|
||||
|
||||
print(f"%WER = {tot_err_rate}", file=f)
|
||||
print(
|
||||
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||
f"words ({num_corr} correct)",
|
||||
file=f,
|
||||
)
|
||||
print(
|
||||
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||
for cut_id, ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR)
|
||||
combine_successive_errors = True
|
||||
if combine_successive_errors:
|
||||
ali = [[[x], [y]] for x, y in ali]
|
||||
for i in range(len(ali) - 1):
|
||||
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||
ali[i] = [[], []]
|
||||
ali = [
|
||||
[
|
||||
list(filter(lambda a: a != ERR, x)),
|
||||
list(filter(lambda a: a != ERR, y)),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
ali = list(filter(lambda x: x != [[], []], ali))
|
||||
ali = [
|
||||
[
|
||||
ERR if x == [] else " ".join(x),
|
||||
ERR if y == [] else " ".join(y),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
|
||||
print(
|
||||
f"{cut_id}:\t"
|
||||
+ " ".join(
|
||||
(
|
||||
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
||||
for ref_word, hyp_word in ali
|
||||
)
|
||||
),
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||
|
||||
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
||||
print(f"{count} {ref} -> {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("DELETIONS: count ref", file=f)
|
||||
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||
print(f"{count} {ref}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("INSERTIONS: count hyp", file=f)
|
||||
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||
print(f"{count} {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
||||
for _, word, counts in sorted(
|
||||
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||
):
|
||||
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||
ref_count = corr + ref_sub + dels
|
||||
hyp_count = corr + hyp_sub + ins
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
return float(tot_err_rate)
|
||||
|
Loading…
x
Reference in New Issue
Block a user