mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 02:22:17 +00:00
support text2speech ultrachat
This commit is contained in:
parent
49256fa917
commit
4c0396f8f2
@ -3,7 +3,7 @@
|
|||||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||||
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
||||||
export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
|
# export HF_HOME="/lustre/fsw/general_sa/yuekaiz/.cache/huggingface"
|
||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
stage=$1
|
stage=$1
|
||||||
@ -123,7 +123,9 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||||
log "stage 19: Training TTS Model"
|
log "stage 19: Training TTS Model"
|
||||||
exp_dir=./qwen_omni/exp_tts
|
exp_dir=./qwen_omni/exp_tts_ultra_chat_voice_assistant
|
||||||
|
exp_dir=./qwen_omni/exp_tts_emilia_en_tts_only_template
|
||||||
|
exp_dir=./qwen_omni/exp_tts_emilia_en_tts_three_concat
|
||||||
pretrained_dir=./qwen_omni/exp_speech2text
|
pretrained_dir=./qwen_omni/exp_speech2text
|
||||||
ngpu=4
|
ngpu=4
|
||||||
|
|
||||||
@ -141,17 +143,16 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
|||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
# --dataset ultra_chat_voice_assistant
|
||||||
train_cmd_args="--batch-size 64 \
|
train_cmd_args="--batch-size 30 \
|
||||||
--exp-dir $exp_dir \
|
--exp-dir $exp_dir \
|
||||||
--last-stage-model-path $pretrained_dir/checkpoint-58548/pytorch_model.bin \
|
|
||||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
--enable-speech-input False \
|
--enable-speech-input False \
|
||||||
--deepspeed \
|
--deepspeed \
|
||||||
--dataset /lustre/fsw/general_sa/yuekaiz/s2s/emilia_en \
|
--dataset /lustre/fsw/general_sa/yuekaiz/s2s/VoxBox/manifests_emilia_en \
|
||||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
--num-epochs 2 \
|
--num-epochs 3 \
|
||||||
--use-lora False --unfreeze-llm False --enable-speech-output True"
|
--use-lora False --unfreeze-llm False --enable-speech-output True"
|
||||||
|
|
||||||
if [ "$latest_checkpoint_step" -ge 0 ]; then
|
if [ "$latest_checkpoint_step" -ge 0 ]; then
|
||||||
@ -168,66 +169,66 @@ if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
# if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||||
log "stage 20: Training TTS Model"
|
# log "stage 20: Training TTS Model"
|
||||||
echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
|
# echo "cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -"
|
||||||
if [ ! -L "/workspace/slam" ]; then
|
# if [ ! -L "/workspace/slam" ]; then
|
||||||
cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
|
# cd /workspace && ln -s /lustre/fsw/general_sa/yuekaiz/s2s slam && cd -
|
||||||
fi
|
# fi
|
||||||
exp_dir=./qwen_omni/exp_test
|
# exp_dir=./qwen_omni/exp_test
|
||||||
ngpu=4
|
# ngpu=4
|
||||||
|
|
||||||
latest_checkpoint_step=-1
|
# latest_checkpoint_step=-1
|
||||||
# Check if exp_dir exists and is a directory
|
# # Check if exp_dir exists and is a directory
|
||||||
if [ -d "$exp_dir" ]; then
|
# if [ -d "$exp_dir" ]; then
|
||||||
# List directories matching checkpoint-* and find the one with the largest step number
|
# # List directories matching checkpoint-* and find the one with the largest step number
|
||||||
for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
|
# for checkpoint_dir in $(ls -d $exp_dir/checkpoint-*/ 2>/dev/null | sort -V); do
|
||||||
checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
|
# checkpoint_name=$(basename "$checkpoint_dir") # e.g., checkpoint-1000
|
||||||
# Extract step number using parameter expansion
|
# # Extract step number using parameter expansion
|
||||||
current_step=${checkpoint_name#checkpoint-}
|
# current_step=${checkpoint_name#checkpoint-}
|
||||||
# Ensure current_step is a number
|
# # Ensure current_step is a number
|
||||||
if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
|
# if [[ "$current_step" =~ ^[0-9]+$ ]] && [ "$current_step" -gt "$latest_checkpoint_step" ]; then
|
||||||
latest_checkpoint_step=$current_step
|
# latest_checkpoint_step=$current_step
|
||||||
fi
|
# fi
|
||||||
done
|
# done
|
||||||
fi
|
# fi
|
||||||
|
|
||||||
train_cmd_args="--max-duration 150 \
|
# train_cmd_args="--max-duration 150 \
|
||||||
--enable-musan False \
|
# --enable-musan False \
|
||||||
--exp-dir $exp_dir \
|
# --exp-dir $exp_dir \
|
||||||
--speech-encoder-path-or-name models/large-v2.pt \
|
# --speech-encoder-path-or-name models/large-v2.pt \
|
||||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
# --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||||
--dataset vocalnet_ultrachat_voiceassistant \
|
# --dataset vocalnet_ultrachat_voiceassistant \
|
||||||
--manifest-dir data/fbank \
|
# --manifest-dir data/fbank \
|
||||||
--deepspeed \
|
# --deepspeed \
|
||||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
# --deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||||
--use-flash-attn True --on-the-fly-feats True \
|
# --use-flash-attn True --on-the-fly-feats True \
|
||||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True"
|
# --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True"
|
||||||
|
|
||||||
if [ "$latest_checkpoint_step" -ge 0 ]; then
|
# if [ "$latest_checkpoint_step" -ge 0 ]; then
|
||||||
log "Continuing training from checkpoint-$latest_checkpoint_step"
|
# log "Continuing training from checkpoint-$latest_checkpoint_step"
|
||||||
step=$latest_checkpoint_step
|
# step=$latest_checkpoint_step
|
||||||
train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
|
# train_cmd_args="$train_cmd_args --pretrained-model-path $exp_dir/checkpoint-${step}/pytorch_model.bin --sampler-state-dict-path $exp_dir/checkpoint-${step}/sampler.pt"
|
||||||
else
|
# else
|
||||||
log "Starting training from scratch as no checkpoint was found in $exp_dir"
|
# log "Starting training from scratch as no checkpoint was found in $exp_dir"
|
||||||
# No pretrained model or sampler state dict needed for the first run
|
# # No pretrained model or sampler state dict needed for the first run
|
||||||
fi
|
# fi
|
||||||
|
|
||||||
torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
|
# torchrun --nproc_per_node $ngpu --nnodes $SLURM_JOB_NUM_NODES --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --rdzv_backend c10d --rdzv_id $SLURM_JOBID ./qwen_omni/train.py \
|
||||||
$train_cmd_args
|
# $train_cmd_args
|
||||||
fi
|
# fi
|
||||||
|
|
||||||
|
|
||||||
if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
|
# if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
|
||||||
log "stage 21: TTS Decoding Test Set"
|
# log "stage 21: TTS Decoding Test Set"
|
||||||
exp_dir=./qwen_omni/exp_tts
|
# exp_dir=./qwen_omni/exp_tts
|
||||||
torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \
|
# torchrun --nproc_per_node=2 ./qwen_omni/decode_tts.py \
|
||||||
--exp-dir $exp_dir \
|
# --exp-dir $exp_dir \
|
||||||
--speech-encoder-path-or-name models/large-v2.pt \
|
# --speech-encoder-path-or-name models/large-v2.pt \
|
||||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
# --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
--pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
|
# --pretrained-model-path $exp_dir/checkpoint-32001/pytorch_model.bin \
|
||||||
--use-flash-attn True \
|
# --use-flash-attn True \
|
||||||
--enable-speech-output True \
|
# --enable-speech-output True \
|
||||||
--token2wav-path /workspace/CosyVoice2-0.5B \
|
# --token2wav-path /workspace/CosyVoice2-0.5B \
|
||||||
--use-lora True
|
# --use-lora True
|
||||||
fi
|
# fi
|
||||||
|
@ -437,7 +437,8 @@ class SPEECH_LLM(nn.Module):
|
|||||||
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
audio_attention_mask = audio_codes.ne(self.codec_lm.config.pad_token_id)
|
||||||
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
audio_embeddings = self.codec_lm.get_input_embeddings()(audio_codes)
|
||||||
|
|
||||||
text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
|
# text_last_hidden_lists, text_embeds_list, text_input_embeds_list = [], [], []
|
||||||
|
text_input_embeds_list = []
|
||||||
for i in range(len(text_label_start_index_list)):
|
for i in range(len(text_label_start_index_list)):
|
||||||
text_last_hidden = model_outputs.hidden_states[-1][
|
text_last_hidden = model_outputs.hidden_states[-1][
|
||||||
i,
|
i,
|
||||||
@ -445,14 +446,14 @@ class SPEECH_LLM(nn.Module):
|
|||||||
+ input_seq_len[i]
|
+ input_seq_len[i]
|
||||||
- 1,
|
- 1,
|
||||||
]
|
]
|
||||||
text_last_hidden_lists.append(text_last_hidden)
|
# text_last_hidden_lists.append(text_last_hidden)
|
||||||
text_embed = inputs_embeds[
|
text_embed = inputs_embeds[
|
||||||
i,
|
i,
|
||||||
text_input_start_index_list[i]
|
text_input_start_index_list[i]
|
||||||
+ 1 : text_input_start_index_list[i]
|
+ 1 : text_input_start_index_list[i]
|
||||||
+ input_seq_len[i],
|
+ input_seq_len[i],
|
||||||
] # exclude bos
|
] # exclude bos
|
||||||
text_embeds_list.append(text_embed)
|
# text_embeds_list.append(text_embed)
|
||||||
|
|
||||||
text_input_embeds = torch.cat(
|
text_input_embeds = torch.cat(
|
||||||
[
|
[
|
||||||
@ -480,8 +481,9 @@ class SPEECH_LLM(nn.Module):
|
|||||||
), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
|
), f"start_idx: {start_idx}, start_idx_re_compute: {start_idx_re_compute}"
|
||||||
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
|
if text_input_embeds.shape[0] > audio_embeddings.shape[1] - start_idx:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}"
|
f"Truncate text_input_embeds: {text_input_embeds.shape} to {audio_embeddings.shape[1] - start_idx}\naudio_codes_lens: {audio_codes_lens[i]}\ninput_question_len_list: {input_question_len_list[i]}\ninput_seq_len: {input_seq_len[i]}\n"
|
||||||
)
|
)
|
||||||
|
# breakpoint()
|
||||||
text_input_embeds = text_input_embeds[
|
text_input_embeds = text_input_embeds[
|
||||||
: audio_embeddings.shape[1] - start_idx
|
: audio_embeddings.shape[1] - start_idx
|
||||||
]
|
]
|
||||||
|
@ -68,6 +68,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||||
from torch.utils.data import DistributedSampler, DataLoader
|
from torch.utils.data import DistributedSampler, DataLoader
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from train import add_model_arguments, add_training_arguments, get_params, get_model
|
from train import add_model_arguments, add_training_arguments, get_params, get_model
|
||||||
from utils import ( # filter_uneven_sized_batch,
|
from utils import ( # filter_uneven_sized_batch,
|
||||||
@ -171,11 +172,14 @@ def data_collator(batch):
|
|||||||
{"role": "user", "content": f"Generate a speech from the following text:\n\n{item['text']}{DEFAULT_SPEECH_TOKEN}"},
|
{"role": "user", "content": f"Generate a speech from the following text:\n\n{item['text']}{DEFAULT_SPEECH_TOKEN}"},
|
||||||
{"role": "assistant", "content": item["text"]},
|
{"role": "assistant", "content": item["text"]},
|
||||||
]
|
]
|
||||||
|
# message_list_item += [
|
||||||
|
# {"role": "user", "content": f"TTS{DEFAULT_SPEECH_TOKEN}"},
|
||||||
|
# {"role": "assistant", "content": item["text"]},
|
||||||
|
# ]
|
||||||
messages.append(message_list_item)
|
messages.append(message_list_item)
|
||||||
durations.append(item["duration"])
|
durations.append(item["duration"])
|
||||||
ids.append(item["id"])
|
ids.append(item["index"] if "index" in item else item["id"])
|
||||||
lang.append(item["language"])
|
lang.append(item["language"])
|
||||||
dnsmos.append(item["dnsmos"])
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"speech_tokens": speech_tokens,
|
"speech_tokens": speech_tokens,
|
||||||
@ -183,7 +187,92 @@ def data_collator(batch):
|
|||||||
"durations": durations,
|
"durations": durations,
|
||||||
"ids": ids,
|
"ids": ids,
|
||||||
"lang": lang,
|
"lang": lang,
|
||||||
"dnsmos": dnsmos,
|
}
|
||||||
|
|
||||||
|
def data_collator_concate_items(batch, concat_items_num: int = 3):
|
||||||
|
"""Concatenate *concat_items_num* consecutive dataset items into one.
|
||||||
|
|
||||||
|
The function groups the incoming ``batch`` (a list of dataset items)
|
||||||
|
into non-overlapping chunks of *concat_items_num*. For each group it
|
||||||
|
concatenates the textual fields and speech codec tokens so that the
|
||||||
|
model generates one longer utterance instead of several short ones.
|
||||||
|
|
||||||
|
Any remainder (when ``len(batch)`` is not divisible by
|
||||||
|
*concat_items_num*) is also kept as a smaller group.
|
||||||
|
"""
|
||||||
|
|
||||||
|
grouped_speech_tokens, grouped_messages, grouped_durations = [], [], []
|
||||||
|
grouped_ids, grouped_lang = [], []
|
||||||
|
|
||||||
|
# Iterate over the batch in strides of *concat_items_num*
|
||||||
|
for start_idx in range(0, len(batch), concat_items_num):
|
||||||
|
group = batch[start_idx : start_idx + concat_items_num]
|
||||||
|
if not group:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 1) Speech tokens --------------------------------------------------
|
||||||
|
# ``item['code']`` can be a list[int] or a 1-D tensor. Use the first
|
||||||
|
# element to decide how to concatenate.
|
||||||
|
first_code = group[0]["code"]
|
||||||
|
if isinstance(first_code, torch.Tensor):
|
||||||
|
concat_code = torch.cat([item["code"] for item in group], dim=0)
|
||||||
|
else:
|
||||||
|
# assume list / iterable of ints
|
||||||
|
concat_code = []
|
||||||
|
for item in group:
|
||||||
|
concat_code.extend(item["code"])
|
||||||
|
|
||||||
|
# 2) Text -----------------------------------------------------------
|
||||||
|
concat_text = "".join([item["text"] for item in group])
|
||||||
|
|
||||||
|
# 3) Build chat template messages -----------------------------------
|
||||||
|
message_list_item = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Generate a speech from the following text:\n\n{concat_text}{DEFAULT_SPEECH_TOKEN}",
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": concat_text},
|
||||||
|
]
|
||||||
|
|
||||||
|
# 4) Misc meta fields ----------------------------------------------
|
||||||
|
total_duration = sum(item["duration"] for item in group)
|
||||||
|
group_ids = [item.get("index", item.get("id")) for item in group]
|
||||||
|
language = group[0].get("language", "")
|
||||||
|
|
||||||
|
# 5) Append to output lists ----------------------------------------
|
||||||
|
grouped_speech_tokens.append(concat_code)
|
||||||
|
grouped_messages.append(message_list_item)
|
||||||
|
grouped_durations.append(total_duration)
|
||||||
|
grouped_ids.append(group_ids)
|
||||||
|
grouped_lang.append(language)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"speech_tokens": grouped_speech_tokens,
|
||||||
|
"messages": grouped_messages,
|
||||||
|
"durations": grouped_durations,
|
||||||
|
"ids": grouped_ids,
|
||||||
|
"lang": grouped_lang,
|
||||||
|
}
|
||||||
|
|
||||||
|
def data_collator_ultra_chat(batch):
|
||||||
|
speech_tokens, messages, durations, ids, lang, dnsmos = [], [], [], [], [], []
|
||||||
|
for i, item in enumerate(batch):
|
||||||
|
speech_tokens.append(item["custom"]["speech_token"])
|
||||||
|
text = item["supervisions"][0]["text"]
|
||||||
|
message_list_item = []
|
||||||
|
message_list_item += [
|
||||||
|
{"role": "user", "content": f"Generate a speech from the following text:\n\n{text}{DEFAULT_SPEECH_TOKEN}"},
|
||||||
|
{"role": "assistant", "content": text},
|
||||||
|
]
|
||||||
|
messages.append(message_list_item)
|
||||||
|
durations.append(item["duration"])
|
||||||
|
ids.append(item["id"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"speech_tokens": speech_tokens,
|
||||||
|
"messages": messages,
|
||||||
|
"durations": durations,
|
||||||
|
"ids": ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
@ -470,13 +559,21 @@ def run(rank, world_size, args):
|
|||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
if params.sampler_state_dict_path:
|
if params.sampler_state_dict_path:
|
||||||
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
||||||
# print(params.dataset)
|
if params.dataset == "ultra_chat_voice_assistant":
|
||||||
ds = load_dataset(params.dataset, split="train")
|
data_dir = "data/fbank"
|
||||||
# shuffle the dataset
|
json_file_lists = ["data/fbank/cuts_voice_assistant_00001-00049.jsonl", "data/fbank/cuts_ultrachat_train.jsonl.gz"]
|
||||||
ds = ds.shuffle(seed=42)
|
ds = load_dataset("json", data_files=json_file_lists, split="train")
|
||||||
train_test_split = ds.train_test_split(test_size=1000, seed=42)
|
# shuffle the dataset
|
||||||
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
|
train_dataset = ds.shuffle(seed=42)
|
||||||
# train_dataset, eval_dataset = train_test_split["test"], train_test_split["test"]
|
eval_dataset = load_dataset("json", data_files=["data/fbank/cuts_voice_assistant.00000.jsonl"], split="train")
|
||||||
|
else:
|
||||||
|
data_dir = Path(params.dataset)
|
||||||
|
json_file_lists = [str(file) for file in data_dir.glob("*.jsonl")]
|
||||||
|
ds = load_dataset("json", data_files=json_file_lists, split="train")
|
||||||
|
# shuffle the dataset
|
||||||
|
ds = ds.shuffle(seed=42)
|
||||||
|
train_test_split = ds.train_test_split(test_size=1000, seed=42)
|
||||||
|
train_dataset, eval_dataset = train_test_split["train"], train_test_split["test"]
|
||||||
|
|
||||||
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
|
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
|
||||||
train_dl = StatefulDataLoader(
|
train_dl = StatefulDataLoader(
|
||||||
@ -486,7 +583,7 @@ def run(rank, world_size, args):
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
prefetch_factor=2,
|
prefetch_factor=2,
|
||||||
collate_fn=data_collator
|
collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
|
||||||
)
|
)
|
||||||
train_dl.load_state_dict(sampler_state_dict)
|
train_dl.load_state_dict(sampler_state_dict)
|
||||||
valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
|
valid_sampler = DistributedSampler(eval_dataset, num_replicas=world_size, rank=rank)
|
||||||
@ -497,7 +594,7 @@ def run(rank, world_size, args):
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=1,
|
num_workers=1,
|
||||||
prefetch_factor=1,
|
prefetch_factor=1,
|
||||||
collate_fn=data_collator
|
collate_fn=data_collator_ultra_chat if params.dataset == "ultra_chat_voice_assistant" else data_collator
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.tensorboard and rank == 0:
|
if args.tensorboard and rank == 0:
|
||||||
|
@ -12,12 +12,12 @@ from dataclasses import dataclass
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
|
||||||
|
from tqdm import tqdm
|
||||||
import kaldialign
|
import kaldialign
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
import numpy as np
|
||||||
Pathlike = Union[str, Path]
|
Pathlike = Union[str, Path]
|
||||||
|
|
||||||
|
|
||||||
@ -431,3 +431,45 @@ def write_error_stats(
|
|||||||
|
|
||||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
return float(tot_err_rate)
|
return float(tot_err_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_statistics(dataset, save_filename):
|
||||||
|
speech_token_lengths = []
|
||||||
|
text_lengths = []
|
||||||
|
for item in tqdm(dataset):
|
||||||
|
if 'custom' not in item:
|
||||||
|
speech_token = item["code"]
|
||||||
|
text = item["text"]
|
||||||
|
else:
|
||||||
|
speech_token = item["custom"]["speech_token"]
|
||||||
|
text = item["supervisions"][0]["text"]
|
||||||
|
speech_token_lengths.append(len(speech_token))
|
||||||
|
text_lengths.append(len(text))
|
||||||
|
speech_token_length_array = np.array(speech_token_lengths)
|
||||||
|
text_length_array = np.array(text_lengths)
|
||||||
|
# 计算并存储统计指标
|
||||||
|
def get_length_stats(lengths_array):
|
||||||
|
length_stats = []
|
||||||
|
length_stats.append(["count", f"{len(lengths_array)}"]) # 总数
|
||||||
|
length_stats.append(["mean", f"{np.mean(lengths_array):.1f}"])
|
||||||
|
length_stats.append(["std", f"{np.std(lengths_array):.1f}"])
|
||||||
|
length_stats.append(["min", f"{np.min(lengths_array):.1f}"])
|
||||||
|
length_stats.append(["25%", f"{np.percentile(lengths_array, 25):.1f}"])
|
||||||
|
length_stats.append(["50% (median)", f"{np.median(lengths_array):.1f}"]) # median 和 50% percentile 是一样的
|
||||||
|
length_stats.append(["75%", f"{np.percentile(lengths_array, 75):.1f}"])
|
||||||
|
length_stats.append(["99%", f"{np.percentile(lengths_array, 99):.1f}"])
|
||||||
|
length_stats.append(["99.5%", f"{np.percentile(lengths_array, 99.5):.1f}"])
|
||||||
|
length_stats.append(["99.9%", f"{np.percentile(lengths_array, 99.9):.1f}"])
|
||||||
|
length_stats.append(["max", f"{np.max(lengths_array):.1f}"])
|
||||||
|
return length_stats
|
||||||
|
speech_length_stats = get_length_stats(speech_token_length_array)
|
||||||
|
text_length_stats = get_length_stats(text_length_array)
|
||||||
|
with open(save_filename, "w") as f:
|
||||||
|
print("speech_tokens 长度统计指标:", file=f)
|
||||||
|
for stat_name, stat_value in speech_length_stats:
|
||||||
|
print(f"{stat_name:<15}: {stat_value}", file=f)
|
||||||
|
print("\ntext 长度统计指标:", file=f)
|
||||||
|
for stat_name, stat_value in text_length_stats:
|
||||||
|
print(f"{stat_name:<15}: {stat_value}", file=f)
|
||||||
|
|
||||||
|
return speech_token_lengths, text_lengths
|
||||||
|
Loading…
x
Reference in New Issue
Block a user