From 7aa6c80ddb7a634fc36cbc36173005a24896be12 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 21 May 2025 21:54:59 -0700 Subject: [PATCH] add multi gpu processing --- egs/speech_llm/SPEECH2SPEECH/prepare.sh | 90 ++++++++++++++----- .../SPEECH2SPEECH/qwen_omni/client.py | 10 ++- .../SPEECH2SPEECH/qwen_omni/train.py | 17 +++- 3 files changed, 93 insertions(+), 24 deletions(-) diff --git a/egs/speech_llm/SPEECH2SPEECH/prepare.sh b/egs/speech_llm/SPEECH2SPEECH/prepare.sh index b86288c5f..98c6ced9b 100644 --- a/egs/speech_llm/SPEECH2SPEECH/prepare.sh +++ b/egs/speech_llm/SPEECH2SPEECH/prepare.sh @@ -239,20 +239,57 @@ fi if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then log "stage 14: Client" - datasets=(alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa) - datasets=(openbookqa commoneval) - for dataset in ${datasets[@]}; do - # sd-qa should use usa split - if [ $dataset == "sd-qa" ]; then - split_name="usa" - else - split_name="test" - fi - echo $dataset $split_name - python3 ./qwen_omni/client.py \ - --subset-name $dataset --split-name $split_name \ - --output-dir result_adapter_librispeech_kl_div_qa_template + exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce + # The final assignment of datasets in the original script is used here: + # (alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa) + declare -a target_datasets=("alpacaeval_full" "wildvoice" "ifeval" "commoneval" "openbookqa" "sd-qa" "advbench" "bbh" "mmsu") + declare -a target_datasets=("openbookqa" "ifeval" "sd-qa" "commoneval" "alpacaeval_full") + + NUM_CLIENT_JOBS=4 # Number of parallel client jobs + BASE_PORT=8000 # Base port for servers + + log "Starting $NUM_CLIENT_JOBS parallel client jobs to process ${#target_datasets[@]} datasets." + + for job_id in $(seq 0 $(($NUM_CLIENT_JOBS - 1))) + do + ( # Start a subshell for backgrounding this client job's tasks + current_port=$(expr $BASE_PORT + $job_id) + log "Client Job $job_id: Initializing. Will connect to port $current_port." + + processed_count_for_this_job=0 + # Iterate over all datasets using their indices + for i in "${!target_datasets[@]}"; do + # Assign dataset to job_id in a round-robin fashion + if [ $(($i % $NUM_CLIENT_JOBS)) -eq $job_id ]; then + dataset="${target_datasets[$i]}" + + # local split_name # Determine split_name based on dataset + if [ "$dataset" == "sd-qa" ]; then + split_name="usa" + else + split_name="test" + fi + + log "Client Job $job_id (Port $current_port): Processing dataset '$dataset' (split '$split_name')" + python3 ./qwen_omni/client.py \ + --subset-name "$dataset" \ + --split-name "$split_name" \ + --output-dir "$exp_dir/results" \ + --port "$current_port" # Assuming client.py accepts --port + + if [ $? -ne 0 ]; then + log "Client Job $job_id (Port $current_port): ERROR processing dataset '$dataset'." + fi + processed_count_for_this_job=$(($processed_count_for_this_job + 1)) + fi + done + log "Client Job $job_id (Port $current_port): Finished. Processed $processed_count_for_this_job datasets." + ) & # Run this client job's subshell in the background done + + log "All client jobs launched. Waiting for completion..." + wait # Wait for all backgrounded client jobs to complete + log "All client jobs have completed." fi if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then @@ -324,15 +361,26 @@ fi if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then + # pip install gradio sherpa-onnx log "stage 17: Server for adapter only speech continuation" - exp_dir=./qwen_omni/exp_speech2text - python3 ./qwen_omni/server.py \ - --speech-encoder-path-or-name models/large-v2.pt \ - --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ - --checkpoint-path $exp_dir/epoch-6/pytorch_model.bin \ - --use-flash-attn True \ - --enable-speech-output False \ - --use-lora False --prompt-template continuation + exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce + + N_GPUS=4 # Define the number of GPUs/processes you want to launch + + for id in $(seq 0 $(($N_GPUS - 1))) + do + log "Launching server on GPU $id with port $(expr 8000 + $id)" + CUDA_VISIBLE_DEVICES=$id python3 ./qwen_omni/server.py \ + --speech-encoder-path-or-name models/large-v2.pt \ + --llm-path-or-name models/Qwen2.5-0.5B-Instruct \ + --checkpoint-path $exp_dir/epoch-10/pytorch_model.bin \ + --use-flash-attn True \ + --enable-speech-output False \ + --port $(expr 8000 + $id) \ + --use-lora True & + done + + wait # Wait for all background processes to complete fi if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py index 05c363979..7dc279e48 100644 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py @@ -13,9 +13,15 @@ def get_args(): parser.add_argument( "--server-url", type=str, - default="http://localhost:8000", + default="http://localhost", help="URL of the FastAPI server", ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port of the FastAPI server", + ) parser.add_argument( "--dataset-name", type=str, @@ -48,7 +54,7 @@ def main(): args.output_dir, f"{args.subset_name}-{args.split_name}.jsonl", ) - server_decode_url = f"{args.server_url}/decode" + server_decode_url = f"{args.server_url}:{args.port}/decode" print("Loading dataset...") if args.subset_name != "mmsu": diff --git a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py index 81aac84e5..d5a2f7cf9 100755 --- a/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py +++ b/egs/speech_llm/SPEECH2SPEECH/qwen_omni/train.py @@ -380,6 +380,19 @@ def process_batch_speech_continuation(batch: dict): messages.append(message) return messages +def process_batch_asr(batch: dict): + messages = [] + for i in range(len(batch["supervisions"]["text"])): + transcript = batch["supervisions"]["cut"][i].custom["text"] + message = [ + { + "role": "user", + "content": f"Transcribe the following audio into text:\n\n{DEFAULT_SPEECH_TOKEN}", + }, + {"role": "assistant", "content": transcript}, + ] + messages.append(message) + return messages def process_batch_text_continuation(batch: dict): messages = [] @@ -548,6 +561,8 @@ def compute_loss( messages = process_batch_speech_continuation(batch) if params.loss_type == "kl_div": messages_text = process_batch_text_continuation(batch) + elif params.dataset_format == "asr": + messages = process_batch_asr(batch) else: raise ValueError(f"Unknown dataset format: {params.dataset_format}") @@ -1020,7 +1035,7 @@ def run(rank, world_size, args): elif params.dataset_format == "vocalnet": train_cuts = data_module.train_cuts_en_vocalnet() valid_cuts = data_module.valid_cuts_en_vocalnet() - elif params.dataset_format == "speech_continuation": + elif params.dataset_format == "speech_continuation" or params.dataset_format == "asr": if params.dataset == "multi_en": train_cuts = data_module.train_cuts_ultravox() elif params.dataset == "librispeech":