diff --git a/src/main.py b/src/main.py index 73285ef..5e2952f 100644 --- a/src/main.py +++ b/src/main.py @@ -12,19 +12,65 @@ client = TritonGrpcClient(TRITON_URL) @app.websocket("/asr") async def websocket_endpoint(ws: WebSocket): + print("Starting ASR2 WebSocket") + + try: + await client.websocket_stream_from_websocket(ws, filename="stream.webm", accept_in_client=True) + + except Exception as e: + print("ASR2 WebSocket exception:", e) + try: + await ws.close() + except: + pass + + + +@app.websocket("/asr_3") +async def websocket_endpoint(ws: WebSocket): + """ + WebSocket endpoint for real-time ASR. + Accumulates audio and processes only once when complete. + """ print("Starting ASR service") await ws.accept() audio_buffer = bytearray() try: + # Receive all audio chunks while True: - data = await ws.receive_bytes() - audio_buffer.extend(data) + try: + # Try to receive bytes + data = await ws.receive_bytes() + audio_buffer.extend(data) + except: + # Try to receive text (for EOF signal) + message = await ws.receive_text() + msg_data = json.loads(message) - async for s in client.event_stream_from_bytes(raw=bytes(audio_buffer), filename="stream.webm"): + # Check for EOF signal + if msg_data.get("eof"): + print("Received EOF signal, processing audio...") + break + + # Process complete audio buffer ONCE + if len(audio_buffer) > 0: + print(f"Processing {len(audio_buffer)} bytes of audio") + async for s in client.event_stream_from_bytes( + raw=bytes(audio_buffer), + filename="stream.webm" + ): await ws.send_text(json.dumps(s, ensure_ascii=False)) except Exception as e: - print("WebSocket error:", e) - await ws.close() + print(f"WebSocket error: {e}") + await ws.send_text(json.dumps({ + "event": "error", + "detail": str(e) + }, ensure_ascii=False)) + finally: + try: + await ws.close() + except: + pass diff --git a/test/test.py b/test/test.py index f395142..3f619e4 100644 --- a/test/test.py +++ b/test/test.py @@ -3,28 +3,39 @@ import websockets import json AUDIO_FILE = "./data/khabar_1_5.wav" -WS_URL = "ws://127.0.0.1:7200/asr" +WS_URL = "ws://127.0.0.1:3002/asr" async def send_audio(): - async with websockets.connect(WS_URL) as ws: + async with websockets.connect(WS_URL, max_size=None) as ws: print("Connected to ASR WebSocket") - # Send audio in chunks + # --- Send audio in binary chunks --- with open(AUDIO_FILE, "rb") as f: chunk_size = 100000 while chunk := f.read(chunk_size): await ws.send(chunk) - # Tell server no more audio is coming - await ws.send(json.dumps({"eof": True})) + # --- Tell server that audio stream is finished --- + await ws.send(json.dumps({"event": "end"})) - # Receive results + # --- Receive partial + final transcripts --- try: while True: - message = await ws.recv() - data = json.loads(message) + msg = await ws.recv() + + # Triton code may send JSON or text + try: + data = json.loads(msg) + except Exception: + print("Non-JSON frame:", msg) + continue + print("Received:", data) + # Optional: stop when server sends is_final + if data.get("is_final"): + print("Final:", data["text"]) + except websockets.exceptions.ConnectionClosed: print("WebSocket closed by server")