mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
lint
This commit is contained in:
parent
360f0aa397
commit
11bd3c9ad8
@ -23,67 +23,33 @@ The following table lists the folders for different tasks.
|
|||||||
|
|
||||||
Command for training is:
|
Command for training is:
|
||||||
```bash
|
```bash
|
||||||
pip install -r whisper_llm_zh/requirements.txt
|
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||||
|
--max-duration 50 \
|
||||||
pip install huggingface_hub['cli']
|
--enable-musan False \
|
||||||
mkdir -p models/whisper models/qwen
|
--exp-dir $exp_dir \
|
||||||
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
# For aishell fine-tuned whisper model
|
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
--manifest-dir data/fbank \
|
||||||
# For multi-hans fine-tuned whisper model
|
--deepspeed \
|
||||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||||
|
--use-flash-attn True \
|
||||||
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
|
|
||||||
|
|
||||||
# First, we only train the projector and freeze other modules.
|
|
||||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
|
||||||
--max-duration 200 \
|
|
||||||
--exp-dir ./whisper_llm_zh/exp_test \
|
|
||||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
|
||||||
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
|
||||||
--manifest-dir data/fbank \
|
|
||||||
--deepspeed \
|
|
||||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
|
||||||
--use-flash-attn True \
|
|
||||||
--use-lora False --unfreeze-llm False
|
|
||||||
|
|
||||||
# Then we jointly train the projector and LLM LoRA modules.
|
|
||||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
|
||||||
--max-duration 200 \
|
|
||||||
--exp-dir ./whisper_llm_zh/exp_test \
|
|
||||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
|
||||||
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
|
||||||
--manifest-dir data/fbank \
|
|
||||||
--deepspeed \
|
|
||||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
|
||||||
--use-flash-attn True \
|
|
||||||
--use-lora True --unfreeze-llm True
|
|
||||||
--pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Command for decoding:
|
Command for decoding is:
|
||||||
```bash
|
```bash
|
||||||
mkdir -p models/whisper models/qwen models/checkpoint
|
python3 ./qwen_omni/decode.py \
|
||||||
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
|
--max-duration 1 \
|
||||||
|
--exp-dir $exp_dir \
|
||||||
# For aishell fine-tuned whisper model
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
# For multi-hans fine-tuned whisper model
|
--epoch 999 --avg 1 \
|
||||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
--manifest-dir data/fbank \
|
||||||
|
--use-flash-attn True \
|
||||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
--method e2e-epoch10_speech2speech \
|
||||||
|
--enable-speech-output True \
|
||||||
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
|
--token2wav-path models/CosyVoice-300M-SFT \
|
||||||
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
|
--use-lora True
|
||||||
|
|
||||||
python3 ./whisper_llm_zh/decode.py \
|
|
||||||
--max-duration 80 \
|
|
||||||
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
|
|
||||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
|
||||||
--llm-path-or-name models/qwen \
|
|
||||||
--epoch 999 --avg 1 \
|
|
||||||
--manifest-dir data/fbank \
|
|
||||||
--use-flash-attn True \
|
|
||||||
--use-lora True --dataset aishell
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Please see [`prepare.sh`](./prepare.sh) for more details.
|
||||||
|
@ -165,7 +165,7 @@ def compute_fbank(args):
|
|||||||
storage_type=LilcomChunkyWriter,
|
storage_type=LilcomChunkyWriter,
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
cuts_path = f"{in_out_dir}/{args.prefix}_cuts.{idx}.jsonl.gz"
|
cuts_path = f"{in_out_dir}/cuts_{args.prefix}.{idx}.jsonl.gz"
|
||||||
logging.info(f"Saving to {cuts_path}")
|
logging.info(f"Saving to {cuts_path}")
|
||||||
# see https://github.com/lhotse-speech/lhotse/issues/1125
|
# see https://github.com/lhotse-speech/lhotse/issues/1125
|
||||||
cut_set.drop_recordings().to_file(cuts_path)
|
cut_set.drop_recordings().to_file(cuts_path)
|
||||||
|
@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||||
export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni
|
|
||||||
|
export PYTHONPATH=$PYTHONPATH:/workspace/icefall
|
||||||
|
|
||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
stage=$1
|
stage=$1
|
||||||
@ -19,18 +21,37 @@ log() {
|
|||||||
|
|
||||||
|
|
||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
log "stage 0: "
|
log "stage 0: Clone CosyVoice repo and install requirements inside the container"
|
||||||
#pip uninstall lhotse
|
# docker: ghcr.io/swivid/f5-tts:main
|
||||||
#cd /workspace/slam/lhotse
|
pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
|
||||||
#git config --global --add safe.directory /workspace/slam/lhotse
|
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git /workspace/CosyVoice
|
||||||
#pip install -e '.[dev]'
|
cd /workspace/CosyVoice
|
||||||
cd -
|
# If you failed to clone submodule due to network failures, please run following command until success
|
||||||
pip install -r slam_omni/requirements.txt
|
git submodule update --init --recursive
|
||||||
|
pip install -r qwen_omni/requirements.txt
|
||||||
|
pip install -r qwen_omni/requirements-cosyvoice.txt
|
||||||
|
|
||||||
|
# For Chinese only dataset, you can use the following command to download the Chinese fine-tuned whisper model.
|
||||||
|
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
|
||||||
|
# Cosyvoice pretrained model for speech token2wav module
|
||||||
|
huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
|
||||||
|
# Qwen Pretrained model
|
||||||
|
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
|
||||||
|
# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
|
||||||
|
huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
|
||||||
|
|
||||||
|
# For Gradio demo, we follow https://arxiv.org/abs/2412.15649 to use ASR model to decode the history speech as context.
|
||||||
|
pip install sherpa-onnx
|
||||||
|
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
|
||||||
|
if [ ! -d $model_path ]; then
|
||||||
|
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
|
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
||||||
|
|
||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
|
log "stage 1: Compute fbank feature from huggingface"
|
||||||
|
|
||||||
python3 local/compute_whisper_fbank.py \
|
python3 local/compute_whisper_fbank.py \
|
||||||
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
||||||
--out-dir data/fbank_test \
|
--out-dir data/fbank_test \
|
||||||
@ -39,26 +60,42 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
--prefix belle
|
--prefix belle
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
log "Stage 2: Combine features"
|
log "Stage 2: Combine features"
|
||||||
manifest_dir=data/fbank
|
manifest_dir=data/fbank
|
||||||
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
|
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
|
||||||
|
mv $manifest_dir/cuts_belle.00000.jsonl.gz ./
|
||||||
|
# exclude cust_belle_00000.jsonl.gz for valid and test set
|
||||||
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
|
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
|
||||||
# # remove cust_belle_00000.jsonl.gz from pieces
|
|
||||||
# pieces=$(echo $pieces | sed 's/cuts_belle.00000.jsonl.gz//g')
|
|
||||||
echo $pieces | wc
|
echo $pieces | wc
|
||||||
lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz
|
lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz
|
||||||
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz && cd -
|
mv ./cuts_belle.00000.jsonl.gz $manifest_dir # put it back
|
||||||
|
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz
|
||||||
|
ln -s cuts_belle.00000.jsonl.gz cuts_belle_test.jsonl.gz && cd -
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
ngpu=8
|
||||||
|
exp_dir=./qwen_omni/exp_speech2speech
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "stage 3: "
|
log "stage 3: Training Speech2Speech Model"
|
||||||
exp_dir=./slam_omni/exp_speech2speech_rerun
|
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||||
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
--max-duration 50 \
|
||||||
python3 ./slam_omni/decode.py \
|
--enable-musan False \
|
||||||
|
--exp-dir $exp_dir \
|
||||||
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
|
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||||
|
--manifest-dir data/fbank \
|
||||||
|
--deepspeed \
|
||||||
|
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
log "stage 4: Decoding, only support batch_size=1 for now."
|
||||||
|
cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
|
||||||
|
python3 ./qwen_omni/decode.py \
|
||||||
--max-duration 1 \
|
--max-duration 1 \
|
||||||
--exp-dir $exp_dir \
|
--exp-dir $exp_dir \
|
||||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
@ -66,78 +103,20 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
--epoch 999 --avg 1 \
|
--epoch 999 --avg 1 \
|
||||||
--manifest-dir data/fbank \
|
--manifest-dir data/fbank \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
--method e2e-epoch10_speech2speech_rerun \
|
--method e2e-epoch10_speech2speech \
|
||||||
--enable-speech-output True \
|
--enable-speech-output True \
|
||||||
--token2wav-path /workspace/CosyVoice-300M-SFT \
|
--token2wav-path models/CosyVoice-300M-SFT \
|
||||||
--use-lora True # --on-the-fly-feats True
|
--use-lora True
|
||||||
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|
||||||
log "stage 4: "
|
|
||||||
ngpu=8
|
|
||||||
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
|
||||||
--max-duration 80 \
|
|
||||||
--enable-musan False \
|
|
||||||
--exp-dir ./slam_omni/exp_speech2text \
|
|
||||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
|
||||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
|
||||||
--manifest-dir data/fbank \
|
|
||||||
--deepspeed \
|
|
||||||
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
|
||||||
--use-flash-attn True \
|
|
||||||
--pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
|
|
||||||
--sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \
|
|
||||||
--use-lora True --unfreeze-llm True
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "stage 5: "
|
log "stage 5: Gradio Demo"
|
||||||
ngpu=8
|
python3 ./qwen_omni/web_demo.py \
|
||||||
exp_dir=./slam_omni/exp_speech2speech_rerun
|
|
||||||
# exp_dir_new=./slam_omni/exp_s2s
|
|
||||||
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
|
||||||
--max-duration 50 \
|
|
||||||
--enable-musan False \
|
|
||||||
--exp-dir $exp_dir \
|
|
||||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
|
||||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
|
||||||
--manifest-dir data/fbank \
|
|
||||||
--deepspeed \
|
|
||||||
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
|
||||||
--use-flash-attn True \
|
|
||||||
--pretrained-model-path $exp_dir/epoch-1-checkpoint-15000.pt/pytorch_model.bin \
|
|
||||||
--sampler-state-dict-path $exp_dir/epoch-1-checkpoint-15000-sampler.pt \
|
|
||||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
|
||||||
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
|
|
||||||
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
|
|
||||||
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|
||||||
log "stage 6: "
|
|
||||||
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
|
||||||
exp_dir=./slam_omni/exp_speech2speech_rerun
|
|
||||||
python3 ./slam_omni/web_demo.py \
|
|
||||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
--checkpoint-path $exp_dir/epoch-998.pt \
|
--checkpoint-path $exp_dir/epoch-999.pt \
|
||||||
--use-flash-attn True \
|
--use-flash-attn True \
|
||||||
--enable-speech-output True \
|
--enable-speech-output True \
|
||||||
--asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \
|
--asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \
|
||||||
--use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share
|
--use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share
|
||||||
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
|
||||||
log "stage 7: "
|
|
||||||
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
|
|
||||||
|
|
||||||
if [ ! -d $model_path ]; then
|
|
||||||
pip install sherpa-onnx
|
|
||||||
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
|
||||||
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
|
|
||||||
fi
|
|
||||||
fi
|
fi
|
||||||
|
@ -116,28 +116,6 @@ class AsrDataModule:
|
|||||||
help="The number of buckets for the DynamicBucketingSampler"
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
"(you might want to increase it for larger datasets).",
|
"(you might want to increase it for larger datasets).",
|
||||||
)
|
)
|
||||||
# group.add_argument(
|
|
||||||
# "--concatenate-cuts",
|
|
||||||
# type=str2bool,
|
|
||||||
# default=False,
|
|
||||||
# help="When enabled, utterances (cuts) will be concatenated "
|
|
||||||
# "to minimize the amount of padding.",
|
|
||||||
# )
|
|
||||||
# group.add_argument(
|
|
||||||
# "--duration-factor",
|
|
||||||
# type=float,
|
|
||||||
# default=1.0,
|
|
||||||
# help="Determines the maximum duration of a concatenated cut "
|
|
||||||
# "relative to the duration of the longest cut in a batch.",
|
|
||||||
# )
|
|
||||||
# group.add_argument(
|
|
||||||
# "--gap",
|
|
||||||
# type=float,
|
|
||||||
# default=1.0,
|
|
||||||
# help="The amount of padding (in seconds) inserted between "
|
|
||||||
# "concatenated cuts. This padding is filled with noise when "
|
|
||||||
# "noise augmentation is used.",
|
|
||||||
# )
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--on-the-fly-feats",
|
"--on-the-fly-feats",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -256,20 +234,6 @@ class AsrDataModule:
|
|||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
|
|
||||||
# if self.args.concatenate_cuts:
|
|
||||||
# logging.info(
|
|
||||||
# f"Using cut concatenation with duration factor "
|
|
||||||
# f"{self.args.duration_factor} and gap {self.args.gap}."
|
|
||||||
# )
|
|
||||||
# # Cut concatenation should be the first transform in the list,
|
|
||||||
# # so that if we e.g. mix noise in, it will fill the gaps between
|
|
||||||
# # different utterances.
|
|
||||||
# transforms = [
|
|
||||||
# CutConcatenate(
|
|
||||||
# duration_factor=self.args.duration_factor, gap=self.args.gap
|
|
||||||
# )
|
|
||||||
# ] + transforms
|
|
||||||
|
|
||||||
input_transforms = []
|
input_transforms = []
|
||||||
if self.args.enable_spec_aug:
|
if self.args.enable_spec_aug:
|
||||||
logging.info("Enable SpecAugment")
|
logging.info("Enable SpecAugment")
|
||||||
@ -426,32 +390,12 @@ class AsrDataModule:
|
|||||||
def test_cuts(self) -> CutSet:
|
def test_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test cuts")
|
logging.info("About to get test cuts")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
# dataset = load_dataset(self.args.huggingface_dataset_path_or_name, streaming=True, split=partition)
|
pass
|
||||||
i, num_digits = 0, 5
|
|
||||||
idx = f"{i}".zfill(num_digits)
|
|
||||||
parquet_files = [
|
|
||||||
f"data/train-{idx}-of-01601.parquet",
|
|
||||||
]
|
|
||||||
parquet_files = [
|
|
||||||
f"{self.args.huggingface_dataset_path_or_name}/{f}"
|
|
||||||
for f in parquet_files
|
|
||||||
]
|
|
||||||
file_name = parquet_files[0]
|
|
||||||
logging.info(f"Loading dataset from {file_name}")
|
|
||||||
dataset = load_dataset(
|
|
||||||
"parquet", data_files=parquet_files, streaming=True, split="train"
|
|
||||||
)
|
|
||||||
cut_set = CutSet.from_huggingface_dataset(
|
|
||||||
dataset, audio_key=self.args.audio_key, text_key=self.args.text_key
|
|
||||||
)
|
|
||||||
if self.args.resample_to_16kHz:
|
|
||||||
cut_set = cut_set.resample(16000)
|
|
||||||
return {"test": cut_set}
|
|
||||||
else:
|
else:
|
||||||
# return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
|
|
||||||
# return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_test_small.jsonl.gz")}
|
|
||||||
return {
|
return {
|
||||||
"test": load_manifest_lazy("data/fbank_test/belle_cuts.00000.jsonl.gz")
|
"test": load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
@ -461,7 +405,7 @@ class AsrDataModule:
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "cuts_belle.00000.jsonl.gz"
|
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
|
@ -20,30 +20,27 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
# Command for decoding using fine-tuned models:
|
# Command for decoding using fine-tuned models:
|
||||||
|
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
|
||||||
|
# Cosyvoice pretrained model for speech token2wav module
|
||||||
|
huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
|
||||||
|
# Qwen Pretrained model
|
||||||
|
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
|
||||||
|
# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
|
||||||
|
huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
|
||||||
|
|
||||||
pip install huggingface_hub['cli']
|
cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
|
||||||
mkdir -p models/whisper models/qwen models/checkpoint
|
python3 ./qwen_omni/decode.py \
|
||||||
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
|
--max-duration 1 \
|
||||||
|
--exp-dir $exp_dir \
|
||||||
# For aishell fine-tuned whisper model
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
# For multi-hans fine-tuned whisper model
|
--epoch 999 --avg 1 \
|
||||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
--manifest-dir data/fbank \
|
||||||
|
--use-flash-attn True \
|
||||||
huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
--method e2e-epoch10_speech2speech \
|
||||||
|
--enable-speech-output True \
|
||||||
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
|
--token2wav-path models/CosyVoice-300M-SFT \
|
||||||
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
|
--use-lora True
|
||||||
|
|
||||||
python3 ./whisper_llm_zh/decode.py \
|
|
||||||
--max-duration 80 \
|
|
||||||
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
|
|
||||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
|
||||||
--llm-path-or-name models/qwen \
|
|
||||||
--epoch 999 --avg 1 \
|
|
||||||
--manifest-dir data/fbank \
|
|
||||||
--use-flash-attn True \
|
|
||||||
--use-lora True --dataset aishell
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -183,11 +180,6 @@ def get_model(params, device):
|
|||||||
attn_implementation = "eager"
|
attn_implementation = "eager"
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
|
|
||||||
# codec_lm = AutoModelForCausalLM.from_pretrained(
|
|
||||||
# params.llm_path_or_name,
|
|
||||||
# attn_implementation=attn_implementation,
|
|
||||||
# torch_dtype=torch_dtype,
|
|
||||||
# )
|
|
||||||
codec_vocab_size = 4096 + 4
|
codec_vocab_size = 4096 + 4
|
||||||
config = Qwen2Config(
|
config = Qwen2Config(
|
||||||
vocab_size=codec_vocab_size,
|
vocab_size=codec_vocab_size,
|
||||||
@ -198,39 +190,19 @@ def get_model(params, device):
|
|||||||
intermediate_size=2048,
|
intermediate_size=2048,
|
||||||
max_position_embeddings=4096,
|
max_position_embeddings=4096,
|
||||||
)
|
)
|
||||||
# codec_lm = Qwen2ForCausalLM(config=config)
|
|
||||||
# Pass attn_implementation and torch_dtype to the constructor
|
|
||||||
# Use AutoModelForCausalLM.from_config for more generality
|
|
||||||
codec_lm = AutoModelForCausalLM.from_config(
|
codec_lm = AutoModelForCausalLM.from_config(
|
||||||
config=config,
|
config=config,
|
||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
# cosyvoice2_token_size = 6561
|
|
||||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||||
codec_lm.vocab_size = codec_vocab_size
|
codec_lm.vocab_size = codec_vocab_size
|
||||||
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
||||||
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
||||||
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
||||||
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
||||||
# if params.use_lora:
|
|
||||||
# lora_config = LoraConfig(
|
|
||||||
# r=64,
|
|
||||||
# lora_alpha=16,
|
|
||||||
# target_modules=[
|
|
||||||
# "q_proj",
|
|
||||||
# "k_proj",
|
|
||||||
# "v_proj",
|
|
||||||
# "o_proj",
|
|
||||||
# "up_proj",
|
|
||||||
# "gate_proj",
|
|
||||||
# "down_proj",
|
|
||||||
# ],
|
|
||||||
# lora_dropout=0.05,
|
|
||||||
# task_type="CAUSAL_LM",
|
|
||||||
# )
|
|
||||||
# codec_lm = get_peft_model(codec_lm, lora_config)
|
|
||||||
# codec_lm.print_trainable_parameters()
|
|
||||||
else:
|
else:
|
||||||
codec_lm = None
|
codec_lm = None
|
||||||
|
|
||||||
@ -373,13 +345,6 @@ def get_parser():
|
|||||||
default="/workspace/CosyVoice-300M-SFT",
|
default="/workspace/CosyVoice-300M-SFT",
|
||||||
help="The path to the token2wav model",
|
help="The path to the token2wav model",
|
||||||
)
|
)
|
||||||
# parser.add_argument(
|
|
||||||
# "--dataset",
|
|
||||||
# type=str,
|
|
||||||
# default="aishell",
|
|
||||||
# choices=["aishell", "speechio", "wenetspeech_test_meeting", "multi_hans_zh"],
|
|
||||||
# help="The dataset to decode",
|
|
||||||
# )
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
return parser
|
return parser
|
||||||
@ -474,12 +439,6 @@ def decode_one_batch(
|
|||||||
|
|
||||||
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
|
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
# messages = [
|
|
||||||
# [
|
|
||||||
# {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
|
||||||
# {"role": "assistant", "content": ""},
|
|
||||||
# ]
|
|
||||||
# ] * len(feature)
|
|
||||||
questions_with_history = [
|
questions_with_history = [
|
||||||
cut.custom["question"] for cut in batch["supervisions"]["cut"]
|
cut.custom["question"] for cut in batch["supervisions"]["cut"]
|
||||||
]
|
]
|
||||||
@ -496,7 +455,6 @@ def decode_one_batch(
|
|||||||
history_question_answer = history_contexts[i].split("USER:")
|
history_question_answer = history_contexts[i].split("USER:")
|
||||||
history_question_answer = [item for item in history_question_answer if item]
|
history_question_answer = [item for item in history_question_answer if item]
|
||||||
for j in range(total_round - 1):
|
for j in range(total_round - 1):
|
||||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
|
||||||
question_answer = history_question_answer[j].split("ASSISTANT:")
|
question_answer = history_question_answer[j].split("ASSISTANT:")
|
||||||
message += [
|
message += [
|
||||||
{"role": "user", "content": question_answer[0].strip()},
|
{"role": "user", "content": question_answer[0].strip()},
|
||||||
@ -504,7 +462,6 @@ def decode_one_batch(
|
|||||||
]
|
]
|
||||||
message += [
|
message += [
|
||||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||||
# {"role": "user", "content": f"{last_questions[i]}"},
|
|
||||||
{"role": "assistant", "content": ""},
|
{"role": "assistant", "content": ""},
|
||||||
]
|
]
|
||||||
print(f"message: {message}, batch_size {len(chat_rounds)}")
|
print(f"message: {message}, batch_size {len(chat_rounds)}")
|
||||||
@ -525,13 +482,6 @@ def decode_one_batch(
|
|||||||
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
|
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
|
||||||
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
|
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
|
||||||
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050)
|
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050)
|
||||||
# with open(speech_token_file_name, 'w') as f:
|
|
||||||
# # save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
|
||||||
# #torchaudio.save(save_path, speech_output.cpu(), 16000)
|
|
||||||
# # print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
|
|
||||||
# save_str = " ".join([str(i) for i in generated_speech_output])
|
|
||||||
# f.write(f"{cut_id}|{save_str}\n")
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
generated_ids = model.decode(
|
generated_ids = model.decode(
|
||||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||||
@ -560,43 +510,6 @@ def decode_dataset(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "beam-search".
|
Return a dict, whose key may be "beam-search".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
|
||||||
"""
|
|
||||||
Text normalization similar to M2MeT challenge baseline.
|
|
||||||
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
|
||||||
"""
|
|
||||||
if normalize == "none":
|
|
||||||
return text
|
|
||||||
elif normalize == "m2met":
|
|
||||||
import re
|
|
||||||
|
|
||||||
text = text.replace(" ", "")
|
|
||||||
text = text.replace("<sil>", "")
|
|
||||||
text = text.replace("<%>", "")
|
|
||||||
text = text.replace("<->", "")
|
|
||||||
text = text.replace("<$>", "")
|
|
||||||
text = text.replace("<#>", "")
|
|
||||||
text = text.replace("<_>", "")
|
|
||||||
text = text.replace("<space>", "")
|
|
||||||
text = text.replace("`", "")
|
|
||||||
text = text.replace("&", "")
|
|
||||||
text = text.replace(",", "")
|
|
||||||
if re.search("[a-zA-Z]", text):
|
|
||||||
text = text.upper()
|
|
||||||
text = text.replace("A", "A")
|
|
||||||
text = text.replace("a", "A")
|
|
||||||
text = text.replace("b", "B")
|
|
||||||
text = text.replace("c", "C")
|
|
||||||
text = text.replace("k", "K")
|
|
||||||
text = text.replace("t", "T")
|
|
||||||
text = text.replace(",", "")
|
|
||||||
text = text.replace("丶", "")
|
|
||||||
text = text.replace("。", "")
|
|
||||||
text = text.replace("、", "")
|
|
||||||
text = text.replace("?", "")
|
|
||||||
return text
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
@ -634,7 +547,6 @@ def decode_dataset(
|
|||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
# ref_text = normalize_text_alimeeting(ref_text)
|
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
print(f"ref: {ref_text}")
|
print(f"ref: {ref_text}")
|
||||||
print(f"hyp: {''.join(hyp_words)}")
|
print(f"hyp: {''.join(hyp_words)}")
|
||||||
@ -673,7 +585,6 @@ def save_results(
|
|||||||
errs_filename = (
|
errs_filename = (
|
||||||
params.log_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.log_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
# we compute CER for aishell dataset.
|
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||||
@ -732,11 +643,8 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
# we need cut ids to display recognition results.
|
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
|
|
||||||
data_module = AsrDataModule(args)
|
data_module = AsrDataModule(args)
|
||||||
# data_module = MultiDataset(args.manifest_dir)
|
|
||||||
|
|
||||||
def remove_long_utt(c: Cut):
|
def remove_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration in 30 seconds
|
# Keep only utterances with duration in 30 seconds
|
||||||
@ -748,13 +656,6 @@ def main():
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# if params.dataset == "aishell":
|
|
||||||
# test_sets_cuts = data_module.aishell_test_cuts()
|
|
||||||
# elif params.dataset == "speechio":
|
|
||||||
# test_sets_cuts = data_module.speechio_test_cuts()
|
|
||||||
# elif params.dataset == "wenetspeech_test_meeting":
|
|
||||||
# test_sets_cuts = data_module.wenetspeech_test_meeting_cuts()
|
|
||||||
# else:
|
|
||||||
test_sets_cuts = data_module.test_cuts()
|
test_sets_cuts = data_module.test_cuts()
|
||||||
|
|
||||||
test_sets = test_sets_cuts.keys()
|
test_sets = test_sets_cuts.keys()
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Tuple # Added for type hints
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -78,7 +78,6 @@ class SPEECH_LLM(nn.Module):
|
|||||||
self.codec_lm_head = nn.Linear(
|
self.codec_lm_head = nn.Linear(
|
||||||
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
||||||
)
|
)
|
||||||
# to torch.float16
|
|
||||||
self.speech_token_projector = self.speech_token_projector.to(
|
self.speech_token_projector = self.speech_token_projector.to(
|
||||||
dtype=torch.float16
|
dtype=torch.float16
|
||||||
)
|
)
|
||||||
@ -498,20 +497,6 @@ class SPEECH_LLM(nn.Module):
|
|||||||
pad_token_id=self.llm.config.pad_token_id,
|
pad_token_id=self.llm.config.pad_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# generated_ids = self.llm.generate(
|
|
||||||
# inputs_embeds=inputs_embeds,
|
|
||||||
# max_new_tokens=kwargs.get("max_new_tokens", 200),
|
|
||||||
# num_beams=kwargs.get("num_beams", 1),
|
|
||||||
# do_sample=kwargs.get("do_sample", False),
|
|
||||||
# min_length=kwargs.get("min_length", 1),
|
|
||||||
# top_p=kwargs.get("top_p", 1.0),
|
|
||||||
# repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
|
||||||
# temperature=kwargs.get("temperature", 1.0),
|
|
||||||
# length_penalty=kwargs.get("length_penalty", 1.0),
|
|
||||||
# bos_token_id=self.llm.config.bos_token_id,
|
|
||||||
# eos_token_id=self.llm.config.eos_token_id,
|
|
||||||
# pad_token_id=self.llm.config.pad_token_id,
|
|
||||||
# )
|
|
||||||
return generated_ids
|
return generated_ids
|
||||||
|
|
||||||
def decode_with_speech_output(
|
def decode_with_speech_output(
|
||||||
@ -520,7 +505,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
input_ids: torch.LongTensor = None, # Prompt input_ids
|
input_ids: torch.LongTensor = None, # Prompt input_ids
|
||||||
attention_mask: torch.Tensor = None, # Prompt attention_mask
|
attention_mask: torch.Tensor = None, # Prompt attention_mask
|
||||||
max_text_new_tokens: int = 1024,
|
max_text_new_tokens: int = 1024,
|
||||||
max_speech_new_tokens: int = 1024, # Max length for speech tokens
|
max_speech_new_tokens: int = 2048, # Max length for speech tokens
|
||||||
llm_kwargs: dict = None, # Kwargs for text LLM generate
|
llm_kwargs: dict = None, # Kwargs for text LLM generate
|
||||||
codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
|
codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
|
||||||
) -> Tuple[torch.LongTensor, List[List[int]]]:
|
) -> Tuple[torch.LongTensor, List[List[int]]]:
|
||||||
@ -602,7 +587,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
eos_token_id = self.llm.config.eos_token_id
|
eos_token_id = self.llm.config.eos_token_id
|
||||||
eos_token_embedding = self.llm.get_input_embeddings()(
|
eos_token_embedding = self.llm.get_input_embeddings()(
|
||||||
torch.tensor([[eos_token_id]], device=device)
|
torch.tensor([[eos_token_id]], device=device)
|
||||||
) # 1,D
|
)
|
||||||
assert (
|
assert (
|
||||||
generated_text_ids[0, -1] == eos_token_id
|
generated_text_ids[0, -1] == eos_token_id
|
||||||
), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
|
), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
|
||||||
@ -610,7 +595,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
token_hidden_states[0].to(self.llm.device)
|
token_hidden_states[0].to(self.llm.device)
|
||||||
for token_hidden_states in text_outputs.hidden_states
|
for token_hidden_states in text_outputs.hidden_states
|
||||||
]
|
]
|
||||||
# shift one for thinker token_embeds, drop the first embeds, and add the eos token
|
|
||||||
first_thinker_token_embed = torch.cat(
|
first_thinker_token_embed = torch.cat(
|
||||||
[
|
[
|
||||||
thinker_token_embeds_org[0][:, 1:],
|
thinker_token_embeds_org[0][:, 1:],
|
||||||
@ -628,7 +613,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
token_hidden_states[-1].to(self.llm.device)
|
token_hidden_states[-1].to(self.llm.device)
|
||||||
for token_hidden_states in text_outputs.hidden_states
|
for token_hidden_states in text_outputs.hidden_states
|
||||||
]
|
]
|
||||||
# thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
|
|
||||||
thinker_reply_part = [
|
thinker_reply_part = [
|
||||||
torch.cat(
|
torch.cat(
|
||||||
[
|
[
|
||||||
@ -651,12 +636,8 @@ class SPEECH_LLM(nn.Module):
|
|||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
thinker_prompt_part = self.speech_token_projector(
|
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part)
|
||||||
thinker_prompt_part
|
thinker_reply_part = self.speech_token_projector(thinker_reply_part)
|
||||||
) # [B, S_full, D_codec]
|
|
||||||
thinker_reply_part = self.speech_token_projector(
|
|
||||||
thinker_reply_part
|
|
||||||
) # [B, S_full, D_codec]
|
|
||||||
|
|
||||||
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
|
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
|
||||||
talker_input_ids = torch.full(
|
talker_input_ids = torch.full(
|
||||||
@ -666,9 +647,7 @@ class SPEECH_LLM(nn.Module):
|
|||||||
device=self.llm.device,
|
device=self.llm.device,
|
||||||
)
|
)
|
||||||
talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id
|
talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id
|
||||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
|
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids)
|
||||||
talker_input_ids
|
|
||||||
) # [B, S_full, D_codec]
|
|
||||||
thinker_input_embeds = torch.cat(
|
thinker_input_embeds = torch.cat(
|
||||||
[
|
[
|
||||||
thinker_prompt_part,
|
thinker_prompt_part,
|
||||||
@ -677,68 +656,43 @@ class SPEECH_LLM(nn.Module):
|
|||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
talker_inputs_embeds += thinker_input_embeds
|
talker_inputs_embeds += thinker_input_embeds
|
||||||
thinker_reply_part = thinker_reply_part[
|
thinker_reply_part = thinker_reply_part[:, delay_step + 1 :, :]
|
||||||
:, delay_step + 1 :, :
|
|
||||||
] # [B, S_full, D_codec]
|
|
||||||
|
|
||||||
past_key_values = None
|
past_key_values = None
|
||||||
# generated_speech_tokens_list = [[] for _ in range(batch_size)]
|
|
||||||
# unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
|
|
||||||
generated_speech_tokens_list = []
|
generated_speech_tokens_list = []
|
||||||
next_token_ids = None
|
next_token_ids = None
|
||||||
# text_context_len = projected_text_embeds.shape[1] # S_full
|
|
||||||
for t in range(max_speech_new_tokens):
|
for t in range(max_speech_new_tokens):
|
||||||
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
|
|
||||||
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
|
|
||||||
if t > 0:
|
if t > 0:
|
||||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
|
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
|
||||||
next_token_ids
|
next_token_ids
|
||||||
) # [B, 1, D_codec]
|
)
|
||||||
if thinker_reply_part.shape[1] > 0:
|
if thinker_reply_part.shape[1] > 0:
|
||||||
talker_inputs_embeds += thinker_reply_part[:, :1, :]
|
talker_inputs_embeds += thinker_reply_part[:, :1, :]
|
||||||
thinker_reply_part = thinker_reply_part[
|
thinker_reply_part = thinker_reply_part[:, 1:, :]
|
||||||
:, 1:, :
|
|
||||||
] # Remove the first token for next step
|
|
||||||
# # Add the projected text embedding corresponding to the current timestep `t`
|
|
||||||
# if t < text_context_len:
|
|
||||||
# # Text context from the full generated text sequence
|
|
||||||
# current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec]
|
|
||||||
# inputs_embeds = current_speech_embeds + current_text_context_embed
|
|
||||||
# else:
|
|
||||||
# # No more text context to add
|
|
||||||
# inputs_embeds = current_speech_embeds
|
|
||||||
|
|
||||||
# Forward pass through codec LM for one step
|
|
||||||
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
|
|
||||||
codec_outputs = self.codec_lm(
|
codec_outputs = self.codec_lm(
|
||||||
inputs_embeds=talker_inputs_embeds, # Combined embedding for this step
|
inputs_embeds=talker_inputs_embeds,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
# No attention mask needed here when using past_key_values and single token input
|
|
||||||
)
|
)
|
||||||
last_token_hidden_state = codec_outputs.hidden_states[-1][
|
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :]
|
||||||
:, -1, :
|
next_token_logits = self.codec_lm_head(last_token_hidden_state)
|
||||||
] # [B, D_codec] #TODO: check shape here
|
|
||||||
# Get logits for the *last* token generated in this step
|
|
||||||
next_token_logits = self.codec_lm_head(
|
|
||||||
last_token_hidden_state
|
|
||||||
) # Use -1 index
|
|
||||||
# suppress tokens between 4096:len(vocab)-3
|
|
||||||
# next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
|
|
||||||
next_token_ids = topk_sampling(
|
next_token_ids = topk_sampling(
|
||||||
next_token_logits,
|
next_token_logits,
|
||||||
)
|
)
|
||||||
# print(next_token_ids, "next_token_ids", t, next_token_ids.shape)
|
|
||||||
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
|
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
|
||||||
break
|
break
|
||||||
# current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
|
|
||||||
past_key_values = codec_outputs.past_key_values # Update KV cache
|
past_key_values = codec_outputs.past_key_values # Update KV cache
|
||||||
generated_speech_tokens_list.append(
|
generated_speech_tokens_list.append(
|
||||||
next_token_ids.squeeze(1).cpu().tolist()[0]
|
next_token_ids.squeeze(1).cpu().tolist()[0]
|
||||||
)
|
)
|
||||||
# --- 6. Return Results ---
|
|
||||||
return generated_text_ids, generated_speech_tokens_list
|
return generated_text_ids, generated_speech_tokens_list
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,28 +17,22 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
# fine-tuning with whisper and Qwen2
|
# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
|
||||||
pip install huggingface_hub['cli']
|
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
|
||||||
mkdir -p models/whisper models/qwen
|
# Qwen Pretrained model
|
||||||
|
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
|
||||||
|
|
||||||
# For aishell fine-tuned whisper model
|
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
--max-duration 50 \
|
||||||
# For multi-hans fine-tuned whisper model
|
--enable-musan False \
|
||||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
--exp-dir $exp_dir \
|
||||||
|
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||||
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
|
--manifest-dir data/fbank \
|
||||||
|
--deepspeed \
|
||||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||||
--max-duration 200 \
|
--use-flash-attn True \
|
||||||
--exp-dir ./whisper_llm_zh/exp_test \
|
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
|
||||||
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
|
||||||
--manifest-dir data/fbank \
|
|
||||||
--deepspeed \
|
|
||||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
|
||||||
--use-flash-attn True \
|
|
||||||
--use-lora True --unfreeze-llm True
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -52,7 +46,6 @@ from shutil import copyfile
|
|||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import deepspeed
|
import deepspeed
|
||||||
import k2
|
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -66,8 +59,6 @@ from lhotse.cut import Cut
|
|||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
||||||
|
|
||||||
# from multi_dataset import MultiDataset
|
|
||||||
from peft import LoraConfig, get_peft_model
|
from peft import LoraConfig, get_peft_model
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
@ -330,9 +321,6 @@ def compute_loss(
|
|||||||
truncation=False,
|
truncation=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
|
|
||||||
# remove too long text
|
|
||||||
# texts = [ text for text in texts if len(text) < 1024 ]
|
|
||||||
if len(texts) != len(messages):
|
if len(texts) != len(messages):
|
||||||
logging.warning(f"Remove too long text, {messages} ")
|
logging.warning(f"Remove too long text, {messages} ")
|
||||||
max_len_texts = max([len(text) for text in texts])
|
max_len_texts = max([len(text) for text in texts])
|
||||||
@ -347,7 +335,7 @@ def compute_loss(
|
|||||||
for text in texts
|
for text in texts
|
||||||
]
|
]
|
||||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||||
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
|
||||||
target_ids = input_ids.clone()
|
target_ids = input_ids.clone()
|
||||||
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||||
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
|
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
|
||||||
@ -396,8 +384,6 @@ def compute_loss(
|
|||||||
history_contexts = [
|
history_contexts = [
|
||||||
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
|
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
|
||||||
]
|
]
|
||||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。<USER>: 告诉我如何烹饪鸡肉
|
|
||||||
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。
|
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
for i, total_round in enumerate(chat_rounds):
|
for i, total_round in enumerate(chat_rounds):
|
||||||
@ -406,7 +392,6 @@ def compute_loss(
|
|||||||
history_question_answer = history_contexts[i].split("USER:")
|
history_question_answer = history_contexts[i].split("USER:")
|
||||||
history_question_answer = [item for item in history_question_answer if item]
|
history_question_answer = [item for item in history_question_answer if item]
|
||||||
for j in range(total_round - 1):
|
for j in range(total_round - 1):
|
||||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
|
||||||
question_answer = history_question_answer[j].split("ASSISTANT:")
|
question_answer = history_question_answer[j].split("ASSISTANT:")
|
||||||
message += [
|
message += [
|
||||||
{"role": "user", "content": question_answer[0].strip()},
|
{"role": "user", "content": question_answer[0].strip()},
|
||||||
@ -683,7 +668,6 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.use_flash_attn:
|
if params.use_flash_attn:
|
||||||
attn_implementation = "flash_attention_2"
|
attn_implementation = "flash_attention_2"
|
||||||
# torch_dtype=torch.bfloat16 FIX ME
|
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
@ -724,14 +708,6 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
||||||
tokenizer.add_special_tokens(special_tokens_dict)
|
tokenizer.add_special_tokens(special_tokens_dict)
|
||||||
# original_tokenizer_vocab_size = len(tokenizer)
|
|
||||||
# cosyvoice2_token_size = 6561
|
|
||||||
# new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
|
|
||||||
# "<|SPEECH_GENERATION_START|>"
|
|
||||||
# ]
|
|
||||||
# num_added_tokens = tokenizer.add_tokens(new_tokens)
|
|
||||||
# model.resize_token_embeddings(len(tokenizer))
|
|
||||||
# model.vocab_size = len(tokenizer)
|
|
||||||
|
|
||||||
llm.config.pad_token_id = tokenizer.pad_token_id
|
llm.config.pad_token_id = tokenizer.pad_token_id
|
||||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||||
@ -755,11 +731,6 @@ def run(rank, world_size, args):
|
|||||||
attn_implementation = "eager"
|
attn_implementation = "eager"
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
|
|
||||||
# codec_lm = AutoModelForCausalLM.from_pretrained(
|
|
||||||
# params.llm_path_or_name,
|
|
||||||
# attn_implementation=attn_implementation,
|
|
||||||
# torch_dtype=torch_dtype,
|
|
||||||
# )
|
|
||||||
codec_vocab_size = 4096 + 4
|
codec_vocab_size = 4096 + 4
|
||||||
# TODO: modify above vocab size or supress_tokens when decoding
|
# TODO: modify above vocab size or supress_tokens when decoding
|
||||||
config = Qwen2Config(
|
config = Qwen2Config(
|
||||||
@ -771,39 +742,19 @@ def run(rank, world_size, args):
|
|||||||
intermediate_size=2048,
|
intermediate_size=2048,
|
||||||
max_position_embeddings=4096,
|
max_position_embeddings=4096,
|
||||||
)
|
)
|
||||||
# codec_lm = Qwen2ForCausalLM(config=config)
|
|
||||||
# Pass attn_implementation and torch_dtype to the constructor
|
|
||||||
# Use AutoModelForCausalLM.from_config for more generality
|
|
||||||
codec_lm = AutoModelForCausalLM.from_config(
|
codec_lm = AutoModelForCausalLM.from_config(
|
||||||
config=config,
|
config=config,
|
||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
# cosyvoice2_token_size = 6561
|
|
||||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||||
codec_lm.vocab_size = codec_vocab_size
|
codec_lm.vocab_size = codec_vocab_size
|
||||||
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
||||||
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
||||||
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
||||||
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
||||||
# if params.use_lora:
|
|
||||||
# lora_config = LoraConfig(
|
|
||||||
# r=64,
|
|
||||||
# lora_alpha=16,
|
|
||||||
# target_modules=[
|
|
||||||
# "q_proj",
|
|
||||||
# "k_proj",
|
|
||||||
# "v_proj",
|
|
||||||
# "o_proj",
|
|
||||||
# "up_proj",
|
|
||||||
# "gate_proj",
|
|
||||||
# "down_proj",
|
|
||||||
# ],
|
|
||||||
# lora_dropout=0.05,
|
|
||||||
# task_type="CAUSAL_LM",
|
|
||||||
# )
|
|
||||||
# codec_lm = get_peft_model(codec_lm, lora_config)
|
|
||||||
# codec_lm.print_trainable_parameters()
|
|
||||||
else:
|
else:
|
||||||
codec_lm = None
|
codec_lm = None
|
||||||
|
|
||||||
@ -856,7 +807,6 @@ def run(rank, world_size, args):
|
|||||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
# )
|
# )
|
||||||
return False
|
return False
|
||||||
# cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]
|
|
||||||
codec_len = len(c.custom["answer_cosyvoice_speech_token"])
|
codec_len = len(c.custom["answer_cosyvoice_speech_token"])
|
||||||
if codec_len > 2200:
|
if codec_len > 2200:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
@ -873,7 +823,7 @@ def run(rank, world_size, args):
|
|||||||
if params.sampler_state_dict_path:
|
if params.sampler_state_dict_path:
|
||||||
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
||||||
sampler_state_dict["max_duration"] = params.max_duration
|
sampler_state_dict["max_duration"] = params.max_duration
|
||||||
# TODO: load sampler state dict
|
|
||||||
train_dl = data_module.train_dataloaders(
|
train_dl = data_module.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user