165 lines
5.5 KiB
Python
165 lines
5.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
run.py — demo runner for TritonGrpcClient service methods.
|
|
|
|
Usage:
|
|
python run.py # auto-find an audio file in cwd and run both modes
|
|
python run.py /path/to/file # run both modes on a specific file
|
|
python run.py --mode sse # run only SSE-mode
|
|
python run.py --mode ws # run only WebSocket-mode
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
# Adjust import to where your client class lives
|
|
# e.g. from triton_client import TritonGrpcClient
|
|
from .service import TritonGrpcClient
|
|
|
|
# Small helper to find an audio file if none provided
|
|
AUDIO_EXTS = [".wav", ".mp3", ".m4a", ".flac", ".ogg", ".aac"]
|
|
|
|
|
|
def find_first_audio_in_cwd() -> Path | None:
|
|
cwd = Path.cwd()
|
|
for p in cwd.iterdir():
|
|
if p.suffix.lower() in AUDIO_EXTS and p.is_file():
|
|
return p
|
|
return None
|
|
|
|
|
|
# ---------- Fake WebSocket for exercising websocket_stream_from_websocket ----------
|
|
class FakeWebSocket:
|
|
"""
|
|
Minimal fake WebSocket implementing the methods used by websocket_stream_from_websocket:
|
|
- accept()
|
|
- receive() -> dict with "bytes" or "text"
|
|
- send_json(obj)
|
|
- close()
|
|
It streams the provided bytes in small binary frames, then a JSON text frame {"event":"end"}.
|
|
"""
|
|
|
|
def __init__(self, data: bytes, frame_size: int = 16 * 1024):
|
|
self._data = data
|
|
self._frame_size = frame_size
|
|
self._offset = 0
|
|
self._sent_end = False
|
|
self.sent_messages: List[dict] = []
|
|
self.closed = False
|
|
|
|
async def accept(self):
|
|
# server expects to optionally call accept; nothing to do
|
|
print("[FakeWebSocket] accept() called")
|
|
|
|
async def receive(self):
|
|
"""
|
|
Return one frame at a time:
|
|
- {"bytes": b"..."} while data remains
|
|
- then {"text": json.dumps({"event":"end"})}
|
|
After that, sleep forever (server won't call receive again in your code).
|
|
"""
|
|
if self._offset < len(self._data):
|
|
end = min(len(self._data), self._offset + self._frame_size)
|
|
chunk = self._data[self._offset : end]
|
|
self._offset = end
|
|
# mimic the WebSocket dict shape used in your code
|
|
return {"bytes": chunk}
|
|
if not self._sent_end:
|
|
self._sent_end = True
|
|
return {"text": json.dumps({"event": "end"})}
|
|
# Block a bit — server should have stopped receiving after 'end'
|
|
await asyncio.sleep(3600)
|
|
return {}
|
|
|
|
async def send_json(self, obj):
|
|
# server sends results here; capture & print
|
|
self.sent_messages.append(obj)
|
|
print("[FakeWebSocket] send_json:", json.dumps(obj, ensure_ascii=False))
|
|
|
|
async def close(self):
|
|
self.closed = True
|
|
print("[FakeWebSocket] close() called")
|
|
|
|
|
|
# ---------- Demo runners ----------
|
|
async def asr_sse_mode(client: TritonGrpcClient, audio_path: Path):
|
|
print("\n=== SSE MODE (event_stream_from_bytes) ===")
|
|
raw = audio_path.read_bytes()
|
|
# event_stream_from_bytes returns an async generator — iterate it
|
|
try:
|
|
async for s in client.event_stream_from_bytes(raw=raw, filename=audio_path.name):
|
|
# s is already a dict-like object emitted by your SSE generator
|
|
print("[SSE] OUT:", json.dumps(s, ensure_ascii=False))
|
|
except Exception as exc:
|
|
print("[SSE] Exception while streaming SSE:", exc)
|
|
|
|
|
|
async def asr_ws_mode(client: TritonGrpcClient, audio_path: Path):
|
|
print("\n=== WEBSOCKET MODE (websocket_stream_from_websocket) ===")
|
|
raw = audio_path.read_bytes()
|
|
fake_ws = FakeWebSocket(raw, frame_size=16 * 1024)
|
|
|
|
# Run the server-side websocket handler. It will call fake_ws.receive() and fake_ws.send_json()
|
|
try:
|
|
await client.websocket_stream_from_websocket(fake_ws, filename=audio_path.name)
|
|
except Exception as exc:
|
|
print("[WS] Exception while running websocket_stream_from_websocket:", exc)
|
|
finally:
|
|
print("[WS] Collected messages from server (send_json calls):")
|
|
for i, m in enumerate(fake_ws.sent_messages, 1):
|
|
print(f" [{i}] {json.dumps(m, ensure_ascii=False)}")
|
|
|
|
|
|
async def main_async(audio_path: Path, modes: List[str]):
|
|
client = TritonGrpcClient() # init; if your client needs args, adjust here
|
|
|
|
if "sse" in modes:
|
|
await asr_sse_mode(client, audio_path)
|
|
|
|
if "ws" in modes:
|
|
await asr_ws_mode(client, audio_path)
|
|
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("file", nargs="?", help="audio file path (optional). If omitted, searches cwd.")
|
|
p.add_argument(
|
|
"--mode",
|
|
choices=["sse", "ws", "both"],
|
|
default="both",
|
|
help="Which method(s) to run against the service",
|
|
)
|
|
return p.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
if args.file:
|
|
audio_path = Path(args.file)
|
|
if not audio_path.exists():
|
|
print("Audio file does not exist:", audio_path)
|
|
sys.exit(2)
|
|
else:
|
|
found = find_first_audio_in_cwd()
|
|
if not found:
|
|
print("No audio file found in cwd. Place an audio file (wav/mp3/m4a/flac/ogg) here or pass a path.")
|
|
sys.exit(2)
|
|
audio_path = found
|
|
|
|
modes = ["sse", "ws"] if args.mode == "both" else [args.mode]
|
|
print(f"Using audio file: {audio_path} — running modes: {modes}")
|
|
|
|
try:
|
|
asyncio.run(main_async(audio_path, modes))
|
|
except KeyboardInterrupt:
|
|
print("Interrupted.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|