From 11bd3c9ad828ab29e068c6f7346d2827d0e34b87 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 29 Apr 2025 09:46:44 +0000 Subject: [PATCH] lint --- egs/speech_llm/SPEECH2SPEECH/README.md | 86 ++++------- .../local/compute_whisper_fbank.py | 2 +- egs/speech_llm/SPEECH2SPEECH/prepare.sh | 145 ++++++++---------- .../SPEECH2SPEECH/qwen_omni/data_module.py | 66 +------- .../SPEECH2SPEECH/qwen_omni/decode.py | 143 +++-------------- .../SPEECH2SPEECH/qwen_omni/model.py | 84 +++------- .../SPEECH2SPEECH/qwen_omni/train.py | 88 +++-------- 7 files changed, 154 insertions(+), 460 deletions(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/README.md b/egs/speech_llm/SPEECH2SPEECH/README.md index e4738eeef..cc5e60063 100644 --- a/egs/speech_llm/SPEECH2SPEECH/README.md +++ b/egs/speech_llm/SPEECH2SPEECH/README.md @@ -23,67 +23,33 @@ The following table lists the folders for different tasks. Command for training is: ```bash -pip install -r whisper_llm_zh/requirements.txt - -pip install huggingface_hub['cli'] -mkdir -p models/whisper models/qwen - -# For aishell fine-tuned whisper model -huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt -# For multi-hans fine-tuned whisper model -# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt - -# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct -huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct - -# First, we only train the projector and freeze other modules. -torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ - --max-duration 200 \ - --exp-dir ./whisper_llm_zh/exp_test \ - --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ - --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \ - --manifest-dir data/fbank \ - --deepspeed \ - --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ - --use-flash-attn True \ - --use-lora False --unfreeze-llm False - -# Then we jointly train the projector and LLM LoRA modules. -torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ - --max-duration 200 \ - --exp-dir ./whisper_llm_zh/exp_test \ - --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ - --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \ - --manifest-dir data/fbank \ - --deepspeed \ - --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ - --use-flash-attn True \ - --use-lora True --unfreeze-llm True - --pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt +torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \ + --max-duration 50 \ + --enable-musan False \ + --exp-dir $exp_dir \ + --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ + --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \ + --manifest-dir data/fbank \ + --deepspeed \ + --deepspeed_config ./qwen_omni/ds_config_zero1.json \ + --use-flash-attn True \ + --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True ``` -Command for decoding: +Command for decoding is: ```bash -mkdir -p models/whisper models/qwen models/checkpoint -huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B - -# For aishell fine-tuned whisper model -huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt -# For multi-hans fine-tuned whisper model -# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt - -huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct - -mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B -ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt - -python3 ./whisper_llm_zh/decode.py \ - --max-duration 80 \ - --exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \ - --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ - --llm-path-or-name models/qwen \ - --epoch 999 --avg 1 \ - --manifest-dir data/fbank \ - --use-flash-attn True \ - --use-lora True --dataset aishell +python3 ./qwen_omni/decode.py \ + --max-duration 1 \ + --exp-dir $exp_dir \ + --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ + --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ + --epoch 999 --avg 1 \ + --manifest-dir data/fbank \ + --use-flash-attn True \ + --method e2e-epoch10_speech2speech \ + --enable-speech-output True \ + --token2wav-path models/CosyVoice-300M-SFT \ + --use-lora True ``` + +Please see [`prepare.sh`](./prepare.sh) for more details. diff --git a/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py b/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py index 4bc5e5a82..f67324ba3 100755 --- a/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py +++ b/egs/speech_llm/SPEECH2SPEECH/local/compute_whisper_fbank.py @@ -165,7 +165,7 @@ def compute_fbank(args): storage_type=LilcomChunkyWriter, overwrite=True, ) - cuts_path = f"{in_out_dir}/{args.prefix}_cuts.{idx}.jsonl.gz" + cuts_path = f"{in_out_dir}/cuts_{args.prefix}.{idx}.jsonl.gz" logging.info(f"Saving to {cuts_path}") # see https://github.com/lhotse-speech/lhotse/issues/1125 cut_set.drop_recordings().to_file(cuts_path) diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index 47320ab66..42c9b4eaa 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -2,7 +2,9 @@ # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python -export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni + +export PYTHONPATH=$PYTHONPATH:/workspace/icefall + set -eou pipefail stage=$1 @@ -19,18 +21,37 @@ log() { if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "stage 0: " - #pip uninstall lhotse - #cd /workspace/slam/lhotse - #git config --global --add safe.directory /workspace/slam/lhotse - #pip install -e '.[dev]' - cd - - pip install -r slam_omni/requirements.txt + log "stage 0: Clone CosyVoice repo and install requirements inside the container" + # docker: ghcr.io/swivid/f5-tts:main + pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html + git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git /workspace/CosyVoice + cd /workspace/CosyVoice + # If you failed to clone submodule due to network failures, please run following command until success + git submodule update --init --recursive + pip install -r qwen_omni/requirements.txt + pip install -r qwen_omni/requirements-cosyvoice.txt + + # For Chinese only dataset, you can use the following command to download the Chinese fine-tuned whisper model. + huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper + # Cosyvoice pretrained model for speech token2wav module + huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT + # Qwen Pretrained model + huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct + # Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni + huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M + + # For Gradio demo, we follow https://arxiv.org/abs/2412.15649 to use ASR model to decode the history speech as context. + pip install sherpa-onnx + model_path=local/sherpa-onnx-paraformer-zh-2023-09-14 + if [ ! -d $model_path ]; then + wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 + tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local + fi fi +export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface" - + log "stage 1: Compute fbank feature from huggingface" python3 local/compute_whisper_fbank.py \ --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \ --out-dir data/fbank_test \ @@ -39,26 +60,42 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then --prefix belle fi - if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Combine features" manifest_dir=data/fbank if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then + mv $manifest_dir/cuts_belle.00000.jsonl.gz ./ + # exclude cust_belle_00000.jsonl.gz for valid and test set pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort) - # # remove cust_belle_00000.jsonl.gz from pieces - # pieces=$(echo $pieces | sed 's/cuts_belle.00000.jsonl.gz//g') echo $pieces | wc lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz - cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz && cd - + mv ./cuts_belle.00000.jsonl.gz $manifest_dir # put it back + cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz + ln -s cuts_belle.00000.jsonl.gz cuts_belle_test.jsonl.gz && cd - fi fi - +ngpu=8 +exp_dir=./qwen_omni/exp_speech2speech if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "stage 3: " - exp_dir=./slam_omni/exp_speech2speech_rerun - export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice - python3 ./slam_omni/decode.py \ + log "stage 3: Training Speech2Speech Model" + torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \ + --max-duration 50 \ + --enable-musan False \ + --exp-dir $exp_dir \ + --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ + --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \ + --manifest-dir data/fbank \ + --deepspeed \ + --deepspeed_config ./qwen_omni/ds_config_zero1.json \ + --use-flash-attn True \ + --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "stage 4: Decoding, only support batch_size=1 for now." + cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd - + python3 ./qwen_omni/decode.py \ --max-duration 1 \ --exp-dir $exp_dir \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ @@ -66,78 +103,20 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then --epoch 999 --avg 1 \ --manifest-dir data/fbank \ --use-flash-attn True \ - --method e2e-epoch10_speech2speech_rerun \ + --method e2e-epoch10_speech2speech \ --enable-speech-output True \ - --token2wav-path /workspace/CosyVoice-300M-SFT \ - --use-lora True # --on-the-fly-feats True - + --token2wav-path models/CosyVoice-300M-SFT \ + --use-lora True fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "stage 4: " - ngpu=8 -torchrun --nproc_per_node $ngpu ./slam_omni/train.py \ - --max-duration 80 \ - --enable-musan False \ - --exp-dir ./slam_omni/exp_speech2text \ - --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ - --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ - --manifest-dir data/fbank \ - --deepspeed \ - --deepspeed_config ./slam_omni/ds_config_zero1.json \ - --use-flash-attn True \ - --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \ - --sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \ - --use-lora True --unfreeze-llm True -fi - - if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "stage 5: " - ngpu=8 - exp_dir=./slam_omni/exp_speech2speech_rerun - # exp_dir_new=./slam_omni/exp_s2s - torchrun --nproc_per_node $ngpu ./slam_omni/train.py \ - --max-duration 50 \ - --enable-musan False \ - --exp-dir $exp_dir \ - --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ - --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ - --manifest-dir data/fbank \ - --deepspeed \ - --deepspeed_config ./slam_omni/ds_config_zero1.json \ - --use-flash-attn True \ - --pretrained-model-path $exp_dir/epoch-1-checkpoint-15000.pt/pytorch_model.bin \ - --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-15000-sampler.pt \ - --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True - # --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \ - # --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \ - -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "stage 6: " - export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice - exp_dir=./slam_omni/exp_speech2speech_rerun - python3 ./slam_omni/web_demo.py \ + log "stage 5: Gradio Demo" + python3 ./qwen_omni/web_demo.py \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ - --checkpoint-path $exp_dir/epoch-998.pt \ + --checkpoint-path $exp_dir/epoch-999.pt \ --use-flash-attn True \ --enable-speech-output True \ --asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \ --use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share - -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "stage 7: " - model_path=local/sherpa-onnx-paraformer-zh-2023-09-14 - - if [ ! -d $model_path ]; then - pip install sherpa-onnx - wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 - tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local - fi fi diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py index 7cab52f73..dc38f32bd 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/data_module.py @@ -116,28 +116,6 @@ class AsrDataModule: help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) - # group.add_argument( - # "--concatenate-cuts", - # type=str2bool, - # default=False, - # help="When enabled, utterances (cuts) will be concatenated " - # "to minimize the amount of padding.", - # ) - # group.add_argument( - # "--duration-factor", - # type=float, - # default=1.0, - # help="Determines the maximum duration of a concatenated cut " - # "relative to the duration of the longest cut in a batch.", - # ) - # group.add_argument( - # "--gap", - # type=float, - # default=1.0, - # help="The amount of padding (in seconds) inserted between " - # "concatenated cuts. This padding is filled with noise when " - # "noise augmentation is used.", - # ) group.add_argument( "--on-the-fly-feats", type=str2bool, @@ -256,20 +234,6 @@ class AsrDataModule: else: logging.info("Disable MUSAN") - # if self.args.concatenate_cuts: - # logging.info( - # f"Using cut concatenation with duration factor " - # f"{self.args.duration_factor} and gap {self.args.gap}." - # ) - # # Cut concatenation should be the first transform in the list, - # # so that if we e.g. mix noise in, it will fill the gaps between - # # different utterances. - # transforms = [ - # CutConcatenate( - # duration_factor=self.args.duration_factor, gap=self.args.gap - # ) - # ] + transforms - input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") @@ -426,32 +390,12 @@ class AsrDataModule: def test_cuts(self) -> CutSet: logging.info("About to get test cuts") if self.args.on_the_fly_feats: - # dataset = load_dataset(self.args.huggingface_dataset_path_or_name, streaming=True, split=partition) - i, num_digits = 0, 5 - idx = f"{i}".zfill(num_digits) - parquet_files = [ - f"data/train-{idx}-of-01601.parquet", - ] - parquet_files = [ - f"{self.args.huggingface_dataset_path_or_name}/{f}" - for f in parquet_files - ] - file_name = parquet_files[0] - logging.info(f"Loading dataset from {file_name}") - dataset = load_dataset( - "parquet", data_files=parquet_files, streaming=True, split="train" - ) - cut_set = CutSet.from_huggingface_dataset( - dataset, audio_key=self.args.audio_key, text_key=self.args.text_key - ) - if self.args.resample_to_16kHz: - cut_set = cut_set.resample(16000) - return {"test": cut_set} + pass else: - # return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")} - # return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_test_small.jsonl.gz")} return { - "test": load_manifest_lazy("data/fbank_test/belle_cuts.00000.jsonl.gz") + "test": load_manifest_lazy( + self.args.manifest_dir / "cuts_belle_test.jsonl.gz" + ) } @lru_cache() @@ -461,7 +405,7 @@ class AsrDataModule: pass else: return load_manifest_lazy( - self.args.manifest_dir / "cuts_belle.00000.jsonl.gz" + self.args.manifest_dir / "cuts_belle_test.jsonl.gz" ) @lru_cache() diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py index acd882d18..e4dccf081 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode.py @@ -20,30 +20,27 @@ """ Usage: # Command for decoding using fine-tuned models: +huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper +# Cosyvoice pretrained model for speech token2wav module +huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT +# Qwen Pretrained model +huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct +# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni +huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M -pip install huggingface_hub['cli'] -mkdir -p models/whisper models/qwen models/checkpoint -huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B - -# For aishell fine-tuned whisper model -huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt -# For multi-hans fine-tuned whisper model -# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt - -huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct - -mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B -ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt - -python3 ./whisper_llm_zh/decode.py \ - --max-duration 80 \ - --exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \ - --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ - --llm-path-or-name models/qwen \ - --epoch 999 --avg 1 \ - --manifest-dir data/fbank \ - --use-flash-attn True \ - --use-lora True --dataset aishell +cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd - +python3 ./qwen_omni/decode.py \ +--max-duration 1 \ +--exp-dir $exp_dir \ +--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ +--llm-path-or-name models/Qwen2.5-0.5B-Instruct \ +--epoch 999 --avg 1 \ +--manifest-dir data/fbank \ +--use-flash-attn True \ +--method e2e-epoch10_speech2speech \ +--enable-speech-output True \ +--token2wav-path models/CosyVoice-300M-SFT \ +--use-lora True """ import argparse @@ -183,11 +180,6 @@ def get_model(params, device): attn_implementation = "eager" torch_dtype = torch.float16 - # codec_lm = AutoModelForCausalLM.from_pretrained( - # params.llm_path_or_name, - # attn_implementation=attn_implementation, - # torch_dtype=torch_dtype, - # ) codec_vocab_size = 4096 + 4 config = Qwen2Config( vocab_size=codec_vocab_size, @@ -198,39 +190,19 @@ def get_model(params, device): intermediate_size=2048, max_position_embeddings=4096, ) - # codec_lm = Qwen2ForCausalLM(config=config) - # Pass attn_implementation and torch_dtype to the constructor - # Use AutoModelForCausalLM.from_config for more generality + codec_lm = AutoModelForCausalLM.from_config( config=config, attn_implementation=attn_implementation, torch_dtype=torch_dtype, ) - # cosyvoice2_token_size = 6561 + codec_lm.resize_token_embeddings(codec_vocab_size) codec_lm.vocab_size = codec_vocab_size codec_lm.config.pad_token_id = codec_vocab_size - 1 codec_lm.config.eos_token_id = codec_vocab_size - 2 codec_lm.config.bos_token_id = codec_vocab_size - 3 codec_lm.config.mask_token_id = codec_vocab_size - 4 - # if params.use_lora: - # lora_config = LoraConfig( - # r=64, - # lora_alpha=16, - # target_modules=[ - # "q_proj", - # "k_proj", - # "v_proj", - # "o_proj", - # "up_proj", - # "gate_proj", - # "down_proj", - # ], - # lora_dropout=0.05, - # task_type="CAUSAL_LM", - # ) - # codec_lm = get_peft_model(codec_lm, lora_config) - # codec_lm.print_trainable_parameters() else: codec_lm = None @@ -373,13 +345,6 @@ def get_parser(): default="/workspace/CosyVoice-300M-SFT", help="The path to the token2wav model", ) - # parser.add_argument( - # "--dataset", - # type=str, - # default="aishell", - # choices=["aishell", "speechio", "wenetspeech_test_meeting", "multi_hans_zh"], - # help="The dataset to decode", - # ) add_model_arguments(parser) return parser @@ -474,12 +439,6 @@ def decode_one_batch( chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]] - # messages = [ - # [ - # {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, - # {"role": "assistant", "content": ""}, - # ] - # ] * len(feature) questions_with_history = [ cut.custom["question"] for cut in batch["supervisions"]["cut"] ] @@ -496,7 +455,6 @@ def decode_one_batch( history_question_answer = history_contexts[i].split("USER:") history_question_answer = [item for item in history_question_answer if item] for j in range(total_round - 1): - # USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。 question_answer = history_question_answer[j].split("ASSISTANT:") message += [ {"role": "user", "content": question_answer[0].strip()}, @@ -504,7 +462,6 @@ def decode_one_batch( ] message += [ {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, - # {"role": "user", "content": f"{last_questions[i]}"}, {"role": "assistant", "content": ""}, ] print(f"message: {message}, batch_size {len(chat_rounds)}") @@ -525,13 +482,6 @@ def decode_one_batch( audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0) audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model) sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050) - # with open(speech_token_file_name, 'w') as f: - # # save_path = params.exp_dir / f"speech_output/{cut_id}.wav" - # #torchaudio.save(save_path, speech_output.cpu(), 16000) - # # print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}") - # save_str = " ".join([str(i) for i in generated_speech_output]) - # f.write(f"{cut_id}|{save_str}\n") - else: generated_ids = model.decode( feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) @@ -560,43 +510,6 @@ def decode_dataset( Returns: Return a dict, whose key may be "beam-search". """ - - def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str: - """ - Text normalization similar to M2MeT challenge baseline. - See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl - """ - if normalize == "none": - return text - elif normalize == "m2met": - import re - - text = text.replace(" ", "") - text = text.replace("", "") - text = text.replace("<%>", "") - text = text.replace("<->", "") - text = text.replace("<$>", "") - text = text.replace("<#>", "") - text = text.replace("<_>", "") - text = text.replace("", "") - text = text.replace("`", "") - text = text.replace("&", "") - text = text.replace(",", "") - if re.search("[a-zA-Z]", text): - text = text.upper() - text = text.replace("A", "A") - text = text.replace("a", "A") - text = text.replace("b", "B") - text = text.replace("c", "C") - text = text.replace("k", "K") - text = text.replace("t", "T") - text = text.replace(",", "") - text = text.replace("丶", "") - text = text.replace("。", "") - text = text.replace("、", "") - text = text.replace("?", "") - return text - results = [] num_cuts = 0 @@ -634,7 +547,6 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - # ref_text = normalize_text_alimeeting(ref_text) ref_words = ref_text.split() print(f"ref: {ref_text}") print(f"hyp: {''.join(hyp_words)}") @@ -673,7 +585,6 @@ def save_results( errs_filename = ( params.log_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) - # we compute CER for aishell dataset. results_char = [] for res in results: results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) @@ -732,11 +643,8 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - # we need cut ids to display recognition results. args.return_cuts = True - data_module = AsrDataModule(args) - # data_module = MultiDataset(args.manifest_dir) def remove_long_utt(c: Cut): # Keep only utterances with duration in 30 seconds @@ -748,13 +656,6 @@ def main(): return False return True - # if params.dataset == "aishell": - # test_sets_cuts = data_module.aishell_test_cuts() - # elif params.dataset == "speechio": - # test_sets_cuts = data_module.speechio_test_cuts() - # elif params.dataset == "wenetspeech_test_meeting": - # test_sets_cuts = data_module.wenetspeech_test_meeting_cuts() - # else: test_sets_cuts = data_module.test_cuts() test_sets = test_sets_cuts.keys() diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py index 97870337d..a0efbd319 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/model.py @@ -1,4 +1,4 @@ -from typing import List, Tuple # Added for type hints +from typing import List, Tuple import torch from torch import nn @@ -78,7 +78,6 @@ class SPEECH_LLM(nn.Module): self.codec_lm_head = nn.Linear( self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size ) - # to torch.float16 self.speech_token_projector = self.speech_token_projector.to( dtype=torch.float16 ) @@ -498,20 +497,6 @@ class SPEECH_LLM(nn.Module): pad_token_id=self.llm.config.pad_token_id, ) - # generated_ids = self.llm.generate( - # inputs_embeds=inputs_embeds, - # max_new_tokens=kwargs.get("max_new_tokens", 200), - # num_beams=kwargs.get("num_beams", 1), - # do_sample=kwargs.get("do_sample", False), - # min_length=kwargs.get("min_length", 1), - # top_p=kwargs.get("top_p", 1.0), - # repetition_penalty=kwargs.get("repetition_penalty", 1.0), - # temperature=kwargs.get("temperature", 1.0), - # length_penalty=kwargs.get("length_penalty", 1.0), - # bos_token_id=self.llm.config.bos_token_id, - # eos_token_id=self.llm.config.eos_token_id, - # pad_token_id=self.llm.config.pad_token_id, - # ) return generated_ids def decode_with_speech_output( @@ -520,7 +505,7 @@ class SPEECH_LLM(nn.Module): input_ids: torch.LongTensor = None, # Prompt input_ids attention_mask: torch.Tensor = None, # Prompt attention_mask max_text_new_tokens: int = 1024, - max_speech_new_tokens: int = 1024, # Max length for speech tokens + max_speech_new_tokens: int = 2048, # Max length for speech tokens llm_kwargs: dict = None, # Kwargs for text LLM generate codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET ) -> Tuple[torch.LongTensor, List[List[int]]]: @@ -602,7 +587,7 @@ class SPEECH_LLM(nn.Module): eos_token_id = self.llm.config.eos_token_id eos_token_embedding = self.llm.get_input_embeddings()( torch.tensor([[eos_token_id]], device=device) - ) # 1,D + ) assert ( generated_text_ids[0, -1] == eos_token_id ), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}" @@ -610,7 +595,7 @@ class SPEECH_LLM(nn.Module): token_hidden_states[0].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states ] - # shift one for thinker token_embeds, drop the first embeds, and add the eos token + first_thinker_token_embed = torch.cat( [ thinker_token_embeds_org[0][:, 1:], @@ -628,7 +613,7 @@ class SPEECH_LLM(nn.Module): token_hidden_states[-1].to(self.llm.device) for token_hidden_states in text_outputs.hidden_states ] - # thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1) + thinker_reply_part = [ torch.cat( [ @@ -651,12 +636,8 @@ class SPEECH_LLM(nn.Module): dim=-1, ) - thinker_prompt_part = self.speech_token_projector( - thinker_prompt_part - ) # [B, S_full, D_codec] - thinker_reply_part = self.speech_token_projector( - thinker_reply_part - ) # [B, S_full, D_codec] + thinker_prompt_part = self.speech_token_projector(thinker_prompt_part) + thinker_reply_part = self.speech_token_projector(thinker_reply_part) thinker_prompt_part_seq_len = thinker_prompt_part.shape[1] talker_input_ids = torch.full( @@ -666,9 +647,7 @@ class SPEECH_LLM(nn.Module): device=self.llm.device, ) talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id - talker_inputs_embeds = self.codec_lm.get_input_embeddings()( - talker_input_ids - ) # [B, S_full, D_codec] + talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids) thinker_input_embeds = torch.cat( [ thinker_prompt_part, @@ -677,68 +656,43 @@ class SPEECH_LLM(nn.Module): dim=1, ) talker_inputs_embeds += thinker_input_embeds - thinker_reply_part = thinker_reply_part[ - :, delay_step + 1 :, : - ] # [B, S_full, D_codec] + thinker_reply_part = thinker_reply_part[:, delay_step + 1 :, :] past_key_values = None - # generated_speech_tokens_list = [[] for _ in range(batch_size)] - # unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device) + generated_speech_tokens_list = [] next_token_ids = None - # text_context_len = projected_text_embeds.shape[1] # S_full + for t in range(max_speech_new_tokens): - # Get embedding for the *current* input token ID (initially BOS, then generated tokens) - # current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec] if t > 0: talker_inputs_embeds = self.codec_lm.get_input_embeddings()( next_token_ids - ) # [B, 1, D_codec] + ) if thinker_reply_part.shape[1] > 0: talker_inputs_embeds += thinker_reply_part[:, :1, :] - thinker_reply_part = thinker_reply_part[ - :, 1:, : - ] # Remove the first token for next step - # # Add the projected text embedding corresponding to the current timestep `t` - # if t < text_context_len: - # # Text context from the full generated text sequence - # current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec] - # inputs_embeds = current_speech_embeds + current_text_context_embed - # else: - # # No more text context to add - # inputs_embeds = current_speech_embeds + thinker_reply_part = thinker_reply_part[:, 1:, :] - # Forward pass through codec LM for one step - # We provide inputs_embeds directly, bypassing prepare_inputs_for_generation codec_outputs = self.codec_lm( - inputs_embeds=talker_inputs_embeds, # Combined embedding for this step + inputs_embeds=talker_inputs_embeds, past_key_values=past_key_values, use_cache=True, return_dict=True, output_hidden_states=True, - # No attention mask needed here when using past_key_values and single token input ) - last_token_hidden_state = codec_outputs.hidden_states[-1][ - :, -1, : - ] # [B, D_codec] #TODO: check shape here - # Get logits for the *last* token generated in this step - next_token_logits = self.codec_lm_head( - last_token_hidden_state - ) # Use -1 index - # suppress tokens between 4096:len(vocab)-3 - # next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens? + last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :] + next_token_logits = self.codec_lm_head(last_token_hidden_state) + next_token_ids = topk_sampling( next_token_logits, ) - # print(next_token_ids, "next_token_ids", t, next_token_ids.shape) if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id: break - # current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step + past_key_values = codec_outputs.past_key_values # Update KV cache generated_speech_tokens_list.append( next_token_ids.squeeze(1).cpu().tolist()[0] ) - # --- 6. Return Results --- + return generated_text_ids, generated_speech_tokens_list diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py index 1438a2624..95ce16d0e 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py @@ -17,28 +17,22 @@ # limitations under the License. """ Usage: -# fine-tuning with whisper and Qwen2 -pip install huggingface_hub['cli'] -mkdir -p models/whisper models/qwen +# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model. +huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper +# Qwen Pretrained model +huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct -# For aishell fine-tuned whisper model -huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt -# For multi-hans fine-tuned whisper model -# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt - -# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct -huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct - -torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ - --max-duration 200 \ - --exp-dir ./whisper_llm_zh/exp_test \ - --speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \ - --llm-path-or-name Qwen/Qwen2-1.5B-Instruct \ - --manifest-dir data/fbank \ - --deepspeed \ - --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ - --use-flash-attn True \ - --use-lora True --unfreeze-llm True +torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \ + --max-duration 50 \ + --enable-musan False \ + --exp-dir $exp_dir \ + --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ + --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \ + --manifest-dir data/fbank \ + --deepspeed \ + --deepspeed_config ./qwen_omni/ds_config_zero1.json \ + --use-flash-attn True \ + --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True """ import argparse @@ -52,7 +46,6 @@ from shutil import copyfile from typing import Any, Dict, List, Optional, Tuple, Union import deepspeed -import k2 import torch import torch.multiprocessing as mp import torch.nn as nn @@ -66,8 +59,6 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector - -# from multi_dataset import MultiDataset from peft import LoraConfig, get_peft_model from torch import Tensor from torch.utils.tensorboard import SummaryWriter @@ -330,9 +321,6 @@ def compute_loss( truncation=False, ) ) - # padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id - # remove too long text - # texts = [ text for text in texts if len(text) < 1024 ] if len(texts) != len(messages): logging.warning(f"Remove too long text, {messages} ") max_len_texts = max([len(text) for text in texts]) @@ -347,7 +335,7 @@ def compute_loss( for text in texts ] input_ids = torch.tensor(texts, dtype=torch.int) - # response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] + target_ids = input_ids.clone() target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID # mask all tokens before token_id 151646 with IGNORE_TOKEN_ID @@ -396,8 +384,6 @@ def compute_loss( history_contexts = [ question.rsplit(":", 1)[0].strip() for question in questions_with_history ] - # USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。: 告诉我如何烹饪鸡肉 - # : 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。 messages = [] for i, total_round in enumerate(chat_rounds): @@ -406,7 +392,6 @@ def compute_loss( history_question_answer = history_contexts[i].split("USER:") history_question_answer = [item for item in history_question_answer if item] for j in range(total_round - 1): - # USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。 question_answer = history_question_answer[j].split("ASSISTANT:") message += [ {"role": "user", "content": question_answer[0].strip()}, @@ -683,7 +668,6 @@ def run(rank, world_size, args): if params.use_flash_attn: attn_implementation = "flash_attention_2" - # torch_dtype=torch.bfloat16 FIX ME torch_dtype = torch.float16 tokenizer.padding_side = "left" @@ -724,14 +708,6 @@ def run(rank, world_size, args): special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]} tokenizer.add_special_tokens(special_tokens_dict) - # original_tokenizer_vocab_size = len(tokenizer) - # cosyvoice2_token_size = 6561 - # new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [ - # "<|SPEECH_GENERATION_START|>" - # ] - # num_added_tokens = tokenizer.add_tokens(new_tokens) - # model.resize_token_embeddings(len(tokenizer)) - # model.vocab_size = len(tokenizer) llm.config.pad_token_id = tokenizer.pad_token_id llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( @@ -755,11 +731,6 @@ def run(rank, world_size, args): attn_implementation = "eager" torch_dtype = torch.float16 - # codec_lm = AutoModelForCausalLM.from_pretrained( - # params.llm_path_or_name, - # attn_implementation=attn_implementation, - # torch_dtype=torch_dtype, - # ) codec_vocab_size = 4096 + 4 # TODO: modify above vocab size or supress_tokens when decoding config = Qwen2Config( @@ -771,39 +742,19 @@ def run(rank, world_size, args): intermediate_size=2048, max_position_embeddings=4096, ) - # codec_lm = Qwen2ForCausalLM(config=config) - # Pass attn_implementation and torch_dtype to the constructor - # Use AutoModelForCausalLM.from_config for more generality + codec_lm = AutoModelForCausalLM.from_config( config=config, attn_implementation=attn_implementation, torch_dtype=torch_dtype, ) - # cosyvoice2_token_size = 6561 + codec_lm.resize_token_embeddings(codec_vocab_size) codec_lm.vocab_size = codec_vocab_size codec_lm.config.pad_token_id = codec_vocab_size - 1 codec_lm.config.eos_token_id = codec_vocab_size - 2 codec_lm.config.bos_token_id = codec_vocab_size - 3 codec_lm.config.mask_token_id = codec_vocab_size - 4 - # if params.use_lora: - # lora_config = LoraConfig( - # r=64, - # lora_alpha=16, - # target_modules=[ - # "q_proj", - # "k_proj", - # "v_proj", - # "o_proj", - # "up_proj", - # "gate_proj", - # "down_proj", - # ], - # lora_dropout=0.05, - # task_type="CAUSAL_LM", - # ) - # codec_lm = get_peft_model(codec_lm, lora_config) - # codec_lm.print_trainable_parameters() else: codec_lm = None @@ -856,7 +807,6 @@ def run(rank, world_size, args): # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # ) return False - # cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"] codec_len = len(c.custom["answer_cosyvoice_speech_token"]) if codec_len > 2200: logging.warning( @@ -873,7 +823,7 @@ def run(rank, world_size, args): if params.sampler_state_dict_path: sampler_state_dict = torch.load(params.sampler_state_dict_path) sampler_state_dict["max_duration"] = params.max_duration - # TODO: load sampler state dict + train_dl = data_module.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict )