mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
add multi gpu processing
This commit is contained in:
parent
ca84aff5d6
commit
7aa6c80ddb
@ -239,20 +239,57 @@ fi
|
||||
|
||||
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
||||
log "stage 14: Client"
|
||||
datasets=(alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa)
|
||||
datasets=(openbookqa commoneval)
|
||||
for dataset in ${datasets[@]}; do
|
||||
# sd-qa should use usa split
|
||||
if [ $dataset == "sd-qa" ]; then
|
||||
split_name="usa"
|
||||
else
|
||||
split_name="test"
|
||||
fi
|
||||
echo $dataset $split_name
|
||||
python3 ./qwen_omni/client.py \
|
||||
--subset-name $dataset --split-name $split_name \
|
||||
--output-dir result_adapter_librispeech_kl_div_qa_template
|
||||
exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
|
||||
# The final assignment of datasets in the original script is used here:
|
||||
# (alpacaeval_full wildvoice mmsu advbench bbh ifeval commoneval openbookqa sd-qa)
|
||||
declare -a target_datasets=("alpacaeval_full" "wildvoice" "ifeval" "commoneval" "openbookqa" "sd-qa" "advbench" "bbh" "mmsu")
|
||||
declare -a target_datasets=("openbookqa" "ifeval" "sd-qa" "commoneval" "alpacaeval_full")
|
||||
|
||||
NUM_CLIENT_JOBS=4 # Number of parallel client jobs
|
||||
BASE_PORT=8000 # Base port for servers
|
||||
|
||||
log "Starting $NUM_CLIENT_JOBS parallel client jobs to process ${#target_datasets[@]} datasets."
|
||||
|
||||
for job_id in $(seq 0 $(($NUM_CLIENT_JOBS - 1)))
|
||||
do
|
||||
( # Start a subshell for backgrounding this client job's tasks
|
||||
current_port=$(expr $BASE_PORT + $job_id)
|
||||
log "Client Job $job_id: Initializing. Will connect to port $current_port."
|
||||
|
||||
processed_count_for_this_job=0
|
||||
# Iterate over all datasets using their indices
|
||||
for i in "${!target_datasets[@]}"; do
|
||||
# Assign dataset to job_id in a round-robin fashion
|
||||
if [ $(($i % $NUM_CLIENT_JOBS)) -eq $job_id ]; then
|
||||
dataset="${target_datasets[$i]}"
|
||||
|
||||
# local split_name # Determine split_name based on dataset
|
||||
if [ "$dataset" == "sd-qa" ]; then
|
||||
split_name="usa"
|
||||
else
|
||||
split_name="test"
|
||||
fi
|
||||
|
||||
log "Client Job $job_id (Port $current_port): Processing dataset '$dataset' (split '$split_name')"
|
||||
python3 ./qwen_omni/client.py \
|
||||
--subset-name "$dataset" \
|
||||
--split-name "$split_name" \
|
||||
--output-dir "$exp_dir/results" \
|
||||
--port "$current_port" # Assuming client.py accepts --port
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
log "Client Job $job_id (Port $current_port): ERROR processing dataset '$dataset'."
|
||||
fi
|
||||
processed_count_for_this_job=$(($processed_count_for_this_job + 1))
|
||||
fi
|
||||
done
|
||||
log "Client Job $job_id (Port $current_port): Finished. Processed $processed_count_for_this_job datasets."
|
||||
) & # Run this client job's subshell in the background
|
||||
done
|
||||
|
||||
log "All client jobs launched. Waiting for completion..."
|
||||
wait # Wait for all backgrounded client jobs to complete
|
||||
log "All client jobs have completed."
|
||||
fi
|
||||
|
||||
if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
||||
@ -324,15 +361,26 @@ fi
|
||||
|
||||
|
||||
if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then
|
||||
# pip install gradio sherpa-onnx
|
||||
log "stage 17: Server for adapter only speech continuation"
|
||||
exp_dir=./qwen_omni/exp_speech2text
|
||||
python3 ./qwen_omni/server.py \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--checkpoint-path $exp_dir/epoch-6/pytorch_model.bin \
|
||||
--use-flash-attn True \
|
||||
--enable-speech-output False \
|
||||
--use-lora False --prompt-template continuation
|
||||
exp_dir=./qwen_omni/exp_speech2text_first_libri_continuation_second_ce
|
||||
|
||||
N_GPUS=4 # Define the number of GPUs/processes you want to launch
|
||||
|
||||
for id in $(seq 0 $(($N_GPUS - 1)))
|
||||
do
|
||||
log "Launching server on GPU $id with port $(expr 8000 + $id)"
|
||||
CUDA_VISIBLE_DEVICES=$id python3 ./qwen_omni/server.py \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--checkpoint-path $exp_dir/epoch-10/pytorch_model.bin \
|
||||
--use-flash-attn True \
|
||||
--enable-speech-output False \
|
||||
--port $(expr 8000 + $id) \
|
||||
--use-lora True &
|
||||
done
|
||||
|
||||
wait # Wait for all background processes to complete
|
||||
fi
|
||||
|
||||
if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
|
||||
|
@ -13,9 +13,15 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--server-url",
|
||||
type=str,
|
||||
default="http://localhost:8000",
|
||||
default="http://localhost",
|
||||
help="URL of the FastAPI server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port of the FastAPI server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
@ -48,7 +54,7 @@ def main():
|
||||
args.output_dir,
|
||||
f"{args.subset_name}-{args.split_name}.jsonl",
|
||||
)
|
||||
server_decode_url = f"{args.server_url}/decode"
|
||||
server_decode_url = f"{args.server_url}:{args.port}/decode"
|
||||
|
||||
print("Loading dataset...")
|
||||
if args.subset_name != "mmsu":
|
||||
|
@ -380,6 +380,19 @@ def process_batch_speech_continuation(batch: dict):
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
def process_batch_asr(batch: dict):
|
||||
messages = []
|
||||
for i in range(len(batch["supervisions"]["text"])):
|
||||
transcript = batch["supervisions"]["cut"][i].custom["text"]
|
||||
message = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Transcribe the following audio into text:\n\n{DEFAULT_SPEECH_TOKEN}",
|
||||
},
|
||||
{"role": "assistant", "content": transcript},
|
||||
]
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
def process_batch_text_continuation(batch: dict):
|
||||
messages = []
|
||||
@ -548,6 +561,8 @@ def compute_loss(
|
||||
messages = process_batch_speech_continuation(batch)
|
||||
if params.loss_type == "kl_div":
|
||||
messages_text = process_batch_text_continuation(batch)
|
||||
elif params.dataset_format == "asr":
|
||||
messages = process_batch_asr(batch)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset format: {params.dataset_format}")
|
||||
|
||||
@ -1020,7 +1035,7 @@ def run(rank, world_size, args):
|
||||
elif params.dataset_format == "vocalnet":
|
||||
train_cuts = data_module.train_cuts_en_vocalnet()
|
||||
valid_cuts = data_module.valid_cuts_en_vocalnet()
|
||||
elif params.dataset_format == "speech_continuation":
|
||||
elif params.dataset_format == "speech_continuation" or params.dataset_format == "asr":
|
||||
if params.dataset == "multi_en":
|
||||
train_cuts = data_module.train_cuts_ultravox()
|
||||
elif params.dataset == "librispeech":
|
||||
|
Loading…
x
Reference in New Issue
Block a user