# clients/triton_grpc_client.py from __future__ import annotations import os, uuid, time, json, asyncio, tempfile from pathlib import Path from contextlib import asynccontextmanager, suppress from typing import AsyncGenerator, List, Tuple, Optional, Callable, Awaitable, TYPE_CHECKING import numpy as np from pydub import AudioSegment import tritonclient.grpc.aio as grpcclient from tritonclient.utils import np_to_triton_dtype import websockets from dotenv import load_dotenv load_dotenv() # if TYPE_CHECKING: # from fastapi import WebSocket # from starlette.websockets import WebSocketDisconnect # ---- constants (same as your code) ---- SAMPLE_RATE_HZ = 16_000 INPUT_WAV_TENSOR = "WAV" INPUT_LEN_TENSOR = "WAV_LENS" OUTPUT_TEXT_TENSOR = "TRANSCRIPTS" MODEL_NAME = "transducer" ZERO_PAD_REQUEST_CONTENT = True TRITON_URL = os.getenv("TRITON_URL") class _WebsocketsAdapter: """ Adapter to make a `websockets` WebSocketServerProtocol behave like a Starlette/FastAPI WebSocket for the shape your code expects: - accept() - receive() -> dict with "bytes" or "text" - send_json(obj) - close() """ def __init__(self, ws): self._ws = ws self._is_server_protocol = True # semantic flag async def accept(self): # websockets server does its own accept when created; no-op here return async def receive(self): # websockets.recv() returns bytes or str data = await self._ws.recv() if isinstance(data, bytes): return {"bytes": data} else: return {"text": data} async def send_json(self, obj): import json await self._ws.send(json.dumps(obj, ensure_ascii=False)) async def close(self): await self._ws.close() class TritonGrpcClient: def __init__(self, triton_url: str): self.triton_url = triton_url """ Owns only: audio -> chunking -> Triton -> queue -> streaming. Keeps producing even if the HTTP/WS request disconnects. MinIO/DB/file finalize stays outside via callbacks. """ # ---------- SSE: return async generator ---------- def event_stream_from_bytes( self, *, raw: bytes, filename: str, content_type: Optional[str] = None, on_final_text: Optional[Callable[[str], Awaitable[None]]] = None, on_error: Optional[Callable[[str], Awaitable[None]]] = None, ) -> AsyncGenerator[str, None]: async def _gen() -> AsyncGenerator[str, None]: if not raw: raise ValueError("Uploaded file is empty") tmp_path = self._write_temp_file(raw, filename) queue: asyncio.Queue[dict[str, str | None]] = asyncio.Queue() drop_event = asyncio.Event() # set on disconnect → stop emitting only # spawn producer and DO NOT await it here; it lives independently asyncio.create_task( self._produce_transcripts(tmp_path, queue, on_final_text, on_error, drop_event), name=f"triton-producer-{uuid.uuid4().hex[:8]}", ) print("[SSE] Client connected", flush=True) try: while True: msg = await queue.get() if msg.get("event") == "done": print("[SSE] Job finished → closing stream", flush=True) break yield f"data: {json.dumps(msg, ensure_ascii=False)}\n\n" except asyncio.CancelledError: # request closed; keep producer alive, just stop emitting print("[SSE] Client disconnected (background continues)", flush=True) drop_event.set() return finally: print("[SSE] event_stream END", flush=True) return _gen() # ---------- WebSocket: same producer, push over WS ---------- @asynccontextmanager async def _open_triton(self): client = grpcclient.InferenceServerClient(self.triton_url, verbose=False) try: yield client finally: with suppress(Exception): await client.close() async def _produce_transcripts( self, tmp_path: Path, queue: "asyncio.Queue[dict[str, str | None]]", on_final_text: Optional[Callable[[str], Awaitable[None]]], on_error: Optional[Callable[[str], Awaitable[None]]], drop_event: asyncio.Event, ) -> None: print("[BG] Started producer", flush=True) last_msg: dict[str, str] | None = None try: print("[BG] Entering stream_transcript loop", flush=True) async for last_msg in self._stream_transcript(str(tmp_path)): if not drop_event.is_set(): await queue.put(last_msg) print("[BG] stream_transcript finished", flush=True) final_text = (last_msg or {}).get("text", "").strip() if on_final_text: with suppress(Exception): await on_final_text(final_text) except Exception as exc: print(f"[BG] EXCEPTION: {exc!r}", flush=True) if on_error: with suppress(Exception): await on_error(str(exc)) if not drop_event.is_set(): await queue.put({"event": "error", "detail": str(exc)}) finally: print("[BG] Cleaning up temp file", flush=True) with suppress(FileNotFoundError): tmp_path.unlink(missing_ok=True) if not drop_event.is_set(): await queue.put({"event": "done"}) print("[BG] producer END", flush=True) # Replace your existing websocket_stream_from_websocket with this version async def websocket_stream_from_websocket( self, websocket, *, filename: str = "stream", content_type: Optional[str] = None, on_final_text: Optional[Callable[..., Awaitable[None]]] = None, on_error: Optional[Callable[[str], Awaitable[None]]] = None, accept_in_client: bool = False, ) -> None: """ Supports both: - Starlette/FastAPI WebSocket objects (they expose .receive(), .send_json(), .accept()) - `websockets` WebSocketServerProtocol (they expose .recv(), .send(), .close()) The adapter wraps the latter so the rest of your existing logic can stay unchanged. """ # Lazy import here so file-level imports need not change import json import uuid import time import asyncio import tempfile from contextlib import suppress # If this is a websockets (WebSocketServerProtocol) instance, wrap it # Heuristic: Starlette has .receive(); websockets has .recv() if not hasattr(websocket, "receive") and hasattr(websocket, "recv"): websocket = _WebsocketsAdapter(websocket) # If caller requested server-side accept for ASGI websockets, call it. # For websockets adapter accept() is a no-op. if accept_in_client: # In FastAPI typical pattern is server calls accept(); keep that behavior await websocket.accept() # We'll append all raw bytes to a temp file so the endpoint can upload later in the finalizer src_suffix = Path(filename).suffix or ".bin" src_path = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}{src_suffix}" f_src = src_path.open("wb") # Background pipeline control queue: asyncio.Queue[dict[str, str | None]] = asyncio.Queue() drop_event = asyncio.Event() # if client disconnects, stop SENDING but keep producing async def _call_final_cb(final_text: str) -> None: if not on_final_text: return try: # Try callback(final_text, source_audio_path) first, fallback to (final_text) if on_final_text.__code__.co_argcount >= 2: await on_final_text(final_text, src_path) else: await on_final_text(final_text) except Exception: pass async def recv_frames_and_transcode(proc): """ Read frames using the (possibly-adapted) `websocket.receive()` and write to ffmpeg stdin; close stdin when we get {"event":"end"} or disconnect. """ try: while True: msg = await websocket.receive() if "bytes" in msg and msg["bytes"] is not None: chunk = msg["bytes"] f_src.write(chunk) proc.stdin.write(chunk) # type: ignore[attr-defined] await proc.stdin.drain() # type: ignore[attr-defined] elif "text" in msg and msg["text"] is not None: try: payload = json.loads(msg["text"]) if payload.get("event") == "end": break except Exception: # ignore non-JSON text frames pass else: # ignore pings/other control frames await asyncio.sleep(0) except Exception: # If underlying connection closed abruptly, just stop receiving pass finally: with suppress(Exception): f_src.flush() f_src.close() with suppress(Exception): # Some implementations use close(); others have no stdin to close proc.stdin.close() # type: ignore[attr-defined] async def transcribe_from_ffmpeg_stdout(proc): """ Read float32 PCM from ffmpeg stdout, chunk on-the-fly, call Triton per chunk, and push partial/final messages to the queue. """ # Triton session + chunk sizes async with self._open_triton() as client: first_sz, chunk_sz = await self._get_chunk_sizes(client, MODEL_NAME) seq_id = uuid.uuid4().int & 0x7FFF_FFFF_FFFF_FFFF full_tx = "" have_sent_any = False need = first_sz # samples needed for the next chunk # PCM sample buffer (float32) buf = np.empty(0, dtype=np.float32) async def infer_one(raw: np.ndarray, eff_len: int, is_first: bool, is_last: bool): nonlocal full_tx wav_np = raw[None, :] # (1, T) len_np = np.array([[eff_len]], np.int32) inp_wav = grpcclient.InferInput(INPUT_WAV_TENSOR, wav_np.shape, np_to_triton_dtype(np.float32)) inp_len = grpcclient.InferInput(INPUT_LEN_TENSOR, len_np.shape, np_to_triton_dtype(np.int32)) inp_wav.set_data_from_numpy(wav_np) inp_len.set_data_from_numpy(len_np) outs = [grpcclient.InferRequestedOutput(OUTPUT_TEXT_TENSOR)] resp = await client.infer( MODEL_NAME, inputs=[inp_wav, inp_len], outputs=outs, sequence_id=seq_id, sequence_start=is_first, sequence_end=is_last, ) txt = b" ".join(resp.as_numpy(OUTPUT_TEXT_TENSOR)).decode().strip() if not txt: return None delta = txt[len(full_tx) :] if txt.startswith(full_tx) else txt full_tx = txt return delta or None # Read ffmpeg stdout in bytes; convert to float32 try: while True: chunk = await proc.stdout.read(8192 * 4) # 8192 samples (float32) if not chunk: break # append decoded samples new = np.frombuffer(chunk, dtype=np.float32) if new.size == 0: continue if buf.size == 0: buf = new else: buf = np.concatenate((buf, new), axis=0) # while enough samples for the next piece, send it while buf.size >= need: take = need piece = buf[:take] buf = buf[take:] is_first = not have_sent_any have_sent_any = True delta = await infer_one(piece, take, is_first=is_first, is_last=False) if delta and not drop_event.is_set(): await queue.put({"time": time.time(), "text": delta, "is_final": False}) # after first, normal chunk size need = chunk_sz finally: # End of stream: flush any remainder and a zero-length piece if configured if buf.size > 0: # pad to chunk_sz for model framing, but eff_len is the real (short) length eff = int(buf.size) pad = np.zeros(chunk_sz, dtype=np.float32) pad[:eff] = buf[:eff] delta = await infer_one(pad, eff, is_first=not have_sent_any, is_last=False) if delta and not drop_event.is_set(): await queue.put({"time": time.time(), "text": delta, "is_final": False}) have_sent_any = True if ZERO_PAD_REQUEST_CONTENT: # zero-length "flush" chunk zero = np.zeros(0, dtype=np.float32) await infer_one(zero, 0, is_first=not have_sent_any, is_last=True) else: # If not sending the explicit flush, still mark last by sending empty with last=True zero = np.zeros(0, dtype=np.float32) await infer_one(zero, 0, is_first=False, is_last=True) # Emit final full transcript as a message, like your file-based path if not drop_event.is_set(): await queue.put({"time": time.time(), "text": full_tx.strip(), "is_final": True}) # Call user finalizer await _call_final_cb(full_tx.strip()) # signal done to sender if not drop_event.is_set(): await queue.put({"event": "done"}) async def send_messages(): HEARTBEAT_SECS = 10 print("[WS] Client connected", flush=True) try: while True: try: # wait for a real message up to HEARTBEAT_SECS msg = await asyncio.wait_for(queue.get(), timeout=HEARTBEAT_SECS) except asyncio.TimeoutError: # no real message → send heartbeat hb = {"event": "heartbeat", "t": time.time()} await websocket.send_json(hb) continue if msg.get("event") == "done": print("[WS] Job finished → closing socket", flush=True) break await websocket.send_json(msg) except WebSocketDisconnect: print("[WS] Client disconnected (background continues)", flush=True) drop_event.set() except asyncio.CancelledError: print("[WS] Client cancelled (background continues)", flush=True) drop_event.set() finally: with suppress(Exception): await websocket.close() print("[WS] send_messages END", flush=True) # The remainder of this function is unchanged from your original implementation: # - transcribe_from_ffmpeg_stdout(proc) # - send_messages() # - ffmpeg process startup # - task creation + gather + cleanup # # To keep this patch minimal I re-use your original nested function bodies as-is. # Copy the original nested functions transcribe_from_ffmpeg_stdout and send_messages # exactly as you had them (they will reference `websocket`, `queue`, `drop_event`, # `_open_triton`, etc.). Below I paste them unchanged but with no modifications so # you can drop them in place. # # ---- paste your original nested functions here exactly (no changes) ---- # ----- Start ffmpeg (stdin bytes -> stdout float32 PCM @ 16k mono) proc = await asyncio.create_subprocess_exec( "ffmpeg", "-hide_banner", "-loglevel", "error", "-i", "pipe:0", "-ac", "1", "-ar", str(SAMPLE_RATE_HZ), "-f", "f32le", "pipe:1", stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) async def drain_ffmpeg_stderr(proc): try: while True: line = await proc.stderr.readline() if not line: break # You can log it if needed: # print(f"[ffmpeg] {line.decode().rstrip()}", flush=True) await asyncio.sleep(0) except asyncio.CancelledError: pass stderr_task = asyncio.create_task(drain_ffmpeg_stderr(proc), name="ffmpeg-stderr") # Run 3 tasks concurrently: # - recv frames & feed ffmpeg # - read pcm & call Triton # - send queue messages to client # Note: `transcribe_from_ffmpeg_stdout` and `send_messages` are your original nested functions. # Keep them unchanged and ensure they reference `websocket` only for send_json/receive (adapter handles it). tasks = [ asyncio.create_task(recv_frames_and_transcode(proc), name="ws-recv"), asyncio.create_task(transcribe_from_ffmpeg_stdout(proc), name="ws-triton"), asyncio.create_task(send_messages(), name="ws-send"), stderr_task, ] try: await asyncio.gather(*tasks) finally: for t in tasks: if not t.done(): t.cancel() with suppress(Exception): proc.terminate() with suppress(Exception): await proc.wait() with suppress(Exception): if not f_src.closed: f_src.close() @staticmethod def _write_temp_file(raw: bytes, filename: str) -> Path: tmp = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}{Path(filename).suffix}" tmp.write_bytes(raw) print(f"[REQ] Temp file written at {tmp}", flush=True) return tmp @staticmethod def _pcm_int16_to_float32(x: np.ndarray) -> np.ndarray: return (x.astype(np.float32) / 32768.0).clip(-1.0, 1.0) def _load_audio(self, path: str, target_sr: int = SAMPLE_RATE_HZ): audio = AudioSegment.from_file(path) audio = audio.set_channels(1).set_frame_rate(target_sr).set_sample_width(2) pcm_int16 = np.frombuffer(audio.raw_data, dtype=np.int16) return pcm_int16, target_sr async def _get_chunk_sizes(self, client: grpcclient.InferenceServerClient, model: str) -> Tuple[int, int]: try: cfg = await client.get_model_config(model, as_json=True) params = {p["key"]: p["value"]["string_param"] for p in cfg.get("parameters", [])} first_sec = float(params.get("chunk_size_first", params.get("chunk_size", "0.465"))) norm_sec = float(params.get("chunk_size", "0.32")) except Exception: first_sec, norm_sec = 0.465, 0.32 return int(first_sec * SAMPLE_RATE_HZ), int(norm_sec * SAMPLE_RATE_HZ) async def _stream_transcript(self, path: str) -> AsyncGenerator[dict, None]: """ Read audio from `path`, chunk it and stream to Triton, yielding partial/final messages. Yields dicts of the form: {"time": , "text": "", "is_final": False} and finally: {"time": , "text": "", "is_final": True} """ import time import uuid from contextlib import suppress # Load audio (int16 PCM) and convert to float32 in [-1, 1] pcm_int16, sr = self._load_audio(path, target_sr=SAMPLE_RATE_HZ) if pcm_int16.size == 0: return wav = self._pcm_int16_to_float32(pcm_int16) # float32 1-D numpy array # Triton client + chunk sizes async with self._open_triton() as client: first_sz, chunk_sz = await self._get_chunk_sizes(client, MODEL_NAME) seq_id = uuid.uuid4().int & 0x7FFF_FFFF_FFFF_FFFF full_tx = "" # accumulated full transcript text have_sent_any = False async def infer_one(raw: np.ndarray, eff_len: int, is_first: bool, is_last: bool): """ Send one request to Triton and return the delta text (or None). raw: 1-D float32 numpy array length == raw.shape[0] (should be >= eff_len; padded if needed) eff_len: number of valid samples in raw (int) """ nonlocal full_tx wav_np = raw[None, :] # shape (1, T) len_np = np.array([[eff_len]], np.int32) inp_wav = grpcclient.InferInput(INPUT_WAV_TENSOR, wav_np.shape, np_to_triton_dtype(np.float32)) inp_len = grpcclient.InferInput(INPUT_LEN_TENSOR, len_np.shape, np_to_triton_dtype(np.int32)) inp_wav.set_data_from_numpy(wav_np) inp_len.set_data_from_numpy(len_np) outs = [grpcclient.InferRequestedOutput(OUTPUT_TEXT_TENSOR)] resp = await client.infer( MODEL_NAME, inputs=[inp_wav, inp_len], outputs=outs, sequence_id=seq_id, sequence_start=is_first, sequence_end=is_last, ) # join parts if model returned multiple tensors; decode bytes -> str txt = b" ".join(resp.as_numpy(OUTPUT_TEXT_TENSOR)).decode().strip() if not txt: return None delta = txt[len(full_tx) :] if txt.startswith(full_tx) else txt full_tx = txt return delta or None # iterate over wav in chunks T = wav.shape[0] offset = 0 need = first_sz # If first_sz is 0 for some reason, fall back to chunk_sz if need <= 0: need = chunk_sz try: while offset < T: take = min(need, T - offset) piece = wav[offset : offset + take] # If piece is shorter than need (shouldn't happen except maybe for first), # pad to `need` as model may expect framing; eff_len tracks real length. eff = int(piece.size) if eff < need: pad = np.zeros(need, dtype=np.float32) pad[:eff] = piece to_send = pad else: to_send = piece is_first = not have_sent_any have_sent_any = True # Not last unless remaining after this is zero is_last = False delta = await infer_one(to_send, eff, is_first=is_first, is_last=is_last) if delta: yield {"time": time.time(), "text": delta, "is_final": False} offset += take need = chunk_sz # after first, normal chunk size # End of stream: if there is leftover silence/padding behavior handled above. # After all real chunks sent, optionally flush remainder (zero-length) as last=True. if ZERO_PAD_REQUEST_CONTENT: # send an explicit flush chunk (zero-length) with sequence_end=True zero = np.zeros(0, dtype=np.float32) # When sending zero-length, eff_len = 0. Mark is_first True only if never sent anything. await infer_one(zero, 0, is_first=(not have_sent_any), is_last=True) else: # send empty with last=True but not as first zero = np.zeros(0, dtype=np.float32) await infer_one(zero, 0, is_first=False, is_last=True) # Emit final full transcript yield {"time": time.time(), "text": full_tx.strip(), "is_final": True} finally: # nothing specific to clean in this generator; Triton client closed by context manager with suppress(Exception): pass