This commit is contained in:
root 2025-04-29 09:46:44 +00:00
parent 360f0aa397
commit 11bd3c9ad8
7 changed files with 154 additions and 460 deletions

View File

@ -23,67 +23,33 @@ The following table lists the folders for different tasks.
Command for training is: Command for training is:
```bash ```bash
pip install -r whisper_llm_zh/requirements.txt torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
--max-duration 50 \
pip install huggingface_hub['cli'] --enable-musan False \
mkdir -p models/whisper models/qwen --exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
# For aishell fine-tuned whisper model --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
# First, we only train the projector and freeze other modules.
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--deepspeed \ --deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \ --use-flash-attn True \
--use-lora False --unfreeze-llm False --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
# Then we jointly train the projector and LLM LoRA modules.
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True
--pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
``` ```
Command for decoding: Command for decoding is:
```bash ```bash
mkdir -p models/whisper models/qwen models/checkpoint python3 ./qwen_omni/decode.py \
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B --max-duration 1 \
--exp-dir $exp_dir \
# For aishell fine-tuned whisper model --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
# For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
python3 ./whisper_llm_zh/decode.py \
--max-duration 80 \
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name models/qwen \
--epoch 999 --avg 1 \ --epoch 999 --avg 1 \
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--use-flash-attn True \ --use-flash-attn True \
--use-lora True --dataset aishell --method e2e-epoch10_speech2speech \
--enable-speech-output True \
--token2wav-path models/CosyVoice-300M-SFT \
--use-lora True
``` ```
Please see [`prepare.sh`](./prepare.sh) for more details.

View File

@ -165,7 +165,7 @@ def compute_fbank(args):
storage_type=LilcomChunkyWriter, storage_type=LilcomChunkyWriter,
overwrite=True, overwrite=True,
) )
cuts_path = f"{in_out_dir}/{args.prefix}_cuts.{idx}.jsonl.gz" cuts_path = f"{in_out_dir}/cuts_{args.prefix}.{idx}.jsonl.gz"
logging.info(f"Saving to {cuts_path}") logging.info(f"Saving to {cuts_path}")
# see https://github.com/lhotse-speech/lhotse/issues/1125 # see https://github.com/lhotse-speech/lhotse/issues/1125
cut_set.drop_recordings().to_file(cuts_path) cut_set.drop_recordings().to_file(cuts_path)

View File

@ -2,7 +2,9 @@
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni
export PYTHONPATH=$PYTHONPATH:/workspace/icefall
set -eou pipefail set -eou pipefail
stage=$1 stage=$1
@ -19,18 +21,37 @@ log() {
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: " log "stage 0: Clone CosyVoice repo and install requirements inside the container"
#pip uninstall lhotse # docker: ghcr.io/swivid/f5-tts:main
#cd /workspace/slam/lhotse pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
#git config --global --add safe.directory /workspace/slam/lhotse git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git /workspace/CosyVoice
#pip install -e '.[dev]' cd /workspace/CosyVoice
cd - # If you failed to clone submodule due to network failures, please run following command until success
pip install -r slam_omni/requirements.txt git submodule update --init --recursive
pip install -r qwen_omni/requirements.txt
pip install -r qwen_omni/requirements-cosyvoice.txt
# For Chinese only dataset, you can use the following command to download the Chinese fine-tuned whisper model.
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Cosyvoice pretrained model for speech token2wav module
huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
# For Gradio demo, we follow https://arxiv.org/abs/2412.15649 to use ASR model to decode the history speech as context.
pip install sherpa-onnx
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
if [ ! -d $model_path ]; then
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
fi
fi fi
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface" log "stage 1: Compute fbank feature from huggingface"
python3 local/compute_whisper_fbank.py \ python3 local/compute_whisper_fbank.py \
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \ --num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
--out-dir data/fbank_test \ --out-dir data/fbank_test \
@ -39,26 +60,42 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
--prefix belle --prefix belle
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Combine features" log "Stage 2: Combine features"
manifest_dir=data/fbank manifest_dir=data/fbank
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
mv $manifest_dir/cuts_belle.00000.jsonl.gz ./
# exclude cust_belle_00000.jsonl.gz for valid and test set
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort) pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
# # remove cust_belle_00000.jsonl.gz from pieces
# pieces=$(echo $pieces | sed 's/cuts_belle.00000.jsonl.gz//g')
echo $pieces | wc echo $pieces | wc
lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz && cd - mv ./cuts_belle.00000.jsonl.gz $manifest_dir # put it back
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz
ln -s cuts_belle.00000.jsonl.gz cuts_belle_test.jsonl.gz && cd -
fi fi
fi fi
ngpu=8
exp_dir=./qwen_omni/exp_speech2speech
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: " log "stage 3: Training Speech2Speech Model"
exp_dir=./slam_omni/exp_speech2speech_rerun torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice --max-duration 50 \
python3 ./slam_omni/decode.py \ --enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "stage 4: Decoding, only support batch_size=1 for now."
cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
python3 ./qwen_omni/decode.py \
--max-duration 1 \ --max-duration 1 \
--exp-dir $exp_dir \ --exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
@ -66,78 +103,20 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
--epoch 999 --avg 1 \ --epoch 999 --avg 1 \
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--use-flash-attn True \ --use-flash-attn True \
--method e2e-epoch10_speech2speech_rerun \ --method e2e-epoch10_speech2speech \
--enable-speech-output True \ --enable-speech-output True \
--token2wav-path /workspace/CosyVoice-300M-SFT \ --token2wav-path models/CosyVoice-300M-SFT \
--use-lora True # --on-the-fly-feats True --use-lora True
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "stage 4: "
ngpu=8
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
--max-duration 80 \
--enable-musan False \
--exp-dir ./slam_omni/exp_speech2text \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \
--deepspeed \
--deepspeed_config ./slam_omni/ds_config_zero1.json \
--use-flash-attn True \
--pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
--sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \
--use-lora True --unfreeze-llm True
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "stage 5: " log "stage 5: Gradio Demo"
ngpu=8 python3 ./qwen_omni/web_demo.py \
exp_dir=./slam_omni/exp_speech2speech_rerun
# exp_dir_new=./slam_omni/exp_s2s
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
--max-duration 50 \
--enable-musan False \
--exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \ --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--manifest-dir data/fbank \ --checkpoint-path $exp_dir/epoch-999.pt \
--deepspeed \
--deepspeed_config ./slam_omni/ds_config_zero1.json \
--use-flash-attn True \
--pretrained-model-path $exp_dir/epoch-1-checkpoint-15000.pt/pytorch_model.bin \
--sampler-state-dict-path $exp_dir/epoch-1-checkpoint-15000-sampler.pt \
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "stage 6: "
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
exp_dir=./slam_omni/exp_speech2speech_rerun
python3 ./slam_omni/web_demo.py \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--checkpoint-path $exp_dir/epoch-998.pt \
--use-flash-attn True \ --use-flash-attn True \
--enable-speech-output True \ --enable-speech-output True \
--asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \ --asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \
--use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share --use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "stage 7: "
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
if [ ! -d $model_path ]; then
pip install sherpa-onnx
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
fi
fi fi

View File

@ -116,28 +116,6 @@ class AsrDataModule:
help="The number of buckets for the DynamicBucketingSampler" help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).", "(you might want to increase it for larger datasets).",
) )
# group.add_argument(
# "--concatenate-cuts",
# type=str2bool,
# default=False,
# help="When enabled, utterances (cuts) will be concatenated "
# "to minimize the amount of padding.",
# )
# group.add_argument(
# "--duration-factor",
# type=float,
# default=1.0,
# help="Determines the maximum duration of a concatenated cut "
# "relative to the duration of the longest cut in a batch.",
# )
# group.add_argument(
# "--gap",
# type=float,
# default=1.0,
# help="The amount of padding (in seconds) inserted between "
# "concatenated cuts. This padding is filled with noise when "
# "noise augmentation is used.",
# )
group.add_argument( group.add_argument(
"--on-the-fly-feats", "--on-the-fly-feats",
type=str2bool, type=str2bool,
@ -256,20 +234,6 @@ class AsrDataModule:
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
# if self.args.concatenate_cuts:
# logging.info(
# f"Using cut concatenation with duration factor "
# f"{self.args.duration_factor} and gap {self.args.gap}."
# )
# # Cut concatenation should be the first transform in the list,
# # so that if we e.g. mix noise in, it will fill the gaps between
# # different utterances.
# transforms = [
# CutConcatenate(
# duration_factor=self.args.duration_factor, gap=self.args.gap
# )
# ] + transforms
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
@ -426,32 +390,12 @@ class AsrDataModule:
def test_cuts(self) -> CutSet: def test_cuts(self) -> CutSet:
logging.info("About to get test cuts") logging.info("About to get test cuts")
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
# dataset = load_dataset(self.args.huggingface_dataset_path_or_name, streaming=True, split=partition) pass
i, num_digits = 0, 5
idx = f"{i}".zfill(num_digits)
parquet_files = [
f"data/train-{idx}-of-01601.parquet",
]
parquet_files = [
f"{self.args.huggingface_dataset_path_or_name}/{f}"
for f in parquet_files
]
file_name = parquet_files[0]
logging.info(f"Loading dataset from {file_name}")
dataset = load_dataset(
"parquet", data_files=parquet_files, streaming=True, split="train"
)
cut_set = CutSet.from_huggingface_dataset(
dataset, audio_key=self.args.audio_key, text_key=self.args.text_key
)
if self.args.resample_to_16kHz:
cut_set = cut_set.resample(16000)
return {"test": cut_set}
else: else:
# return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
# return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_test_small.jsonl.gz")}
return { return {
"test": load_manifest_lazy("data/fbank_test/belle_cuts.00000.jsonl.gz") "test": load_manifest_lazy(
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
)
} }
@lru_cache() @lru_cache()
@ -461,7 +405,7 @@ class AsrDataModule:
pass pass
else: else:
return load_manifest_lazy( return load_manifest_lazy(
self.args.manifest_dir / "cuts_belle.00000.jsonl.gz" self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
) )
@lru_cache() @lru_cache()

View File

@ -20,30 +20,27 @@
""" """
Usage: Usage:
# Command for decoding using fine-tuned models: # Command for decoding using fine-tuned models:
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
# Cosyvoice pretrained model for speech token2wav module
huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
# Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
pip install huggingface_hub['cli'] cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
mkdir -p models/whisper models/qwen models/checkpoint python3 ./qwen_omni/decode.py \
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B --max-duration 1 \
--exp-dir $exp_dir \
# For aishell fine-tuned whisper model --speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
# For multi-hans fine-tuned whisper model --epoch 999 --avg 1 \
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt --manifest-dir data/fbank \
--use-flash-attn True \
huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct --method e2e-epoch10_speech2speech \
--enable-speech-output True \
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B --token2wav-path models/CosyVoice-300M-SFT \
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt --use-lora True
python3 ./whisper_llm_zh/decode.py \
--max-duration 80 \
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name models/qwen \
--epoch 999 --avg 1 \
--manifest-dir data/fbank \
--use-flash-attn True \
--use-lora True --dataset aishell
""" """
import argparse import argparse
@ -183,11 +180,6 @@ def get_model(params, device):
attn_implementation = "eager" attn_implementation = "eager"
torch_dtype = torch.float16 torch_dtype = torch.float16
# codec_lm = AutoModelForCausalLM.from_pretrained(
# params.llm_path_or_name,
# attn_implementation=attn_implementation,
# torch_dtype=torch_dtype,
# )
codec_vocab_size = 4096 + 4 codec_vocab_size = 4096 + 4
config = Qwen2Config( config = Qwen2Config(
vocab_size=codec_vocab_size, vocab_size=codec_vocab_size,
@ -198,39 +190,19 @@ def get_model(params, device):
intermediate_size=2048, intermediate_size=2048,
max_position_embeddings=4096, max_position_embeddings=4096,
) )
# codec_lm = Qwen2ForCausalLM(config=config)
# Pass attn_implementation and torch_dtype to the constructor
# Use AutoModelForCausalLM.from_config for more generality
codec_lm = AutoModelForCausalLM.from_config( codec_lm = AutoModelForCausalLM.from_config(
config=config, config=config,
attn_implementation=attn_implementation, attn_implementation=attn_implementation,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
) )
# cosyvoice2_token_size = 6561
codec_lm.resize_token_embeddings(codec_vocab_size) codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size codec_lm.vocab_size = codec_vocab_size
codec_lm.config.pad_token_id = codec_vocab_size - 1 codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2 codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3 codec_lm.config.bos_token_id = codec_vocab_size - 3
codec_lm.config.mask_token_id = codec_vocab_size - 4 codec_lm.config.mask_token_id = codec_vocab_size - 4
# if params.use_lora:
# lora_config = LoraConfig(
# r=64,
# lora_alpha=16,
# target_modules=[
# "q_proj",
# "k_proj",
# "v_proj",
# "o_proj",
# "up_proj",
# "gate_proj",
# "down_proj",
# ],
# lora_dropout=0.05,
# task_type="CAUSAL_LM",
# )
# codec_lm = get_peft_model(codec_lm, lora_config)
# codec_lm.print_trainable_parameters()
else: else:
codec_lm = None codec_lm = None
@ -373,13 +345,6 @@ def get_parser():
default="/workspace/CosyVoice-300M-SFT", default="/workspace/CosyVoice-300M-SFT",
help="The path to the token2wav model", help="The path to the token2wav model",
) )
# parser.add_argument(
# "--dataset",
# type=str,
# default="aishell",
# choices=["aishell", "speechio", "wenetspeech_test_meeting", "multi_hans_zh"],
# help="The dataset to decode",
# )
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -474,12 +439,6 @@ def decode_one_batch(
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]] chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
# messages = [
# [
# {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
# {"role": "assistant", "content": ""},
# ]
# ] * len(feature)
questions_with_history = [ questions_with_history = [
cut.custom["question"] for cut in batch["supervisions"]["cut"] cut.custom["question"] for cut in batch["supervisions"]["cut"]
] ]
@ -496,7 +455,6 @@ def decode_one_batch(
history_question_answer = history_contexts[i].split("USER:") history_question_answer = history_contexts[i].split("USER:")
history_question_answer = [item for item in history_question_answer if item] history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1): for j in range(total_round - 1):
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
question_answer = history_question_answer[j].split("ASSISTANT:") question_answer = history_question_answer[j].split("ASSISTANT:")
message += [ message += [
{"role": "user", "content": question_answer[0].strip()}, {"role": "user", "content": question_answer[0].strip()},
@ -504,7 +462,6 @@ def decode_one_batch(
] ]
message += [ message += [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"}, {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
# {"role": "user", "content": f"{last_questions[i]}"},
{"role": "assistant", "content": ""}, {"role": "assistant", "content": ""},
] ]
print(f"message: {message}, batch_size {len(chat_rounds)}") print(f"message: {message}, batch_size {len(chat_rounds)}")
@ -525,13 +482,6 @@ def decode_one_batch(
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0) audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model) audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050) sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050)
# with open(speech_token_file_name, 'w') as f:
# # save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
# #torchaudio.save(save_path, speech_output.cpu(), 16000)
# # print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
# save_str = " ".join([str(i) for i in generated_speech_output])
# f.write(f"{cut_id}|{save_str}\n")
else: else:
generated_ids = model.decode( generated_ids = model.decode(
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device) feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
@ -560,43 +510,6 @@ def decode_dataset(
Returns: Returns:
Return a dict, whose key may be "beam-search". Return a dict, whose key may be "beam-search".
""" """
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
results = [] results = []
num_cuts = 0 num_cuts = 0
@ -634,7 +547,6 @@ def decode_dataset(
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
# ref_text = normalize_text_alimeeting(ref_text)
ref_words = ref_text.split() ref_words = ref_text.split()
print(f"ref: {ref_text}") print(f"ref: {ref_text}")
print(f"hyp: {''.join(hyp_words)}") print(f"hyp: {''.join(hyp_words)}")
@ -673,7 +585,6 @@ def save_results(
errs_filename = ( errs_filename = (
params.log_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" params.log_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
) )
# we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
@ -732,11 +643,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
data_module = AsrDataModule(args) data_module = AsrDataModule(args)
# data_module = MultiDataset(args.manifest_dir)
def remove_long_utt(c: Cut): def remove_long_utt(c: Cut):
# Keep only utterances with duration in 30 seconds # Keep only utterances with duration in 30 seconds
@ -748,13 +656,6 @@ def main():
return False return False
return True return True
# if params.dataset == "aishell":
# test_sets_cuts = data_module.aishell_test_cuts()
# elif params.dataset == "speechio":
# test_sets_cuts = data_module.speechio_test_cuts()
# elif params.dataset == "wenetspeech_test_meeting":
# test_sets_cuts = data_module.wenetspeech_test_meeting_cuts()
# else:
test_sets_cuts = data_module.test_cuts() test_sets_cuts = data_module.test_cuts()
test_sets = test_sets_cuts.keys() test_sets = test_sets_cuts.keys()

View File

@ -1,4 +1,4 @@
from typing import List, Tuple # Added for type hints from typing import List, Tuple
import torch import torch
from torch import nn from torch import nn
@ -78,7 +78,6 @@ class SPEECH_LLM(nn.Module):
self.codec_lm_head = nn.Linear( self.codec_lm_head = nn.Linear(
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
) )
# to torch.float16
self.speech_token_projector = self.speech_token_projector.to( self.speech_token_projector = self.speech_token_projector.to(
dtype=torch.float16 dtype=torch.float16
) )
@ -498,20 +497,6 @@ class SPEECH_LLM(nn.Module):
pad_token_id=self.llm.config.pad_token_id, pad_token_id=self.llm.config.pad_token_id,
) )
# generated_ids = self.llm.generate(
# inputs_embeds=inputs_embeds,
# max_new_tokens=kwargs.get("max_new_tokens", 200),
# num_beams=kwargs.get("num_beams", 1),
# do_sample=kwargs.get("do_sample", False),
# min_length=kwargs.get("min_length", 1),
# top_p=kwargs.get("top_p", 1.0),
# repetition_penalty=kwargs.get("repetition_penalty", 1.0),
# temperature=kwargs.get("temperature", 1.0),
# length_penalty=kwargs.get("length_penalty", 1.0),
# bos_token_id=self.llm.config.bos_token_id,
# eos_token_id=self.llm.config.eos_token_id,
# pad_token_id=self.llm.config.pad_token_id,
# )
return generated_ids return generated_ids
def decode_with_speech_output( def decode_with_speech_output(
@ -520,7 +505,7 @@ class SPEECH_LLM(nn.Module):
input_ids: torch.LongTensor = None, # Prompt input_ids input_ids: torch.LongTensor = None, # Prompt input_ids
attention_mask: torch.Tensor = None, # Prompt attention_mask attention_mask: torch.Tensor = None, # Prompt attention_mask
max_text_new_tokens: int = 1024, max_text_new_tokens: int = 1024,
max_speech_new_tokens: int = 1024, # Max length for speech tokens max_speech_new_tokens: int = 2048, # Max length for speech tokens
llm_kwargs: dict = None, # Kwargs for text LLM generate llm_kwargs: dict = None, # Kwargs for text LLM generate
codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
) -> Tuple[torch.LongTensor, List[List[int]]]: ) -> Tuple[torch.LongTensor, List[List[int]]]:
@ -602,7 +587,7 @@ class SPEECH_LLM(nn.Module):
eos_token_id = self.llm.config.eos_token_id eos_token_id = self.llm.config.eos_token_id
eos_token_embedding = self.llm.get_input_embeddings()( eos_token_embedding = self.llm.get_input_embeddings()(
torch.tensor([[eos_token_id]], device=device) torch.tensor([[eos_token_id]], device=device)
) # 1,D )
assert ( assert (
generated_text_ids[0, -1] == eos_token_id generated_text_ids[0, -1] == eos_token_id
), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}" ), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
@ -610,7 +595,7 @@ class SPEECH_LLM(nn.Module):
token_hidden_states[0].to(self.llm.device) token_hidden_states[0].to(self.llm.device)
for token_hidden_states in text_outputs.hidden_states for token_hidden_states in text_outputs.hidden_states
] ]
# shift one for thinker token_embeds, drop the first embeds, and add the eos token
first_thinker_token_embed = torch.cat( first_thinker_token_embed = torch.cat(
[ [
thinker_token_embeds_org[0][:, 1:], thinker_token_embeds_org[0][:, 1:],
@ -628,7 +613,7 @@ class SPEECH_LLM(nn.Module):
token_hidden_states[-1].to(self.llm.device) token_hidden_states[-1].to(self.llm.device)
for token_hidden_states in text_outputs.hidden_states for token_hidden_states in text_outputs.hidden_states
] ]
# thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
thinker_reply_part = [ thinker_reply_part = [
torch.cat( torch.cat(
[ [
@ -651,12 +636,8 @@ class SPEECH_LLM(nn.Module):
dim=-1, dim=-1,
) )
thinker_prompt_part = self.speech_token_projector( thinker_prompt_part = self.speech_token_projector(thinker_prompt_part)
thinker_prompt_part thinker_reply_part = self.speech_token_projector(thinker_reply_part)
) # [B, S_full, D_codec]
thinker_reply_part = self.speech_token_projector(
thinker_reply_part
) # [B, S_full, D_codec]
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1] thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
talker_input_ids = torch.full( talker_input_ids = torch.full(
@ -666,9 +647,7 @@ class SPEECH_LLM(nn.Module):
device=self.llm.device, device=self.llm.device,
) )
talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id
talker_inputs_embeds = self.codec_lm.get_input_embeddings()( talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids)
talker_input_ids
) # [B, S_full, D_codec]
thinker_input_embeds = torch.cat( thinker_input_embeds = torch.cat(
[ [
thinker_prompt_part, thinker_prompt_part,
@ -677,68 +656,43 @@ class SPEECH_LLM(nn.Module):
dim=1, dim=1,
) )
talker_inputs_embeds += thinker_input_embeds talker_inputs_embeds += thinker_input_embeds
thinker_reply_part = thinker_reply_part[ thinker_reply_part = thinker_reply_part[:, delay_step + 1 :, :]
:, delay_step + 1 :, :
] # [B, S_full, D_codec]
past_key_values = None past_key_values = None
# generated_speech_tokens_list = [[] for _ in range(batch_size)]
# unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
generated_speech_tokens_list = [] generated_speech_tokens_list = []
next_token_ids = None next_token_ids = None
# text_context_len = projected_text_embeds.shape[1] # S_full
for t in range(max_speech_new_tokens): for t in range(max_speech_new_tokens):
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
if t > 0: if t > 0:
talker_inputs_embeds = self.codec_lm.get_input_embeddings()( talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
next_token_ids next_token_ids
) # [B, 1, D_codec] )
if thinker_reply_part.shape[1] > 0: if thinker_reply_part.shape[1] > 0:
talker_inputs_embeds += thinker_reply_part[:, :1, :] talker_inputs_embeds += thinker_reply_part[:, :1, :]
thinker_reply_part = thinker_reply_part[ thinker_reply_part = thinker_reply_part[:, 1:, :]
:, 1:, :
] # Remove the first token for next step
# # Add the projected text embedding corresponding to the current timestep `t`
# if t < text_context_len:
# # Text context from the full generated text sequence
# current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec]
# inputs_embeds = current_speech_embeds + current_text_context_embed
# else:
# # No more text context to add
# inputs_embeds = current_speech_embeds
# Forward pass through codec LM for one step
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
codec_outputs = self.codec_lm( codec_outputs = self.codec_lm(
inputs_embeds=talker_inputs_embeds, # Combined embedding for this step inputs_embeds=talker_inputs_embeds,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
return_dict=True, return_dict=True,
output_hidden_states=True, output_hidden_states=True,
# No attention mask needed here when using past_key_values and single token input
) )
last_token_hidden_state = codec_outputs.hidden_states[-1][ last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :]
:, -1, : next_token_logits = self.codec_lm_head(last_token_hidden_state)
] # [B, D_codec] #TODO: check shape here
# Get logits for the *last* token generated in this step
next_token_logits = self.codec_lm_head(
last_token_hidden_state
) # Use -1 index
# suppress tokens between 4096:len(vocab)-3
# next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
next_token_ids = topk_sampling( next_token_ids = topk_sampling(
next_token_logits, next_token_logits,
) )
# print(next_token_ids, "next_token_ids", t, next_token_ids.shape)
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id: if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
break break
# current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
past_key_values = codec_outputs.past_key_values # Update KV cache past_key_values = codec_outputs.past_key_values # Update KV cache
generated_speech_tokens_list.append( generated_speech_tokens_list.append(
next_token_ids.squeeze(1).cpu().tolist()[0] next_token_ids.squeeze(1).cpu().tolist()[0]
) )
# --- 6. Return Results ---
return generated_text_ids, generated_speech_tokens_list return generated_text_ids, generated_speech_tokens_list

View File

@ -17,28 +17,22 @@
# limitations under the License. # limitations under the License.
""" """
Usage: Usage:
# fine-tuning with whisper and Qwen2 # For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
pip install huggingface_hub['cli'] huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
mkdir -p models/whisper models/qwen # Qwen Pretrained model
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
# For aishell fine-tuned whisper model torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt --max-duration 50 \
# For multi-hans fine-tuned whisper model --enable-musan False \
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt --exp-dir $exp_dir \
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct --llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--deepspeed \ --deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --deepspeed_config ./qwen_omni/ds_config_zero1.json \
--use-flash-attn True \ --use-flash-attn True \
--use-lora True --unfreeze-llm True --use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
""" """
import argparse import argparse
@ -52,7 +46,6 @@ from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed import deepspeed
import k2
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
@ -66,8 +59,6 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
# from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model from peft import LoraConfig, get_peft_model
from torch import Tensor from torch import Tensor
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -330,9 +321,6 @@ def compute_loss(
truncation=False, truncation=False,
) )
) )
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
# remove too long text
# texts = [ text for text in texts if len(text) < 1024 ]
if len(texts) != len(messages): if len(texts) != len(messages):
logging.warning(f"Remove too long text, {messages} ") logging.warning(f"Remove too long text, {messages} ")
max_len_texts = max([len(text) for text in texts]) max_len_texts = max([len(text) for text in texts])
@ -347,7 +335,7 @@ def compute_loss(
for text in texts for text in texts
] ]
input_ids = torch.tensor(texts, dtype=torch.int) input_ids = torch.tensor(texts, dtype=torch.int)
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
target_ids = input_ids.clone() target_ids = input_ids.clone()
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID # mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
@ -396,8 +384,6 @@ def compute_loss(
history_contexts = [ history_contexts = [
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
] ]
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。<USER>: 告诉我如何烹饪鸡肉
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。
messages = [] messages = []
for i, total_round in enumerate(chat_rounds): for i, total_round in enumerate(chat_rounds):
@ -406,7 +392,6 @@ def compute_loss(
history_question_answer = history_contexts[i].split("USER:") history_question_answer = history_contexts[i].split("USER:")
history_question_answer = [item for item in history_question_answer if item] history_question_answer = [item for item in history_question_answer if item]
for j in range(total_round - 1): for j in range(total_round - 1):
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
question_answer = history_question_answer[j].split("ASSISTANT:") question_answer = history_question_answer[j].split("ASSISTANT:")
message += [ message += [
{"role": "user", "content": question_answer[0].strip()}, {"role": "user", "content": question_answer[0].strip()},
@ -683,7 +668,6 @@ def run(rank, world_size, args):
if params.use_flash_attn: if params.use_flash_attn:
attn_implementation = "flash_attention_2" attn_implementation = "flash_attention_2"
# torch_dtype=torch.bfloat16 FIX ME
torch_dtype = torch.float16 torch_dtype = torch.float16
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
@ -724,14 +708,6 @@ def run(rank, world_size, args):
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]} special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
tokenizer.add_special_tokens(special_tokens_dict) tokenizer.add_special_tokens(special_tokens_dict)
# original_tokenizer_vocab_size = len(tokenizer)
# cosyvoice2_token_size = 6561
# new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
# "<|SPEECH_GENERATION_START|>"
# ]
# num_added_tokens = tokenizer.add_tokens(new_tokens)
# model.resize_token_embeddings(len(tokenizer))
# model.vocab_size = len(tokenizer)
llm.config.pad_token_id = tokenizer.pad_token_id llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
@ -755,11 +731,6 @@ def run(rank, world_size, args):
attn_implementation = "eager" attn_implementation = "eager"
torch_dtype = torch.float16 torch_dtype = torch.float16
# codec_lm = AutoModelForCausalLM.from_pretrained(
# params.llm_path_or_name,
# attn_implementation=attn_implementation,
# torch_dtype=torch_dtype,
# )
codec_vocab_size = 4096 + 4 codec_vocab_size = 4096 + 4
# TODO: modify above vocab size or supress_tokens when decoding # TODO: modify above vocab size or supress_tokens when decoding
config = Qwen2Config( config = Qwen2Config(
@ -771,39 +742,19 @@ def run(rank, world_size, args):
intermediate_size=2048, intermediate_size=2048,
max_position_embeddings=4096, max_position_embeddings=4096,
) )
# codec_lm = Qwen2ForCausalLM(config=config)
# Pass attn_implementation and torch_dtype to the constructor
# Use AutoModelForCausalLM.from_config for more generality
codec_lm = AutoModelForCausalLM.from_config( codec_lm = AutoModelForCausalLM.from_config(
config=config, config=config,
attn_implementation=attn_implementation, attn_implementation=attn_implementation,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
) )
# cosyvoice2_token_size = 6561
codec_lm.resize_token_embeddings(codec_vocab_size) codec_lm.resize_token_embeddings(codec_vocab_size)
codec_lm.vocab_size = codec_vocab_size codec_lm.vocab_size = codec_vocab_size
codec_lm.config.pad_token_id = codec_vocab_size - 1 codec_lm.config.pad_token_id = codec_vocab_size - 1
codec_lm.config.eos_token_id = codec_vocab_size - 2 codec_lm.config.eos_token_id = codec_vocab_size - 2
codec_lm.config.bos_token_id = codec_vocab_size - 3 codec_lm.config.bos_token_id = codec_vocab_size - 3
codec_lm.config.mask_token_id = codec_vocab_size - 4 codec_lm.config.mask_token_id = codec_vocab_size - 4
# if params.use_lora:
# lora_config = LoraConfig(
# r=64,
# lora_alpha=16,
# target_modules=[
# "q_proj",
# "k_proj",
# "v_proj",
# "o_proj",
# "up_proj",
# "gate_proj",
# "down_proj",
# ],
# lora_dropout=0.05,
# task_type="CAUSAL_LM",
# )
# codec_lm = get_peft_model(codec_lm, lora_config)
# codec_lm.print_trainable_parameters()
else: else:
codec_lm = None codec_lm = None
@ -856,7 +807,6 @@ def run(rank, world_size, args):
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# ) # )
return False return False
# cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]
codec_len = len(c.custom["answer_cosyvoice_speech_token"]) codec_len = len(c.custom["answer_cosyvoice_speech_token"])
if codec_len > 2200: if codec_len > 2200:
logging.warning( logging.warning(
@ -873,7 +823,7 @@ def run(rank, world_size, args):
if params.sampler_state_dict_path: if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path) sampler_state_dict = torch.load(params.sampler_state_dict_path)
sampler_state_dict["max_duration"] = params.max_duration sampler_state_dict["max_duration"] = params.max_duration
# TODO: load sampler state dict
train_dl = data_module.train_dataloaders( train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )