2025-05-13 05:37:11 +00:00

143 lines
4.8 KiB
Python

# client.py
import argparse
import json
import os
import requests
from datasets import load_dataset
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser(description="Speech-to-Text Client")
parser.add_argument(
"--server-url",
type=str,
default="http://localhost:8000",
help="URL of the FastAPI server",
)
parser.add_argument(
"--dataset-name",
type=str,
default="hlt-lab/voicebench",
help="Hugging Face dataset name",
)
parser.add_argument(
"--subset-name",
type=str,
default="commoneval", # Adjust as needed
help="Dataset subset name",
)
parser.add_argument(
"--split-name",
type=str,
default="test", # Adjust as needed
help="Dataset split name",
)
parser.add_argument(
"--output-dir", required=True, type=str, help="Directory to save results"
)
args = parser.parse_args()
return args
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
output_filename = os.path.join(
args.output_dir,
f"{args.subset_name}-{args.split_name}.jsonl",
)
server_decode_url = f"{args.server_url}/decode"
print("Loading dataset...")
dataset = load_dataset(
args.dataset_name,
args.subset_name,
split=args.split_name,
trust_remote_code=True,
)
print(f"Dataset loaded with {len(dataset)} samples.")
print(f"Sending requests to {server_decode_url}...")
print(f"Saving results to {output_filename}")
with open(output_filename, "w", encoding="utf-8") as outfile:
# Iterate directly over the dataset
progress_bar = tqdm(dataset, desc="Processing", unit="samples")
for item in progress_bar:
audio_info = item.get("audio")
assert (
audio_info["sampling_rate"] == 16000
), f"Sampling rate is {audio_info['sampling_rate']}, not 16khz"
# Prepare data for JSON serialization and server request
audio_array = audio_info["array"].tolist() # Convert numpy array to list
result_dict = {}
for key in item.keys():
if key != "audio":
# Ensure other fields are JSON serializable
try:
# Attempt to serialize to catch issues early (optional)
json.dumps(item[key])
result_dict[key] = item[key]
except (TypeError, OverflowError):
print(
f"Warning: Converting non-serializable key '{key}' to string."
)
result_dict[key] = str(
item[key]
) # Convert problematic types to string
payload = {
"audio": audio_array,
"sampling_rate": 16000,
}
try:
response = requests.post(server_decode_url, json=payload, timeout=60)
response.raise_for_status()
server_response = response.json()
decoded_text = server_response.get("text", "")
# Add the response to the result dictionary
result_dict["response"] = decoded_text
print(result_dict)
# Write result to JSONL file
json.dump(result_dict, outfile, ensure_ascii=False)
outfile.write("\n")
except requests.exceptions.RequestException as e:
print(f"\nError sending request for an item: {e}")
error_entry = result_dict # Use the data prepared so far
error_entry["error"] = str(e)
error_entry["response"] = ""
json.dump(error_entry, outfile, ensure_ascii=False)
outfile.write("\n")
except json.JSONDecodeError:
print("\nError decoding server response for an item.")
error_entry = result_dict
error_entry["error"] = "Invalid JSON response from server"
error_entry["response"] = ""
json.dump(error_entry, outfile, ensure_ascii=False)
outfile.write("\n")
except Exception as e:
print(f"\nUnexpected error processing an item: {e}")
error_entry = result_dict
error_entry["error"] = f"Unexpected error: {str(e)}"
error_entry["response"] = ""
json.dump(error_entry, outfile, ensure_ascii=False)
outfile.write("\n")
# Progress bar updates automatically by iterating over tqdm(dataset)
# No need to close progress_bar explicitly when iterating directly
print("Processing finished.")
if __name__ == "__main__":
main()