from fastapi import FastAPI, WebSocket import asyncio from asr.service import TritonGrpcClient import json from config.base import TRITON_URL, MIN_CHUNK_SIZE app = FastAPI() client = TritonGrpcClient(TRITON_URL) @app.websocket("/asr") async def websocket_endpoint(ws: WebSocket): print("Starting ASR2 WebSocket") try: await client.websocket_stream_from_websocket(ws, filename="stream.webm", accept_in_client=True) except Exception as e: print("ASR2 WebSocket exception:", e) try: await ws.close() except: pass @app.websocket("/asr_3") async def websocket_endpoint(ws: WebSocket): """ WebSocket endpoint for real-time ASR. Accumulates audio and processes only once when complete. """ print("Starting ASR service") await ws.accept() audio_buffer = bytearray() try: # Receive all audio chunks while True: try: # Try to receive bytes data = await ws.receive_bytes() audio_buffer.extend(data) except: # Try to receive text (for EOF signal) message = await ws.receive_text() msg_data = json.loads(message) # Check for EOF signal if msg_data.get("eof"): print("Received EOF signal, processing audio...") break # Process complete audio buffer ONCE if len(audio_buffer) > 0: print(f"Processing {len(audio_buffer)} bytes of audio") async for s in client.event_stream_from_bytes( raw=bytes(audio_buffer), filename="stream.webm" ): await ws.send_text(json.dumps(s, ensure_ascii=False)) except Exception as e: print(f"WebSocket error: {e}") await ws.send_text(json.dumps({ "event": "error", "detail": str(e) }, ensure_ascii=False)) finally: try: await ws.close() except: pass