mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
add multi gpu processing
This commit is contained in:
parent
ca84aff5d6
commit
7aa6c80ddb
@ -239,20 +239,57 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
||||||
log "stage 14: Client"
|
log "stage 14: Client"
|
||||||
datasets=(alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa)
|
exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
|
||||||
datasets=(openbookqa commoneval)
|
# The final assignment of datasets in the original script is used here:
|
||||||
for dataset in ${datasets[@]}; do
|
# (alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa)
|
||||||
# sd-qa should use usa split
|
declare -a target_datasets=("alpacaeval_full" "wildvoice" "ifeval" "commoneval" "openbookqa" "sd-qa" "advbench" "bbh" "mmsu")
|
||||||
if [ $dataset == "sd-qa" ]; then
|
declare -a target_datasets=("openbookqa" "ifeval" "sd-qa" "commoneval" "alpacaeval_full")
|
||||||
split_name="usa"
|
|
||||||
else
|
NUM_CLIENT_JOBS=4 # Number of parallel client jobs
|
||||||
split_name="test"
|
BASE_PORT=8000 # Base port for servers
|
||||||
fi
|
|
||||||
echo $dataset $split_name
|
log "Starting $NUM_CLIENT_JOBS parallel client jobs to process ${#target_datasets[@]} datasets."
|
||||||
python3 ./qwen_omni/client.py \
|
|
||||||
--subset-name $dataset --split-name $split_name \
|
for job_id in $(seq 0 $(($NUM_CLIENT_JOBS - 1)))
|
||||||
--output-dir result_adapter_librispeech_kl_div_qa_template
|
do
|
||||||
|
( # Start a subshell for backgrounding this client job's tasks
|
||||||
|
current_port=$(expr $BASE_PORT + $job_id)
|
||||||
|
log "Client Job $job_id: Initializing. Will connect to port $current_port."
|
||||||
|
|
||||||
|
processed_count_for_this_job=0
|
||||||
|
# Iterate over all datasets using their indices
|
||||||
|
for i in "${!target_datasets[@]}"; do
|
||||||
|
# Assign dataset to job_id in a round-robin fashion
|
||||||
|
if [ $(($i % $NUM_CLIENT_JOBS)) -eq $job_id ]; then
|
||||||
|
dataset="${target_datasets[$i]}"
|
||||||
|
|
||||||
|
# local split_name # Determine split_name based on dataset
|
||||||
|
if [ "$dataset" == "sd-qa" ]; then
|
||||||
|
split_name="usa"
|
||||||
|
else
|
||||||
|
split_name="test"
|
||||||
|
fi
|
||||||
|
|
||||||
|
log "Client Job $job_id (Port $current_port): Processing dataset '$dataset' (split '$split_name')"
|
||||||
|
python3 ./qwen_omni/client.py \
|
||||||
|
--subset-name "$dataset" \
|
||||||
|
--split-name "$split_name" \
|
||||||
|
--output-dir "$exp_dir/results" \
|
||||||
|
--port "$current_port" # Assuming client.py accepts --port
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
log "Client Job $job_id (Port $current_port): ERROR processing dataset '$dataset'."
|
||||||
|
fi
|
||||||
|
processed_count_for_this_job=$(($processed_count_for_this_job + 1))
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
log "Client Job $job_id (Port $current_port): Finished. Processed $processed_count_for_this_job datasets."
|
||||||
|
) & # Run this client job's subshell in the background
|
||||||
done
|
done
|
||||||
|
|
||||||
|
log "All client jobs launched. Waiting for completion..."
|
||||||
|
wait # Wait for all backgrounded client jobs to complete
|
||||||
|
log "All client jobs have completed."
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
||||||
@ -324,15 +361,26 @@ fi
|
|||||||
|
|
||||||
|
|
||||||
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
|
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
|
||||||
|
# pip install gradio sherpa-onnx
|
||||||
log "stage 17: Server for adapter only speech continuation"
|
log "stage 17: Server for adapter only speech continuation"
|
||||||
exp_dir=./qwen_omni/exp_speech2text
|
exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
|
||||||
python3 ./qwen_omni/server.py \
|
|
||||||
--speech-encoder-path-or-name models/large-v2.pt \
|
N_GPUS=4 # Define the number of GPUs/processes you want to launch
|
||||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
|
||||||
--checkpoint-path $exp_dir/epoch-6/pytorch_model.bin \
|
for id in $(seq 0 $(($N_GPUS - 1)))
|
||||||
--use-flash-attn True \
|
do
|
||||||
--enable-speech-output False \
|
log "Launching server on GPU $id with port $(expr 8000 + $id)"
|
||||||
--use-lora False --prompt-template continuation
|
CUDA_VISIBLE_DEVICES=$id python3 ./qwen_omni/server.py \
|
||||||
|
--speech-encoder-path-or-name models/large-v2.pt \
|
||||||
|
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||||
|
--checkpoint-path $exp_dir/epoch-10/pytorch_model.bin \
|
||||||
|
--use-flash-attn True \
|
||||||
|
--enable-speech-output False \
|
||||||
|
--port $(expr 8000 + $id) \
|
||||||
|
--use-lora True &
|
||||||
|
done
|
||||||
|
|
||||||
|
wait # Wait for all background processes to complete
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
|
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
|
||||||
|
@ -13,9 +13,15 @@ def get_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--server-url",
|
"--server-url",
|
||||||
type=str,
|
type=str,
|
||||||
default="http://localhost:8000",
|
default="http://localhost",
|
||||||
help="URL of the FastAPI server",
|
help="URL of the FastAPI server",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=8000,
|
||||||
|
help="Port of the FastAPI server",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset-name",
|
"--dataset-name",
|
||||||
type=str,
|
type=str,
|
||||||
@ -48,7 +54,7 @@ def main():
|
|||||||
args.output_dir,
|
args.output_dir,
|
||||||
f"{args.subset_name}-{args.split_name}.jsonl",
|
f"{args.subset_name}-{args.split_name}.jsonl",
|
||||||
)
|
)
|
||||||
server_decode_url = f"{args.server_url}/decode"
|
server_decode_url = f"{args.server_url}:{args.port}/decode"
|
||||||
|
|
||||||
print("Loading dataset...")
|
print("Loading dataset...")
|
||||||
if args.subset_name != "mmsu":
|
if args.subset_name != "mmsu":
|
||||||
|
@ -380,6 +380,19 @@ def process_batch_speech_continuation(batch: dict):
|
|||||||
messages.append(message)
|
messages.append(message)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
def process_batch_asr(batch: dict):
|
||||||
|
messages = []
|
||||||
|
for i in range(len(batch["supervisions"]["text"])):
|
||||||
|
transcript = batch["supervisions"]["cut"][i].custom["text"]
|
||||||
|
message = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Transcribe the following audio into text:\n\n{DEFAULT_SPEECH_TOKEN}",
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": transcript},
|
||||||
|
]
|
||||||
|
messages.append(message)
|
||||||
|
return messages
|
||||||
|
|
||||||
def process_batch_text_continuation(batch: dict):
|
def process_batch_text_continuation(batch: dict):
|
||||||
messages = []
|
messages = []
|
||||||
@ -548,6 +561,8 @@ def compute_loss(
|
|||||||
messages = process_batch_speech_continuation(batch)
|
messages = process_batch_speech_continuation(batch)
|
||||||
if params.loss_type == "kl_div":
|
if params.loss_type == "kl_div":
|
||||||
messages_text = process_batch_text_continuation(batch)
|
messages_text = process_batch_text_continuation(batch)
|
||||||
|
elif params.dataset_format == "asr":
|
||||||
|
messages = process_batch_asr(batch)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
|
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
|
||||||
|
|
||||||
@ -1020,7 +1035,7 @@ def run(rank, world_size, args):
|
|||||||
elif params.dataset_format == "vocalnet":
|
elif params.dataset_format == "vocalnet":
|
||||||
train_cuts = data_module.train_cuts_en_vocalnet()
|
train_cuts = data_module.train_cuts_en_vocalnet()
|
||||||
valid_cuts = data_module.valid_cuts_en_vocalnet()
|
valid_cuts = data_module.valid_cuts_en_vocalnet()
|
||||||
elif params.dataset_format == "speech_continuation":
|
elif params.dataset_format == "speech_continuation" or params.dataset_format == "asr":
|
||||||
if params.dataset == "multi_en":
|
if params.dataset == "multi_en":
|
||||||
train_cuts = data_module.train_cuts_ultravox()
|
train_cuts = data_module.train_cuts_ultravox()
|
||||||
elif params.dataset == "librispeech":
|
elif params.dataset == "librispeech":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user