2025-10-15 10:51:30 +03:30

165 lines
5.5 KiB
Python

#!/usr/bin/env python3
"""
run.py — demo runner for TritonGrpcClient service methods.
Usage:
python run.py # auto-find an audio file in cwd and run both modes
python run.py /path/to/file # run both modes on a specific file
python run.py --mode sse # run only SSE-mode
python run.py --mode ws # run only WebSocket-mode
"""
import argparse
import asyncio
import json
import os
import sys
from pathlib import Path
from typing import List
# Adjust import to where your client class lives
# e.g. from triton_client import TritonGrpcClient
from .service import TritonGrpcClient
# Small helper to find an audio file if none provided
AUDIO_EXTS = [".wav", ".mp3", ".m4a", ".flac", ".ogg", ".aac"]
def find_first_audio_in_cwd() -> Path | None:
cwd = Path.cwd()
for p in cwd.iterdir():
if p.suffix.lower() in AUDIO_EXTS and p.is_file():
return p
return None
# ---------- Fake WebSocket for exercising websocket_stream_from_websocket ----------
class FakeWebSocket:
"""
Minimal fake WebSocket implementing the methods used by websocket_stream_from_websocket:
- accept()
- receive() -> dict with "bytes" or "text"
- send_json(obj)
- close()
It streams the provided bytes in small binary frames, then a JSON text frame {"event":"end"}.
"""
def __init__(self, data: bytes, frame_size: int = 16 * 1024):
self._data = data
self._frame_size = frame_size
self._offset = 0
self._sent_end = False
self.sent_messages: List[dict] = []
self.closed = False
async def accept(self):
# server expects to optionally call accept; nothing to do
print("[FakeWebSocket] accept() called")
async def receive(self):
"""
Return one frame at a time:
- {"bytes": b"..."} while data remains
- then {"text": json.dumps({"event":"end"})}
After that, sleep forever (server won't call receive again in your code).
"""
if self._offset < len(self._data):
end = min(len(self._data), self._offset + self._frame_size)
chunk = self._data[self._offset : end]
self._offset = end
# mimic the WebSocket dict shape used in your code
return {"bytes": chunk}
if not self._sent_end:
self._sent_end = True
return {"text": json.dumps({"event": "end"})}
# Block a bit — server should have stopped receiving after 'end'
await asyncio.sleep(3600)
return {}
async def send_json(self, obj):
# server sends results here; capture & print
self.sent_messages.append(obj)
print("[FakeWebSocket] send_json:", json.dumps(obj, ensure_ascii=False))
async def close(self):
self.closed = True
print("[FakeWebSocket] close() called")
# ---------- Demo runners ----------
async def asr_sse_mode(client: TritonGrpcClient, audio_path: Path):
print("\n=== SSE MODE (event_stream_from_bytes) ===")
raw = audio_path.read_bytes()
# event_stream_from_bytes returns an async generator — iterate it
try:
async for s in client.event_stream_from_bytes(raw=raw, filename=audio_path.name):
# s is already a dict-like object emitted by your SSE generator
print("[SSE] OUT:", json.dumps(s, ensure_ascii=False))
except Exception as exc:
print("[SSE] Exception while streaming SSE:", exc)
async def asr_ws_mode(client: TritonGrpcClient, audio_path: Path):
print("\n=== WEBSOCKET MODE (websocket_stream_from_websocket) ===")
raw = audio_path.read_bytes()
fake_ws = FakeWebSocket(raw, frame_size=16 * 1024)
# Run the server-side websocket handler. It will call fake_ws.receive() and fake_ws.send_json()
try:
await client.websocket_stream_from_websocket(fake_ws, filename=audio_path.name)
except Exception as exc:
print("[WS] Exception while running websocket_stream_from_websocket:", exc)
finally:
print("[WS] Collected messages from server (send_json calls):")
for i, m in enumerate(fake_ws.sent_messages, 1):
print(f" [{i}] {json.dumps(m, ensure_ascii=False)}")
async def main_async(audio_path: Path, modes: List[str]):
client = TritonGrpcClient() # init; if your client needs args, adjust here
if "sse" in modes:
await asr_sse_mode(client, audio_path)
if "ws" in modes:
await asr_ws_mode(client, audio_path)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("file", nargs="?", help="audio file path (optional). If omitted, searches cwd.")
p.add_argument(
"--mode",
choices=["sse", "ws", "both"],
default="both",
help="Which method(s) to run against the service",
)
return p.parse_args()
def main():
args = parse_args()
if args.file:
audio_path = Path(args.file)
if not audio_path.exists():
print("Audio file does not exist:", audio_path)
sys.exit(2)
else:
found = find_first_audio_in_cwd()
if not found:
print("No audio file found in cwd. Place an audio file (wav/mp3/m4a/flac/ogg) here or pass a path.")
sys.exit(2)
audio_path = found
modes = ["sse", "ws"] if args.mode == "both" else [args.mode]
print(f"Using audio file: {audio_path} — running modes: {modes}")
try:
asyncio.run(main_async(audio_path, modes))
except KeyboardInterrupt:
print("Interrupted.")
if __name__ == "__main__":
main()