mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
lint
This commit is contained in:
parent
360f0aa397
commit
11bd3c9ad8
@ -23,67 +23,33 @@ The following table lists the folders for different tasks.
|
||||
|
||||
Command for training is:
|
||||
```bash
|
||||
pip install -r whisper_llm_zh/requirements.txt
|
||||
|
||||
pip install huggingface_hub['cli']
|
||||
mkdir -p models/whisper models/qwen
|
||||
|
||||
# For aishell fine-tuned whisper model
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||
# For multi-hans fine-tuned whisper model
|
||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
|
||||
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
|
||||
|
||||
# First, we only train the projector and freeze other modules.
|
||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir ./whisper_llm_zh/exp_test \
|
||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 50 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora False --unfreeze-llm False
|
||||
|
||||
# Then we jointly train the projector and LLM LoRA modules.
|
||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir ./whisper_llm_zh/exp_test \
|
||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --unfreeze-llm True
|
||||
--pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||
```
|
||||
|
||||
Command for decoding:
|
||||
Command for decoding is:
|
||||
```bash
|
||||
mkdir -p models/whisper models/qwen models/checkpoint
|
||||
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
|
||||
|
||||
# For aishell fine-tuned whisper model
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||
# For multi-hans fine-tuned whisper model
|
||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
|
||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||
|
||||
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
|
||||
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
|
||||
|
||||
python3 ./whisper_llm_zh/decode.py \
|
||||
--max-duration 80 \
|
||||
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
|
||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||
--llm-path-or-name models/qwen \
|
||||
python3 ./qwen_omni/decode.py \
|
||||
--max-duration 1 \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--epoch 999 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --dataset aishell
|
||||
--method e2e-epoch10_speech2speech \
|
||||
--enable-speech-output True \
|
||||
--token2wav-path models/CosyVoice-300M-SFT \
|
||||
--use-lora True
|
||||
```
|
||||
|
||||
Please see [`prepare.sh`](./prepare.sh) for more details.
|
||||
|
@ -165,7 +165,7 @@ def compute_fbank(args):
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
cuts_path = f"{in_out_dir}/{args.prefix}_cuts.{idx}.jsonl.gz"
|
||||
cuts_path = f"{in_out_dir}/cuts_{args.prefix}.{idx}.jsonl.gz"
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
# see https://github.com/lhotse-speech/lhotse/issues/1125
|
||||
cut_set.drop_recordings().to_file(cuts_path)
|
||||
|
@ -2,7 +2,9 @@
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/slam/icefall_omni
|
||||
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/icefall
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=$1
|
||||
@ -19,18 +21,37 @@ log() {
|
||||
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "stage 0: "
|
||||
#pip uninstall lhotse
|
||||
#cd /workspace/slam/lhotse
|
||||
#git config --global --add safe.directory /workspace/slam/lhotse
|
||||
#pip install -e '.[dev]'
|
||||
cd -
|
||||
pip install -r slam_omni/requirements.txt
|
||||
log "stage 0: Clone CosyVoice repo and install requirements inside the container"
|
||||
# docker: ghcr.io/swivid/f5-tts:main
|
||||
pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git /workspace/CosyVoice
|
||||
cd /workspace/CosyVoice
|
||||
# If you failed to clone submodule due to network failures, please run following command until success
|
||||
git submodule update --init --recursive
|
||||
pip install -r qwen_omni/requirements.txt
|
||||
pip install -r qwen_omni/requirements-cosyvoice.txt
|
||||
|
||||
# For Chinese only dataset, you can use the following command to download the Chinese fine-tuned whisper model.
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
|
||||
# Cosyvoice pretrained model for speech token2wav module
|
||||
huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
|
||||
# Qwen Pretrained model
|
||||
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
|
||||
# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
|
||||
huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
|
||||
|
||||
# For Gradio demo, we follow https://arxiv.org/abs/2412.15649 to use ASR model to decode the history speech as context.
|
||||
pip install sherpa-onnx
|
||||
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
|
||||
if [ ! -d $model_path ]; then
|
||||
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
||||
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
|
||||
fi
|
||||
fi
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
|
||||
|
||||
log "stage 1: Compute fbank feature from huggingface"
|
||||
python3 local/compute_whisper_fbank.py \
|
||||
--num-mel-bins 80 --whisper-fbank True --resample-to-16kHz True --speed-perturb False \
|
||||
--out-dir data/fbank_test \
|
||||
@ -39,26 +60,42 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
--prefix belle
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Combine features"
|
||||
manifest_dir=data/fbank
|
||||
if [ ! -f $manifest_dir/cuts_belle_00001-01600.jsonl.gz ]; then
|
||||
mv $manifest_dir/cuts_belle.00000.jsonl.gz ./
|
||||
# exclude cust_belle_00000.jsonl.gz for valid and test set
|
||||
pieces=$(find $manifest_dir -name "cuts_belle.*.jsonl.gz" | sort)
|
||||
# # remove cust_belle_00000.jsonl.gz from pieces
|
||||
# pieces=$(echo $pieces | sed 's/cuts_belle.00000.jsonl.gz//g')
|
||||
echo $pieces | wc
|
||||
lhotse combine $pieces data/fbank/cuts_belle_00001-01600.jsonl.gz
|
||||
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz && cd -
|
||||
mv ./cuts_belle.00000.jsonl.gz $manifest_dir # put it back
|
||||
cd $manifest_dir && ln -s cuts_belle_00001-01600.jsonl.gz cuts_belle_train.jsonl.gz
|
||||
ln -s cuts_belle.00000.jsonl.gz cuts_belle_test.jsonl.gz && cd -
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
ngpu=8
|
||||
exp_dir=./qwen_omni/exp_speech2speech
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "stage 3: "
|
||||
exp_dir=./slam_omni/exp_speech2speech_rerun
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
||||
python3 ./slam_omni/decode.py \
|
||||
log "stage 3: Training Speech2Speech Model"
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 50 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "stage 4: Decoding, only support batch_size=1 for now."
|
||||
cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
|
||||
python3 ./qwen_omni/decode.py \
|
||||
--max-duration 1 \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
@ -66,78 +103,20 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
--epoch 999 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--method e2e-epoch10_speech2speech_rerun \
|
||||
--method e2e-epoch10_speech2speech \
|
||||
--enable-speech-output True \
|
||||
--token2wav-path /workspace/CosyVoice-300M-SFT \
|
||||
--use-lora True # --on-the-fly-feats True
|
||||
|
||||
--token2wav-path models/CosyVoice-300M-SFT \
|
||||
--use-lora True
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "stage 4: "
|
||||
ngpu=8
|
||||
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||
--max-duration 80 \
|
||||
--enable-musan False \
|
||||
--exp-dir ./slam_omni/exp_speech2text \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
|
||||
--sampler-state-dict-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000-sampler.pt \
|
||||
--use-lora True --unfreeze-llm True
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "stage 5: "
|
||||
ngpu=8
|
||||
exp_dir=./slam_omni/exp_speech2speech_rerun
|
||||
# exp_dir_new=./slam_omni/exp_s2s
|
||||
torchrun --nproc_per_node $ngpu ./slam_omni/train.py \
|
||||
--max-duration 50 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
log "stage 5: Gradio Demo"
|
||||
python3 ./qwen_omni/web_demo.py \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./slam_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--pretrained-model-path $exp_dir/epoch-1-checkpoint-15000.pt/pytorch_model.bin \
|
||||
--sampler-state-dict-path $exp_dir/epoch-1-checkpoint-15000-sampler.pt \
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||
# --pretrained-model-path slam_omni/exp_speech2text/epoch-1-checkpoint-5000.pt/pytorch_model.bin \
|
||||
# --sampler-state-dict-path $exp_dir/epoch-1-checkpoint-35000-sampler.pt \
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "stage 6: "
|
||||
export PYTHONPATH=$PYTHONPATH:/workspace/CosyVoice
|
||||
exp_dir=./slam_omni/exp_speech2speech_rerun
|
||||
python3 ./slam_omni/web_demo.py \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--checkpoint-path $exp_dir/epoch-998.pt \
|
||||
--checkpoint-path $exp_dir/epoch-999.pt \
|
||||
--use-flash-attn True \
|
||||
--enable-speech-output True \
|
||||
--asr-model-dir local/sherpa-onnx-paraformer-zh-2023-09-14 \
|
||||
--use-lora True --token2wav-path /workspace/CosyVoice-300M-SFT --share
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "stage 7: "
|
||||
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
|
||||
|
||||
if [ ! -d $model_path ]; then
|
||||
pip install sherpa-onnx
|
||||
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
||||
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
|
||||
fi
|
||||
fi
|
||||
|
@ -116,28 +116,6 @@ class AsrDataModule:
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
# group.add_argument(
|
||||
# "--concatenate-cuts",
|
||||
# type=str2bool,
|
||||
# default=False,
|
||||
# help="When enabled, utterances (cuts) will be concatenated "
|
||||
# "to minimize the amount of padding.",
|
||||
# )
|
||||
# group.add_argument(
|
||||
# "--duration-factor",
|
||||
# type=float,
|
||||
# default=1.0,
|
||||
# help="Determines the maximum duration of a concatenated cut "
|
||||
# "relative to the duration of the longest cut in a batch.",
|
||||
# )
|
||||
# group.add_argument(
|
||||
# "--gap",
|
||||
# type=float,
|
||||
# default=1.0,
|
||||
# help="The amount of padding (in seconds) inserted between "
|
||||
# "concatenated cuts. This padding is filled with noise when "
|
||||
# "noise augmentation is used.",
|
||||
# )
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
@ -256,20 +234,6 @@ class AsrDataModule:
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
# if self.args.concatenate_cuts:
|
||||
# logging.info(
|
||||
# f"Using cut concatenation with duration factor "
|
||||
# f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
# )
|
||||
# # Cut concatenation should be the first transform in the list,
|
||||
# # so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# # different utterances.
|
||||
# transforms = [
|
||||
# CutConcatenate(
|
||||
# duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
# )
|
||||
# ] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
@ -426,32 +390,12 @@ class AsrDataModule:
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
if self.args.on_the_fly_feats:
|
||||
# dataset = load_dataset(self.args.huggingface_dataset_path_or_name, streaming=True, split=partition)
|
||||
i, num_digits = 0, 5
|
||||
idx = f"{i}".zfill(num_digits)
|
||||
parquet_files = [
|
||||
f"data/train-{idx}-of-01601.parquet",
|
||||
]
|
||||
parquet_files = [
|
||||
f"{self.args.huggingface_dataset_path_or_name}/{f}"
|
||||
for f in parquet_files
|
||||
]
|
||||
file_name = parquet_files[0]
|
||||
logging.info(f"Loading dataset from {file_name}")
|
||||
dataset = load_dataset(
|
||||
"parquet", data_files=parquet_files, streaming=True, split="train"
|
||||
)
|
||||
cut_set = CutSet.from_huggingface_dataset(
|
||||
dataset, audio_key=self.args.audio_key, text_key=self.args.text_key
|
||||
)
|
||||
if self.args.resample_to_16kHz:
|
||||
cut_set = cut_set.resample(16000)
|
||||
return {"test": cut_set}
|
||||
pass
|
||||
else:
|
||||
# return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_belle.00000.jsonl.gz")}
|
||||
# return {'test':load_manifest_lazy(self.args.manifest_dir / "cuts_test_small.jsonl.gz")}
|
||||
return {
|
||||
"test": load_manifest_lazy("data/fbank_test/belle_cuts.00000.jsonl.gz")
|
||||
"test": load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
|
||||
)
|
||||
}
|
||||
|
||||
@lru_cache()
|
||||
@ -461,7 +405,7 @@ class AsrDataModule:
|
||||
pass
|
||||
else:
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_belle.00000.jsonl.gz"
|
||||
self.args.manifest_dir / "cuts_belle_test.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
|
@ -20,30 +20,27 @@
|
||||
"""
|
||||
Usage:
|
||||
# Command for decoding using fine-tuned models:
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
|
||||
# Cosyvoice pretrained model for speech token2wav module
|
||||
huggingface-cli download --local-dir models/CosyVoice-300M-SFT FunAudioLLM/CosyVoice-300M-SFT
|
||||
# Qwen Pretrained model
|
||||
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
|
||||
# Qwen-Omni like speech2speech model trained on worstchan/Belle_1.4M-SLAM-Omni
|
||||
huggingface-cli download --local-dir models/qwen-omni-like-speech2speech-belle-1.4M yuekai/qwen-omni-like-speech2speech-belle-1.4M
|
||||
|
||||
pip install huggingface_hub['cli']
|
||||
mkdir -p models/whisper models/qwen models/checkpoint
|
||||
huggingface-cli download --local-dir models/checkpoint yuekai/icefall_asr_aishell_whisper_qwen2_1.5B
|
||||
|
||||
# For aishell fine-tuned whisper model
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||
# For multi-hans fine-tuned whisper model
|
||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
|
||||
huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||
|
||||
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
|
||||
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
|
||||
|
||||
python3 ./whisper_llm_zh/decode.py \
|
||||
--max-duration 80 \
|
||||
--exp-dir whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B \
|
||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||
--llm-path-or-name models/qwen \
|
||||
cd $exp_dir && ln -s ../../models/qwen-omni-like-speech2speech-belle-1.4M/pytorch_model.bin epoch-999.pt && cd -
|
||||
python3 ./qwen_omni/decode.py \
|
||||
--max-duration 1 \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--epoch 999 --avg 1 \
|
||||
--manifest-dir data/fbank \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --dataset aishell
|
||||
--method e2e-epoch10_speech2speech \
|
||||
--enable-speech-output True \
|
||||
--token2wav-path models/CosyVoice-300M-SFT \
|
||||
--use-lora True
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -183,11 +180,6 @@ def get_model(params, device):
|
||||
attn_implementation = "eager"
|
||||
torch_dtype = torch.float16
|
||||
|
||||
# codec_lm = AutoModelForCausalLM.from_pretrained(
|
||||
# params.llm_path_or_name,
|
||||
# attn_implementation=attn_implementation,
|
||||
# torch_dtype=torch_dtype,
|
||||
# )
|
||||
codec_vocab_size = 4096 + 4
|
||||
config = Qwen2Config(
|
||||
vocab_size=codec_vocab_size,
|
||||
@ -198,39 +190,19 @@ def get_model(params, device):
|
||||
intermediate_size=2048,
|
||||
max_position_embeddings=4096,
|
||||
)
|
||||
# codec_lm = Qwen2ForCausalLM(config=config)
|
||||
# Pass attn_implementation and torch_dtype to the constructor
|
||||
# Use AutoModelForCausalLM.from_config for more generality
|
||||
|
||||
codec_lm = AutoModelForCausalLM.from_config(
|
||||
config=config,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
# cosyvoice2_token_size = 6561
|
||||
|
||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||
codec_lm.vocab_size = codec_vocab_size
|
||||
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
||||
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
||||
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
||||
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
||||
# if params.use_lora:
|
||||
# lora_config = LoraConfig(
|
||||
# r=64,
|
||||
# lora_alpha=16,
|
||||
# target_modules=[
|
||||
# "q_proj",
|
||||
# "k_proj",
|
||||
# "v_proj",
|
||||
# "o_proj",
|
||||
# "up_proj",
|
||||
# "gate_proj",
|
||||
# "down_proj",
|
||||
# ],
|
||||
# lora_dropout=0.05,
|
||||
# task_type="CAUSAL_LM",
|
||||
# )
|
||||
# codec_lm = get_peft_model(codec_lm, lora_config)
|
||||
# codec_lm.print_trainable_parameters()
|
||||
else:
|
||||
codec_lm = None
|
||||
|
||||
@ -373,13 +345,6 @@ def get_parser():
|
||||
default="/workspace/CosyVoice-300M-SFT",
|
||||
help="The path to the token2wav model",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--dataset",
|
||||
# type=str,
|
||||
# default="aishell",
|
||||
# choices=["aishell", "speechio", "wenetspeech_test_meeting", "multi_hans_zh"],
|
||||
# help="The dataset to decode",
|
||||
# )
|
||||
|
||||
add_model_arguments(parser)
|
||||
return parser
|
||||
@ -474,12 +439,6 @@ def decode_one_batch(
|
||||
|
||||
chat_rounds = [cut.custom["round"] for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
# messages = [
|
||||
# [
|
||||
# {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
||||
# {"role": "assistant", "content": ""},
|
||||
# ]
|
||||
# ] * len(feature)
|
||||
questions_with_history = [
|
||||
cut.custom["question"] for cut in batch["supervisions"]["cut"]
|
||||
]
|
||||
@ -496,7 +455,6 @@ def decode_one_batch(
|
||||
history_question_answer = history_contexts[i].split("USER:")
|
||||
history_question_answer = [item for item in history_question_answer if item]
|
||||
for j in range(total_round - 1):
|
||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
||||
question_answer = history_question_answer[j].split("ASSISTANT:")
|
||||
message += [
|
||||
{"role": "user", "content": question_answer[0].strip()},
|
||||
@ -504,7 +462,6 @@ def decode_one_batch(
|
||||
]
|
||||
message += [
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||
# {"role": "user", "content": f"{last_questions[i]}"},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
print(f"message: {message}, batch_size {len(chat_rounds)}")
|
||||
@ -525,13 +482,6 @@ def decode_one_batch(
|
||||
audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).unsqueeze(0)
|
||||
audio_hat = audio_decode_cosyvoice(audio_tokens, token2wav_model)
|
||||
sf.write(speech_file_name, audio_hat.squeeze(0).cpu().numpy(), 22050)
|
||||
# with open(speech_token_file_name, 'w') as f:
|
||||
# # save_path = params.exp_dir / f"speech_output/{cut_id}.wav"
|
||||
# #torchaudio.save(save_path, speech_output.cpu(), 16000)
|
||||
# # print(f"speech_output: {generated_speech_output}, cut_id: {cut_id}")
|
||||
# save_str = " ".join([str(i) for i in generated_speech_output])
|
||||
# f.write(f"{cut_id}|{save_str}\n")
|
||||
|
||||
else:
|
||||
generated_ids = model.decode(
|
||||
feature, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||
@ -560,43 +510,6 @@ def decode_dataset(
|
||||
Returns:
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
|
||||
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
||||
"""
|
||||
Text normalization similar to M2MeT challenge baseline.
|
||||
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||
"""
|
||||
if normalize == "none":
|
||||
return text
|
||||
elif normalize == "m2met":
|
||||
import re
|
||||
|
||||
text = text.replace(" ", "")
|
||||
text = text.replace("<sil>", "")
|
||||
text = text.replace("<%>", "")
|
||||
text = text.replace("<->", "")
|
||||
text = text.replace("<$>", "")
|
||||
text = text.replace("<#>", "")
|
||||
text = text.replace("<_>", "")
|
||||
text = text.replace("<space>", "")
|
||||
text = text.replace("`", "")
|
||||
text = text.replace("&", "")
|
||||
text = text.replace(",", "")
|
||||
if re.search("[a-zA-Z]", text):
|
||||
text = text.upper()
|
||||
text = text.replace("A", "A")
|
||||
text = text.replace("a", "A")
|
||||
text = text.replace("b", "B")
|
||||
text = text.replace("c", "C")
|
||||
text = text.replace("k", "K")
|
||||
text = text.replace("t", "T")
|
||||
text = text.replace(",", "")
|
||||
text = text.replace("丶", "")
|
||||
text = text.replace("。", "")
|
||||
text = text.replace("、", "")
|
||||
text = text.replace("?", "")
|
||||
return text
|
||||
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
@ -634,7 +547,6 @@ def decode_dataset(
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
# ref_text = normalize_text_alimeeting(ref_text)
|
||||
ref_words = ref_text.split()
|
||||
print(f"ref: {ref_text}")
|
||||
print(f"hyp: {''.join(hyp_words)}")
|
||||
@ -673,7 +585,6 @@ def save_results(
|
||||
errs_filename = (
|
||||
params.log_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
|
||||
@ -732,11 +643,8 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
|
||||
data_module = AsrDataModule(args)
|
||||
# data_module = MultiDataset(args.manifest_dir)
|
||||
|
||||
def remove_long_utt(c: Cut):
|
||||
# Keep only utterances with duration in 30 seconds
|
||||
@ -748,13 +656,6 @@ def main():
|
||||
return False
|
||||
return True
|
||||
|
||||
# if params.dataset == "aishell":
|
||||
# test_sets_cuts = data_module.aishell_test_cuts()
|
||||
# elif params.dataset == "speechio":
|
||||
# test_sets_cuts = data_module.speechio_test_cuts()
|
||||
# elif params.dataset == "wenetspeech_test_meeting":
|
||||
# test_sets_cuts = data_module.wenetspeech_test_meeting_cuts()
|
||||
# else:
|
||||
test_sets_cuts = data_module.test_cuts()
|
||||
|
||||
test_sets = test_sets_cuts.keys()
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple # Added for type hints
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -78,7 +78,6 @@ class SPEECH_LLM(nn.Module):
|
||||
self.codec_lm_head = nn.Linear(
|
||||
self.codec_lm.config.hidden_size, self.codec_lm.config.vocab_size
|
||||
)
|
||||
# to torch.float16
|
||||
self.speech_token_projector = self.speech_token_projector.to(
|
||||
dtype=torch.float16
|
||||
)
|
||||
@ -498,20 +497,6 @@ class SPEECH_LLM(nn.Module):
|
||||
pad_token_id=self.llm.config.pad_token_id,
|
||||
)
|
||||
|
||||
# generated_ids = self.llm.generate(
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# max_new_tokens=kwargs.get("max_new_tokens", 200),
|
||||
# num_beams=kwargs.get("num_beams", 1),
|
||||
# do_sample=kwargs.get("do_sample", False),
|
||||
# min_length=kwargs.get("min_length", 1),
|
||||
# top_p=kwargs.get("top_p", 1.0),
|
||||
# repetition_penalty=kwargs.get("repetition_penalty", 1.0),
|
||||
# temperature=kwargs.get("temperature", 1.0),
|
||||
# length_penalty=kwargs.get("length_penalty", 1.0),
|
||||
# bos_token_id=self.llm.config.bos_token_id,
|
||||
# eos_token_id=self.llm.config.eos_token_id,
|
||||
# pad_token_id=self.llm.config.pad_token_id,
|
||||
# )
|
||||
return generated_ids
|
||||
|
||||
def decode_with_speech_output(
|
||||
@ -520,7 +505,7 @@ class SPEECH_LLM(nn.Module):
|
||||
input_ids: torch.LongTensor = None, # Prompt input_ids
|
||||
attention_mask: torch.Tensor = None, # Prompt attention_mask
|
||||
max_text_new_tokens: int = 1024,
|
||||
max_speech_new_tokens: int = 1024, # Max length for speech tokens
|
||||
max_speech_new_tokens: int = 2048, # Max length for speech tokens
|
||||
llm_kwargs: dict = None, # Kwargs for text LLM generate
|
||||
codec_lm_kwargs: dict = None, # Kwargs for codec LM (e.g., temperature for sampling) - NOT IMPLEMENTED YET
|
||||
) -> Tuple[torch.LongTensor, List[List[int]]]:
|
||||
@ -602,7 +587,7 @@ class SPEECH_LLM(nn.Module):
|
||||
eos_token_id = self.llm.config.eos_token_id
|
||||
eos_token_embedding = self.llm.get_input_embeddings()(
|
||||
torch.tensor([[eos_token_id]], device=device)
|
||||
) # 1,D
|
||||
)
|
||||
assert (
|
||||
generated_text_ids[0, -1] == eos_token_id
|
||||
), f"Last token is not EOS: {generated_text_ids[0, -1]} != {eos_token_id}"
|
||||
@ -610,7 +595,7 @@ class SPEECH_LLM(nn.Module):
|
||||
token_hidden_states[0].to(self.llm.device)
|
||||
for token_hidden_states in text_outputs.hidden_states
|
||||
]
|
||||
# shift one for thinker token_embeds, drop the first embeds, and add the eos token
|
||||
|
||||
first_thinker_token_embed = torch.cat(
|
||||
[
|
||||
thinker_token_embeds_org[0][:, 1:],
|
||||
@ -628,7 +613,7 @@ class SPEECH_LLM(nn.Module):
|
||||
token_hidden_states[-1].to(self.llm.device)
|
||||
for token_hidden_states in text_outputs.hidden_states
|
||||
]
|
||||
# thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1)
|
||||
|
||||
thinker_reply_part = [
|
||||
torch.cat(
|
||||
[
|
||||
@ -651,12 +636,8 @@ class SPEECH_LLM(nn.Module):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
thinker_prompt_part = self.speech_token_projector(
|
||||
thinker_prompt_part
|
||||
) # [B, S_full, D_codec]
|
||||
thinker_reply_part = self.speech_token_projector(
|
||||
thinker_reply_part
|
||||
) # [B, S_full, D_codec]
|
||||
thinker_prompt_part = self.speech_token_projector(thinker_prompt_part)
|
||||
thinker_reply_part = self.speech_token_projector(thinker_reply_part)
|
||||
|
||||
thinker_prompt_part_seq_len = thinker_prompt_part.shape[1]
|
||||
talker_input_ids = torch.full(
|
||||
@ -666,9 +647,7 @@ class SPEECH_LLM(nn.Module):
|
||||
device=self.llm.device,
|
||||
)
|
||||
talker_input_ids[:, -1] = self.codec_lm.config.bos_token_id
|
||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
|
||||
talker_input_ids
|
||||
) # [B, S_full, D_codec]
|
||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(talker_input_ids)
|
||||
thinker_input_embeds = torch.cat(
|
||||
[
|
||||
thinker_prompt_part,
|
||||
@ -677,68 +656,43 @@ class SPEECH_LLM(nn.Module):
|
||||
dim=1,
|
||||
)
|
||||
talker_inputs_embeds += thinker_input_embeds
|
||||
thinker_reply_part = thinker_reply_part[
|
||||
:, delay_step + 1 :, :
|
||||
] # [B, S_full, D_codec]
|
||||
thinker_reply_part = thinker_reply_part[:, delay_step + 1 :, :]
|
||||
|
||||
past_key_values = None
|
||||
# generated_speech_tokens_list = [[] for _ in range(batch_size)]
|
||||
# unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device)
|
||||
|
||||
generated_speech_tokens_list = []
|
||||
next_token_ids = None
|
||||
# text_context_len = projected_text_embeds.shape[1] # S_full
|
||||
|
||||
for t in range(max_speech_new_tokens):
|
||||
# Get embedding for the *current* input token ID (initially BOS, then generated tokens)
|
||||
# current_speech_embeds = self.codec_lm.get_input_embeddings()(current_speech_input_ids) # [B, 1, D_codec]
|
||||
if t > 0:
|
||||
talker_inputs_embeds = self.codec_lm.get_input_embeddings()(
|
||||
next_token_ids
|
||||
) # [B, 1, D_codec]
|
||||
)
|
||||
if thinker_reply_part.shape[1] > 0:
|
||||
talker_inputs_embeds += thinker_reply_part[:, :1, :]
|
||||
thinker_reply_part = thinker_reply_part[
|
||||
:, 1:, :
|
||||
] # Remove the first token for next step
|
||||
# # Add the projected text embedding corresponding to the current timestep `t`
|
||||
# if t < text_context_len:
|
||||
# # Text context from the full generated text sequence
|
||||
# current_text_context_embed = projected_text_embeds[:, t:t+1, :] # [B, 1, D_codec]
|
||||
# inputs_embeds = current_speech_embeds + current_text_context_embed
|
||||
# else:
|
||||
# # No more text context to add
|
||||
# inputs_embeds = current_speech_embeds
|
||||
thinker_reply_part = thinker_reply_part[:, 1:, :]
|
||||
|
||||
# Forward pass through codec LM for one step
|
||||
# We provide inputs_embeds directly, bypassing prepare_inputs_for_generation
|
||||
codec_outputs = self.codec_lm(
|
||||
inputs_embeds=talker_inputs_embeds, # Combined embedding for this step
|
||||
inputs_embeds=talker_inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
# No attention mask needed here when using past_key_values and single token input
|
||||
)
|
||||
last_token_hidden_state = codec_outputs.hidden_states[-1][
|
||||
:, -1, :
|
||||
] # [B, D_codec] #TODO: check shape here
|
||||
# Get logits for the *last* token generated in this step
|
||||
next_token_logits = self.codec_lm_head(
|
||||
last_token_hidden_state
|
||||
) # Use -1 index
|
||||
# suppress tokens between 4096:len(vocab)-3
|
||||
# next_token_logits[:, 4096:-3] = -float("Inf") # TODO: where we should supress tokens?
|
||||
last_token_hidden_state = codec_outputs.hidden_states[-1][:, -1, :]
|
||||
next_token_logits = self.codec_lm_head(last_token_hidden_state)
|
||||
|
||||
next_token_ids = topk_sampling(
|
||||
next_token_logits,
|
||||
)
|
||||
# print(next_token_ids, "next_token_ids", t, next_token_ids.shape)
|
||||
if next_token_ids[0, 0] == self.codec_lm.config.eos_token_id:
|
||||
break
|
||||
# current_speech_input_ids = next_token_ids # Use the newly generated token ID as input for next step
|
||||
|
||||
past_key_values = codec_outputs.past_key_values # Update KV cache
|
||||
generated_speech_tokens_list.append(
|
||||
next_token_ids.squeeze(1).cpu().tolist()[0]
|
||||
)
|
||||
# --- 6. Return Results ---
|
||||
|
||||
return generated_text_ids, generated_speech_tokens_list
|
||||
|
||||
|
||||
|
@ -17,28 +17,22 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
# fine-tuning with whisper and Qwen2
|
||||
pip install huggingface_hub['cli']
|
||||
mkdir -p models/whisper models/qwen
|
||||
# For Chinese dataset, you can use the following command to download the Chinese fine-tuned whisper model.
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper
|
||||
# Qwen Pretrained model
|
||||
huggingface-cli download --local-dir models/Qwen2.5-0.5B-Instruct Qwen/Qwen2.5-0.5B-Instruct
|
||||
|
||||
# For aishell fine-tuned whisper model
|
||||
huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishell_whisper exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt
|
||||
# For multi-hans fine-tuned whisper model
|
||||
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
|
||||
|
||||
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
|
||||
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
|
||||
|
||||
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir ./whisper_llm_zh/exp_test \
|
||||
--speech-encoder-path-or-name models/whisper/exp_large_v2/whisper-large-v2-aishell1-epoch-10-avg-6.pt \
|
||||
--llm-path-or-name Qwen/Qwen2-1.5B-Instruct \
|
||||
torchrun --nproc_per_node $ngpu ./qwen_omni/train.py \
|
||||
--max-duration 50 \
|
||||
--enable-musan False \
|
||||
--exp-dir $exp_dir \
|
||||
--speech-encoder-path-or-name models/whisper/v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt \
|
||||
--llm-path-or-name Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--manifest-dir data/fbank \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
|
||||
--deepspeed_config ./qwen_omni/ds_config_zero1.json \
|
||||
--use-flash-attn True \
|
||||
--use-lora True --unfreeze-llm True
|
||||
--use-lora True --unfreeze-llm True --unfreeze-speech-projector True --enable-speech-output True
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -52,7 +46,6 @@ from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import deepspeed
|
||||
import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
@ -66,8 +59,6 @@ from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
|
||||
|
||||
# from multi_dataset import MultiDataset
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@ -330,9 +321,6 @@ def compute_loss(
|
||||
truncation=False,
|
||||
)
|
||||
)
|
||||
# padding texts to the same length, texts is a list of list, padding with tokenzier.pad_token_id
|
||||
# remove too long text
|
||||
# texts = [ text for text in texts if len(text) < 1024 ]
|
||||
if len(texts) != len(messages):
|
||||
logging.warning(f"Remove too long text, {messages} ")
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
@ -347,7 +335,7 @@ def compute_loss(
|
||||
for text in texts
|
||||
]
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
# response = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0]
|
||||
|
||||
target_ids = input_ids.clone()
|
||||
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
||||
# mask all tokens before token_id 151646 with IGNORE_TOKEN_ID
|
||||
@ -396,8 +384,6 @@ def compute_loss(
|
||||
history_contexts = [
|
||||
question.rsplit("<USER>:", 1)[0].strip() for question in questions_with_history
|
||||
]
|
||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。<USER>: 告诉我如何烹饪鸡肉
|
||||
# <USER>: 对以下句子进行鉴赏:他心地善良。输出结果为"他是一个有善心的人。
|
||||
|
||||
messages = []
|
||||
for i, total_round in enumerate(chat_rounds):
|
||||
@ -406,7 +392,6 @@ def compute_loss(
|
||||
history_question_answer = history_contexts[i].split("USER:")
|
||||
history_question_answer = [item for item in history_question_answer if item]
|
||||
for j in range(total_round - 1):
|
||||
# USER: 生成一个关于夏天的诗歌。 ASSISTANT: 夏日炎炎,万物生长,阳光明媚,享受着夏日的美好时光。 USER: 给我列举一些新闻头条。 ASSISTANT: 当今社会的新闻永远不会停。
|
||||
question_answer = history_question_answer[j].split("ASSISTANT:")
|
||||
message += [
|
||||
{"role": "user", "content": question_answer[0].strip()},
|
||||
@ -683,7 +668,6 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.use_flash_attn:
|
||||
attn_implementation = "flash_attention_2"
|
||||
# torch_dtype=torch.bfloat16 FIX ME
|
||||
torch_dtype = torch.float16
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
@ -724,14 +708,6 @@ def run(rank, world_size, args):
|
||||
|
||||
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
||||
tokenizer.add_special_tokens(special_tokens_dict)
|
||||
# original_tokenizer_vocab_size = len(tokenizer)
|
||||
# cosyvoice2_token_size = 6561
|
||||
# new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
|
||||
# "<|SPEECH_GENERATION_START|>"
|
||||
# ]
|
||||
# num_added_tokens = tokenizer.add_tokens(new_tokens)
|
||||
# model.resize_token_embeddings(len(tokenizer))
|
||||
# model.vocab_size = len(tokenizer)
|
||||
|
||||
llm.config.pad_token_id = tokenizer.pad_token_id
|
||||
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
||||
@ -755,11 +731,6 @@ def run(rank, world_size, args):
|
||||
attn_implementation = "eager"
|
||||
torch_dtype = torch.float16
|
||||
|
||||
# codec_lm = AutoModelForCausalLM.from_pretrained(
|
||||
# params.llm_path_or_name,
|
||||
# attn_implementation=attn_implementation,
|
||||
# torch_dtype=torch_dtype,
|
||||
# )
|
||||
codec_vocab_size = 4096 + 4
|
||||
# TODO: modify above vocab size or supress_tokens when decoding
|
||||
config = Qwen2Config(
|
||||
@ -771,39 +742,19 @@ def run(rank, world_size, args):
|
||||
intermediate_size=2048,
|
||||
max_position_embeddings=4096,
|
||||
)
|
||||
# codec_lm = Qwen2ForCausalLM(config=config)
|
||||
# Pass attn_implementation and torch_dtype to the constructor
|
||||
# Use AutoModelForCausalLM.from_config for more generality
|
||||
|
||||
codec_lm = AutoModelForCausalLM.from_config(
|
||||
config=config,
|
||||
attn_implementation=attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
# cosyvoice2_token_size = 6561
|
||||
|
||||
codec_lm.resize_token_embeddings(codec_vocab_size)
|
||||
codec_lm.vocab_size = codec_vocab_size
|
||||
codec_lm.config.pad_token_id = codec_vocab_size - 1
|
||||
codec_lm.config.eos_token_id = codec_vocab_size - 2
|
||||
codec_lm.config.bos_token_id = codec_vocab_size - 3
|
||||
codec_lm.config.mask_token_id = codec_vocab_size - 4
|
||||
# if params.use_lora:
|
||||
# lora_config = LoraConfig(
|
||||
# r=64,
|
||||
# lora_alpha=16,
|
||||
# target_modules=[
|
||||
# "q_proj",
|
||||
# "k_proj",
|
||||
# "v_proj",
|
||||
# "o_proj",
|
||||
# "up_proj",
|
||||
# "gate_proj",
|
||||
# "down_proj",
|
||||
# ],
|
||||
# lora_dropout=0.05,
|
||||
# task_type="CAUSAL_LM",
|
||||
# )
|
||||
# codec_lm = get_peft_model(codec_lm, lora_config)
|
||||
# codec_lm.print_trainable_parameters()
|
||||
else:
|
||||
codec_lm = None
|
||||
|
||||
@ -856,7 +807,6 @@ def run(rank, world_size, args):
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
# cut.custom["answer_cosyvoice_speech_token"] for cut in batch["supervisions"]["cut"]
|
||||
codec_len = len(c.custom["answer_cosyvoice_speech_token"])
|
||||
if codec_len > 2200:
|
||||
logging.warning(
|
||||
@ -873,7 +823,7 @@ def run(rank, world_size, args):
|
||||
if params.sampler_state_dict_path:
|
||||
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
||||
sampler_state_dict["max_duration"] = params.max_duration
|
||||
# TODO: load sampler state dict
|
||||
|
||||
train_dl = data_module.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user