This commit is contained in:
root 2025-04-29 09:46:44 +00:00
parent 360f0aa397
commit 11bd3c9ad8
7 changed files with 154 additions and 460 deletions

View File

@ -23,67 +23,33 @@ The following table lists the folders for different tasks.
Command for training is:
```bash
pip install -r whisper_llm_zh/requirements.txt
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
# First, we only train the projector and freeze other modules.
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora False --unfreeze-llm False
# Then we jointly train the projector and LLM LoRA modules.
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True
--pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
```
Command for decoding:
Command for decoding is:
```bash
mkdir -p models/whisper models/qwen models/checkpoint
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
python3 ./whisper_llm_zh/decode.py \
--max-duration 80 \
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name models/qwen \
python3 ./qwen_omni/decode.py \
--max-duration 1 \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--use-lora True --dataset aishell
--method e2e-epoch10_speech2speech \
--enable-speech-output True \
--token2wav-path models/CosyVoice-300M-SFT \
--use-lora True
```
Please see [`prepare.sh`](./prepare.sh) for more details.

View File

@ -165,7 +165,7 @@ def compute_fbank(args):
storage_type=LilcomChunkyWriter,
overwrite=True,
)
cuts_path = f"{in_out_dir}/{args.prefix}_cuts.{idx}.jsonl.gz"
cuts_path = f"{in_out_dir}/cuts_{args.prefix}.{idx}.jsonl.gz"
logging.info(f"Saving to {cuts_path}")
# see https://github.com/lhotse-speech/lhotse/issues/1125
cut_set.drop_recordings().to_file(cuts_path)

View File

@ -2,7 +2,9 @@
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni
export PYTHONPATH=$PYTHONPATH:/workspace/icefall
set -eou pipefail
stage=$1
@ -19,18 +21,37 @@ log() {
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: "
#pip uninstall lhotse
#cd /workspace/slam/lhotse
#git config --global --add safe.directory /workspace/slam/lhotse
#pip install -e '.[dev]'
cd -
pip install -r slam_omni/requirements.txt
log "stage 0: Clone CosyVoice repo and install requirements inside the container"
# docker: ghcr.io/swivid/f5-tts:main
pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git /workspace/CosyVoice
cd /workspace/CosyVoice
# If you failed to clone submodule due to network failures, please run following command until success
git submodule update --init --recursive
pip install -r qwen_omni/requirements.txt
pip install -r qwen_omni/requirements-cosyvoice.txt
# For Chinese only dataset, you can use the following command to download the Chinese fine-tuned whisper model.
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Cosyvoice pretrained model for speech token2wav module
huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
# For Gradio demo, we follow https://arxiv.org/abs/2412.15649 to use ASR model to decode the history speech as context.
pip install sherpa-onnx
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
if [ ! -d $model_path ]; then
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
fi
fi
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
log "stage 1: Compute fbank feature from huggingface"
python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_test \
@ -39,26 +60,42 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
--prefix belle
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Combine features"
manifest_dir=data/fbank
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
mv $manifest_dir/cuts_belle.00000.jsonl.gz ./
# exclude cust_belle_00000.jsonl.gz for valid and test set
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
# # remove cust_belle_00000.jsonl.gz from pieces
# pieces=$(echo $pieces | sed 's/cuts_belle.00000.jsonl.gz//g')
echo $pieces | wc
lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz && cd -
mv ./cuts_belle.00000.jsonl.gz $manifest_dir # put it back
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz
ln -s cuts_belle.00000.jsonl.gz cuts_belle_test.jsonl.gz && cd -
fi
fi
ngpu=8
exp_dir=./qwen_omni/exp_speech2speech
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: "
exp_dir=./slam_omni/exp_speech2speech_rerun
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
python3 ./slam_omni/decode.py \
log "stage 3: Training Speech2Speech Model"
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "stage 4: Decoding, only support batch_size=1 for now."
cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
python3 ./qwen_omni/decode.py \
--max-duration 1 \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
@ -66,78 +103,20 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method e2e-epoch10_speech2speech_rerun \
--method e2e-epoch10_speech2speech \
--enable-speech-output True \
--token2wav-path /workspace/CosyVoice-300M-SFT \
--use-lora True # --on-the-fly-feats True
--token2wav-path models/CosyVoice-300M-SFT \
--use-lora True
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "stage 4: "
ngpu=8
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
--max-duration 80 \
--enable-musan False \
--exp-dir ./slam_omni/exp_speech2text \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./slam_omni/ds_config_zero1.json \
--use-flash-attn True \
--pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
--sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \
--use-lora True --unfreeze-llm True
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "stage 5: "
ngpu=8
exp_dir=./slam_omni/exp_speech2speech_rerun
# exp_dir_new=./slam_omni/exp_s2s
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
log "stage 5: Gradio Demo"
python3 ./qwen_omni/web_demo.py \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./slam_omni/ds_config_zero1.json \
--use-flash-attn True \
--pretrained-model-path $exp_dir/epoch-1-checkpoint-15000.pt/pytorch_model.bin \
--sampler-state-dict-path $exp_dir/epoch-1-checkpoint-15000-sampler.pt \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "stage 6: "
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
exp_dir=./slam_omni/exp_speech2speech_rerun
python3 ./slam_omni/web_demo.py \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--checkpoint-path $exp_dir/epoch-998.pt \
--checkpoint-path $exp_dir/epoch-999.pt \
--use-flash-attn True \
--enable-speech-output True \
--asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \
--use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "stage 7: "
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
if [ ! -d $model_path ]; then
pip install sherpa-onnx
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
fi
fi

View File

@ -116,28 +116,6 @@ class AsrDataModule:
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
# group.add_argument(
# "--concatenate-cuts",
# type=str2bool,
# default=False,
# help="When enabled, utterances (cuts) will be concatenated "
# "to minimize the amount of padding.",
# )
# group.add_argument(
# "--duration-factor",
# type=float,
# default=1.0,
# help="Determines the maximum duration of a concatenated cut "
# "relative to the duration of the longest cut in a batch.",
# )
# group.add_argument(
# "--gap",
# type=float,
# default=1.0,
# help="The amount of padding (in seconds) inserted between "
# "concatenated cuts. This padding is filled with noise when "
# "noise augmentation is used.",
# )
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
@ -256,20 +234,6 @@ class AsrDataModule:
else:
logging.info("Disable MUSAN")
# if self.args.concatenate_cuts:
# logging.info(
# f"Using cut concatenation with duration factor "
# f"{self.args.duration_factor} and gap {self.args.gap}."
# )
# # Cut concatenation should be the first transform in the list,
# # so that if we e.g. mix noise in, it will fill the gaps between
# # different utterances.
# transforms = [
# CutConcatenate(
# duration_factor=self.args.duration_factor, gap=self.args.gap
# )
# ] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
@ -426,32 +390,12 @@ class AsrDataModule:
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
if self.args.on_the_fly_feats:
# dataset = load_dataset(self.args.huggingface_dataset_path_or_name, streaming=True, split=partition)
i, num_digits = 0, 5
idx = f"{i}".zfill(num_digits)
parquet_files = [
f"data/train-{idx}-of-01601.parquet",
]
parquet_files = [
f"{self.args.huggingface_dataset_path_or_name}/{f}"
for f in parquet_files
]
file_name = parquet_files[0]
logging.info(f"Loading dataset from {file_name}")
dataset = load_dataset(
"parquet", data_files=parquet_files, streaming=True, split="train"
)
cut_set = CutSet.from_huggingface_dataset(
dataset, audio_key=self.args.audio_key, text_key=self.args.text_key
)
if self.args.resample_to_16kHz:
cut_set = cut_set.resample(16000)
return {"test": cut_set}
pass
else:
# return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
# return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_test_small.jsonl.gz")}
return {
"test": load_manifest_lazy("data/fbank_test/belle_cuts.00000.jsonl.gz")
"test": load_manifest_lazy(
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
)
}
@lru_cache()
@ -461,7 +405,7 @@ class AsrDataModule:
pass
else:
return load_manifest_lazy(
self.args.manifest_dir / "cuts_belle.00000.jsonl.gz"
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
)
@lru_cache()

View File

@ -20,30 +20,27 @@
"""
Usage:
# Command for decoding using fine-tuned models:
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Cosyvoice pretrained model for speech token2wav module
huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen models/checkpoint
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
python3 ./whisper_llm_zh/decode.py \
--max-duration 80 \
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name models/qwen \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--use-lora True --dataset aishell
cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
python3 ./qwen_omni/decode.py \
--max-duration 1 \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--method e2e-epoch10_speech2speech \
--enable-speech-output True \
--token2wav-path models/CosyVoice-300M-SFT \
--use-lora True
"""
import argparse
@ -183,11 +180,6 @@ def get_model(params, device):
attn_implementation = "eager"
torch_dtype = torch.float16
# codec_lm = AutoModelForCausalLM.from_pretrained(
# params.llm_path_or_name,
# attn_implementation=attn_implementation,
# torch_dtype=torch_dtype,
# )
codec_vocab_size = 4096 + 4
config = Qwen2Config(
vocab_size=codec_vocab_size,
@ -198,39 +190,19 @@ def get_model(params, device):
intermediate_size=2048,
max_position_embeddings=4096,
)
# codec_lm = Qwen2ForCausalLM(config=config)
# Pass attn_implementation and torch_dtype to the constructor
# Use AutoModelForCausalLM.from_config for more generality
codec_lm = AutoModelForCausalLM.from_config(
config=config,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
# cosyvoice2_token_size = 6561
codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size
codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3
codec_lm.config.mask_token_id = codec_vocab_size - 4
# if params.use_lora:
# lora_config = LoraConfig(
# r=64,
# lora_alpha=16,
# target_modules=[
# "q_proj",
# "k_proj",
# "v_proj",
# "o_proj",
# "up_proj",
# "gate_proj",
# "down_proj",
# ],
# lora_dropout=0.05,
# task_type="CAUSAL_LM",
# )
# codec_lm = get_peft_model(codec_lm, lora_config)
# codec_lm.print_trainable_parameters()
else:
codec_lm = None
@ -373,13 +345,6 @@ def get_parser():
default="/workspace/CosyVoice-300M-SFT",
help="The path to the token2wav model",
)
# parser.add_argument(
# "--dataset",
# type=str,
# default="aishell",
# choices=["aishell", "speechio", "wenetspeech_test_meeting", "multi_hans_zh"],
# help="The dataset to decode",
# )
add_model_arguments(parser)
return parser
@ -474,12 +439,6 @@ def decode_one_batch(
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
# messages = [
# [
# {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
# {"role": "assistant", "content": ""},
# ]
# ] * len(feature)
questions_with_history = [
cut.custom["question"] for cut in batch["supervisions"]["cut"]
]
@ -496,7 +455,6 @@ def decode_one_batch(
history_question_answer = history_contexts[i].split("USER:")
history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1):
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
question_answer = history_question_answer[j].split("ASSISTANT:")
message += [
{"role": "user", "content": question_answer[0].strip()},
@ -504,7 +462,6 @@ def decode_one_batch(
]
message += [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
# {"role": "user", "content": f"{last_questions[i]}"},
{"role": "assistant", "content": ""},
]
print(f"message: {message}, batch_size {len(chat_rounds)}")
@ -525,13 +482,6 @@ def decode_one_batch(
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050)
# with open(speech_token_file_name, 'w') as f:
# # save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
# #torchaudio.save(save_path, speech_output.cpu(), 16000)
# # print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
# save_str = " ".join([str(i) for i in generated_speech_output])
# f.write(f"{cut_id}|{save_str}\n")
else:
generated_ids = model.decode(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
@ -560,43 +510,6 @@ def decode_dataset(
Returns:
Return a dict, whose key may be "beam-search".
"""
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
results = []
num_cuts = 0
@ -634,7 +547,6 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
# ref_text = normalize_text_alimeeting(ref_text)
ref_words = ref_text.split()
print(f"ref: {ref_text}")
print(f"hyp: {''.join(hyp_words)}")
@ -673,7 +585,6 @@ def save_results(
errs_filename = (
params.log_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
@ -732,11 +643,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
data_module = AsrDataModule(args)
# data_module = MultiDataset(args.manifest_dir)
def remove_long_utt(c: Cut):
# Keep only utterances with duration in 30 seconds
@ -748,13 +656,6 @@ def main():
return False
return True
# if params.dataset == "aishell":
# test_sets_cuts = data_module.aishell_test_cuts()
# elif params.dataset == "speechio":
# test_sets_cuts = data_module.speechio_test_cuts()
# elif params.dataset == "wenetspeech_test_meeting":
# test_sets_cuts = data_module.wenetspeech_test_meeting_cuts()
# else:
test_sets_cuts = data_module.test_cuts()
test_sets = test_sets_cuts.keys()

View File

@ -1,4 +1,4 @@
from typing import List, Tuple # Added for type hints
from typing import List, Tuple
import torch
from torch import nn
@ -78,7 +78,6 @@ class SPEECH_LLM(nn.Module):
self.codec_lm_head = nn.Linear(
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
)
# to torch.float16
self.speech_token_projector = self.speech_token_projector.to(
dtype=torch.float16
)
@ -498,20 +497,6 @@ class SPEECH_LLM(nn.Module):
pad_token_id=self.llm.config.pad_token_id,
)
# generated_ids = self.llm.generate(
# inputs_embeds=inputs_embeds,
# max_new_tokens=kwargs.get("max_new_tokens", 200),
# num_beams=kwargs.get("num_beams", 1),
# do_sample=kwargs.get("do_sample", False),
# min_length=kwargs.get("min_length", 1),
# top_p=kwargs.get("top_p", 1.0),
# repetition_penalty=kwargs.get("repetition_penalty", 1.0),
# temperature=kwargs.get("temperature", 1.0),
# length_penalty=kwargs.get("length_penalty", 1.0),
# bos_token_id=self.llm.config.bos_token_id,
# eos_token_id=self.llm.config.eos_token_id,
# pad_token_id=self.llm.config.pad_token_id,
# )
return generated_ids
def decode_with_speech_output(
@ -520,7 +505,7 @@ class SPEECH_LLM(nn.Module):
input_ids: torch.LongTensor = None, # Prompt input_ids
attention_mask: torch.Tensor = None, # Prompt attention_mask
max_text_new_tokens: int = 1024,
max_speech_new_tokens: int = 1024, # Max length for speech tokens
max_speech_new_tokens: int = 2048, # Max length for speech tokens
llm_kwargs: dict = None, # Kwargs for text LLM generate
codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
) -> Tuple[torch.LongTensor, List[List[int]]]:
@ -602,7 +587,7 @@ class SPEECH_LLM(nn.Module):
eos_token_id = self.llm.config.eos_token_id
eos_token_embedding = self.llm.get_input_embeddings()(
torch.tensor([[eos_token_id]], device=device)
) # 1,D
)
assert (
generated_text_ids[0, -1] == eos_token_id
), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
@ -610,7 +595,7 @@ class SPEECH_LLM(nn.Module):
token_hidden_states[0].to(self.llm.device)
for token_hidden_states in text_outputs.hidden_states
]
# shift one for thinker token_embeds, drop the first embeds, and add the eos token
first_thinker_token_embed = torch.cat(
[
thinker_token_embeds_org[0][:, 1:],
@ -628,7 +613,7 @@ class SPEECH_LLM(nn.Module):
token_hidden_states[-1].to(self.llm.device)
for token_hidden_states in text_outputs.hidden_states
]
# thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
thinker_reply_part = [
torch.cat(
[
@ -651,12 +636,8 @@ class SPEECH_LLM(nn.Module):
dim=-1,
)
thinker_prompt_part = self.speech_token_projector(
thinker_prompt_part
) # [B, S_full, D_codec]
thinker_reply_part = self.speech_token_projector(
thinker_reply_part
) # [B, S_full, D_codec]
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part)
thinker_reply_part = self.speech_token_projector(thinker_reply_part)
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
talker_input_ids = torch.full(
@ -666,9 +647,7 @@ class SPEECH_LLM(nn.Module):
device=self.llm.device,
)
talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
talker_input_ids
) # [B, S_full, D_codec]
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids)
thinker_input_embeds = torch.cat(
[
thinker_prompt_part,
@ -677,68 +656,43 @@ class SPEECH_LLM(nn.Module):
dim=1,
)
talker_inputs_embeds += thinker_input_embeds
thinker_reply_part = thinker_reply_part[
:, delay_step + 1 :, :
] # [B, S_full, D_codec]
thinker_reply_part = thinker_reply_part[:, delay_step + 1 :, :]
past_key_values = None
# generated_speech_tokens_list = [[] for _ in range(batch_size)]
# unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
generated_speech_tokens_list = []
next_token_ids = None
# text_context_len = projected_text_embeds.shape[1] # S_full
for t in range(max_speech_new_tokens):
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
if t > 0:
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
next_token_ids
) # [B, 1, D_codec]
)
if thinker_reply_part.shape[1] > 0:
talker_inputs_embeds += thinker_reply_part[:, :1, :]
thinker_reply_part = thinker_reply_part[
:, 1:, :
] # Remove the first token for next step
# # Add the projected text embedding corresponding to the current timestep `t`
# if t < text_context_len:
# # Text context from the full generated text sequence
# current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec]
# inputs_embeds = current_speech_embeds + current_text_context_embed
# else:
# # No more text context to add
# inputs_embeds = current_speech_embeds
thinker_reply_part = thinker_reply_part[:, 1:, :]
# Forward pass through codec LM for one step
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
codec_outputs = self.codec_lm(
inputs_embeds=talker_inputs_embeds, # Combined embedding for this step
inputs_embeds=talker_inputs_embeds,
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
output_hidden_states=True,
# No attention mask needed here when using past_key_values and single token input
)
last_token_hidden_state = codec_outputs.hidden_states[-1][
:, -1, :
] # [B, D_codec] #TODO: check shape here
# Get logits for the *last* token generated in this step
next_token_logits = self.codec_lm_head(
last_token_hidden_state
) # Use -1 index
# suppress tokens between 4096:len(vocab)-3
# next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :]
next_token_logits = self.codec_lm_head(last_token_hidden_state)
next_token_ids = topk_sampling(
next_token_logits,
)
# print(next_token_ids, "next_token_ids", t, next_token_ids.shape)
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
break
# current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
past_key_values = codec_outputs.past_key_values # Update KV cache
generated_speech_tokens_list.append(
next_token_ids.squeeze(1).cpu().tolist()[0]
)
# --- 6. Return Results ---
return generated_text_ids, generated_speech_tokens_list

View File

@ -17,28 +17,22 @@
# limitations under the License.
"""
Usage:
# fine-tuning with whisper and Qwen2
pip install huggingface_hub['cli']
mkdir -p models/whisper models/qwen
# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
# For aishell fine-tuned whisper model
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
"""
import argparse
@ -52,7 +46,6 @@ from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
@ -66,8 +59,6 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
# from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
@ -330,9 +321,6 @@ def compute_loss(
truncation=False,
)
)
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
# remove too long text
# texts = [ text for text in texts if len(text) < 1024 ]
if len(texts) != len(messages):
logging.warning(f"Remove too long text, {messages} ")
max_len_texts = max([len(text) for text in texts])
@ -347,7 +335,7 @@ def compute_loss(
for text in texts
]
input_ids = torch.tensor(texts, dtype=torch.int)
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
@ -396,8 +384,6 @@ def compute_loss(
history_contexts = [
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
]
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。<USER>: 告诉我如何烹饪鸡肉
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。
messages = []
for i, total_round in enumerate(chat_rounds):
@ -406,7 +392,6 @@ def compute_loss(
history_question_answer = history_contexts[i].split("USER:")
history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1):
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
question_answer = history_question_answer[j].split("ASSISTANT:")
message += [
{"role": "user", "content": question_answer[0].strip()},
@ -683,7 +668,6 @@ def run(rank, world_size, args):
if params.use_flash_attn:
attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16 FIX ME
torch_dtype = torch.float16
tokenizer.padding_side = "left"
@ -724,14 +708,6 @@ def run(rank, world_size, args):
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict)
# original_tokenizer_vocab_size = len(tokenizer)
# cosyvoice2_token_size = 6561
# new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
# "<|SPEECH_GENERATION_START|>"
# ]
# num_added_tokens = tokenizer.add_tokens(new_tokens)
# model.resize_token_embeddings(len(tokenizer))
# model.vocab_size = len(tokenizer)
llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
@ -755,11 +731,6 @@ def run(rank, world_size, args):
attn_implementation = "eager"
torch_dtype = torch.float16
# codec_lm = AutoModelForCausalLM.from_pretrained(
# params.llm_path_or_name,
# attn_implementation=attn_implementation,
# torch_dtype=torch_dtype,
# )
codec_vocab_size = 4096 + 4
# TODO: modify above vocab size or supress_tokens when decoding
config = Qwen2Config(
@ -771,39 +742,19 @@ def run(rank, world_size, args):
intermediate_size=2048,
max_position_embeddings=4096,
)
# codec_lm = Qwen2ForCausalLM(config=config)
# Pass attn_implementation and torch_dtype to the constructor
# Use AutoModelForCausalLM.from_config for more generality
codec_lm = AutoModelForCausalLM.from_config(
config=config,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
# cosyvoice2_token_size = 6561
codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size
codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3
codec_lm.config.mask_token_id = codec_vocab_size - 4
# if params.use_lora:
# lora_config = LoraConfig(
# r=64,
# lora_alpha=16,
# target_modules=[
# "q_proj",
# "k_proj",
# "v_proj",
# "o_proj",
# "up_proj",
# "gate_proj",
# "down_proj",
# ],
# lora_dropout=0.05,
# task_type="CAUSAL_LM",
# )
# codec_lm = get_peft_model(codec_lm, lora_config)
# codec_lm.print_trainable_parameters()
else:
codec_lm = None
@ -856,7 +807,6 @@ def run(rank, world_size, args):
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
# cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]
codec_len = len(c.custom["answer_cosyvoice_speech_token"])
if codec_len > 2200:
logging.warning(
@ -873,7 +823,7 @@ def run(rank, world_size, args):
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
sampler_state_dict["max_duration"] = params.max_duration
# TODO: load sampler state dict
train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)