add multi gpu processing

This commit is contained in:
root 2025-05-21 21:54:59 -07:00
parent ca84aff5d6
commit 7aa6c80ddb
3 changed files with 93 additions and 24 deletions

View File

@ -239,20 +239,57 @@ fi
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
log "stage 14: Client" log "stage 14: Client"
datasets=(alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa) exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
datasets=(openbookqa commoneval) # The final assignment of datasets in the original script is used here:
for dataset in ${datasets[@]}; do # (alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa)
# sd-qa should use usa split declare -a target_datasets=("alpacaeval_full" "wildvoice" "ifeval" "commoneval" "openbookqa" "sd-qa" "advbench" "bbh" "mmsu")
if [ $dataset == "sd-qa" ]; then declare -a target_datasets=("openbookqa" "ifeval" "sd-qa" "commoneval" "alpacaeval_full")
NUM_CLIENT_JOBS=4 # Number of parallel client jobs
BASE_PORT=8000 # Base port for servers
log "Starting $NUM_CLIENT_JOBS parallel client jobs to process ${#target_datasets[@]} datasets."
for job_id in $(seq 0 $(($NUM_CLIENT_JOBS - 1)))
do
( # Start a subshell for backgrounding this client job's tasks
current_port=$(expr $BASE_PORT + $job_id)
log "Client Job $job_id: Initializing. Will connect to port $current_port."
processed_count_for_this_job=0
# Iterate over all datasets using their indices
for i in "${!target_datasets[@]}"; do
# Assign dataset to job_id in a round-robin fashion
if [ $(($i % $NUM_CLIENT_JOBS)) -eq $job_id ]; then
dataset="${target_datasets[$i]}"
# local split_name # Determine split_name based on dataset
if [ "$dataset" == "sd-qa" ]; then
split_name="usa" split_name="usa"
else else
split_name="test" split_name="test"
fi fi
echo $dataset $split_name
log "Client Job $job_id (Port $current_port): Processing dataset '$dataset' (split '$split_name')"
python3 ./qwen_omni/client.py \ python3 ./qwen_omni/client.py \
--subset-name $dataset --split-name $split_name \ --subset-name "$dataset" \
--output-dir result_adapter_librispeech_kl_div_qa_template --split-name "$split_name" \
--output-dir "$exp_dir/results" \
--port "$current_port" # Assuming client.py accepts --port
if [ $? -ne 0 ]; then
log "Client Job $job_id (Port $current_port): ERROR processing dataset '$dataset'."
fi
processed_count_for_this_job=$(($processed_count_for_this_job + 1))
fi
done done
log "Client Job $job_id (Port $current_port): Finished. Processed $processed_count_for_this_job datasets."
) & # Run this client job's subshell in the background
done
log "All client jobs launched. Waiting for completion..."
wait # Wait for all backgrounded client jobs to complete
log "All client jobs have completed."
fi fi
if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
@ -324,15 +361,26 @@ fi
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
# pip install gradio sherpa-onnx
log "stage 17: Server for adapter only speech continuation" log "stage 17: Server for adapter only speech continuation"
exp_dir=./qwen_omni/exp_speech2text exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
python3 ./qwen_omni/server.py \
N_GPUS=4 # Define the number of GPUs/processes you want to launch
for id in $(seq 0 $(($N_GPUS - 1)))
do
log "Launching server on GPU $id with port $(expr 8000 + $id)"
CUDA_VISIBLE_DEVICES=$id python3 ./qwen_omni/server.py \
--speech-encoder-path-or-name models/large-v2.pt \ --speech-encoder-path-or-name models/large-v2.pt \
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \ --llm-path-or-name models/Qwen2.5-0.5B-Instruct \
--checkpoint-path $exp_dir/epoch-6/pytorch_model.bin \ --checkpoint-path $exp_dir/epoch-10/pytorch_model.bin \
--use-flash-attn True \ --use-flash-attn True \
--enable-speech-output False \ --enable-speech-output False \
--use-lora False --prompt-template continuation --port $(expr 8000 + $id) \
--use-lora True &
done
wait # Wait for all background processes to complete
fi fi
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then

View File

@ -13,9 +13,15 @@ def get_args():
parser.add_argument( parser.add_argument(
"--server-url", "--server-url",
type=str, type=str,
default="http://localhost:8000", default="http://localhost",
help="URL of the FastAPI server", help="URL of the FastAPI server",
) )
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port of the FastAPI server",
)
parser.add_argument( parser.add_argument(
"--dataset-name", "--dataset-name",
type=str, type=str,
@ -48,7 +54,7 @@ def main():
args.output_dir, args.output_dir,
f"{args.subset_name}-{args.split_name}.jsonl", f"{args.subset_name}-{args.split_name}.jsonl",
) )
server_decode_url = f"{args.server_url}/decode" server_decode_url = f"{args.server_url}:{args.port}/decode"
print("Loading dataset...") print("Loading dataset...")
if args.subset_name != "mmsu": if args.subset_name != "mmsu":

View File

@ -380,6 +380,19 @@ def process_batch_speech_continuation(batch: dict):
messages.append(message) messages.append(message)
return messages return messages
def process_batch_asr(batch: dict):
messages = []
for i in range(len(batch["supervisions"]["text"])):
transcript = batch["supervisions"]["cut"][i].custom["text"]
message = [
{
"role": "user",
"content": f"Transcribe the following audio into text:\n\n{DEFAULT_SPEECH_TOKEN}",
},
{"role": "assistant", "content": transcript},
]
messages.append(message)
return messages
def process_batch_text_continuation(batch: dict): def process_batch_text_continuation(batch: dict):
messages = [] messages = []
@ -548,6 +561,8 @@ def compute_loss(
messages = process_batch_speech_continuation(batch) messages = process_batch_speech_continuation(batch)
if params.loss_type == "kl_div": if params.loss_type == "kl_div":
messages_text = process_batch_text_continuation(batch) messages_text = process_batch_text_continuation(batch)
elif params.dataset_format == "asr":
messages = process_batch_asr(batch)
else: else:
raise ValueError(f"Unknown dataset format: {params.dataset_format}") raise ValueError(f"Unknown dataset format: {params.dataset_format}")
@ -1020,7 +1035,7 @@ def run(rank, world_size, args):
elif params.dataset_format == "vocalnet": elif params.dataset_format == "vocalnet":
train_cuts = data_module.train_cuts_en_vocalnet() train_cuts = data_module.train_cuts_en_vocalnet()
valid_cuts = data_module.valid_cuts_en_vocalnet() valid_cuts = data_module.valid_cuts_en_vocalnet()
elif params.dataset_format == "speech_continuation": elif params.dataset_format == "speech_continuation" or params.dataset_format == "asr":
if params.dataset == "multi_en": if params.dataset == "multi_en":
train_cuts = data_module.train_cuts_ultravox() train_cuts = data_module.train_cuts_ultravox()
elif params.dataset == "librispeech": elif params.dataset == "librispeech":