asr_client/asr/service.py
2025-10-15 10:51:30 +03:30

611 lines
25 KiB
Python

# clients/triton_grpc_client.py
from __future__ import annotations
import os, uuid, time, json, asyncio, tempfile
from pathlib import Path
from contextlib import asynccontextmanager, suppress
from typing import AsyncGenerator, List, Tuple, Optional, Callable, Awaitable, TYPE_CHECKING
import numpy as np
from pydub import AudioSegment
import tritonclient.grpc.aio as grpcclient
from tritonclient.utils import np_to_triton_dtype
import websockets
from dotenv import load_dotenv
load_dotenv()
# if TYPE_CHECKING:
# from fastapi import WebSocket
# from starlette.websockets import WebSocketDisconnect
# ---- constants (same as your code) ----
SAMPLE_RATE_HZ = 16_000
INPUT_WAV_TENSOR = "WAV"
INPUT_LEN_TENSOR = "WAV_LENS"
OUTPUT_TEXT_TENSOR = "TRANSCRIPTS"
MODEL_NAME = "transducer"
ZERO_PAD_REQUEST_CONTENT = True
TRITON_URL = os.getenv("TRITON_URL")
class _WebsocketsAdapter:
"""
Adapter to make a `websockets` WebSocketServerProtocol behave like
a Starlette/FastAPI WebSocket for the shape your code expects:
- accept()
- receive() -> dict with "bytes" or "text"
- send_json(obj)
- close()
"""
def __init__(self, ws):
self._ws = ws
self._is_server_protocol = True # semantic flag
async def accept(self):
# websockets server does its own accept when created; no-op here
return
async def receive(self):
# websockets.recv() returns bytes or str
data = await self._ws.recv()
if isinstance(data, bytes):
return {"bytes": data}
else:
return {"text": data}
async def send_json(self, obj):
import json
await self._ws.send(json.dumps(obj, ensure_ascii=False))
async def close(self):
await self._ws.close()
class TritonGrpcClient:
def __init__(self, triton_url: str):
self.triton_url = triton_url
"""
Owns only: audio -> chunking -> Triton -> queue -> streaming.
Keeps producing even if the HTTP/WS request disconnects.
MinIO/DB/file finalize stays outside via callbacks.
"""
# ---------- SSE: return async generator ----------
def event_stream_from_bytes(
self,
*,
raw: bytes,
filename: str,
content_type: Optional[str] = None,
on_final_text: Optional[Callable[[str], Awaitable[None]]] = None,
on_error: Optional[Callable[[str], Awaitable[None]]] = None,
) -> AsyncGenerator[str, None]:
async def _gen() -> AsyncGenerator[str, None]:
if not raw:
raise ValueError("Uploaded file is empty")
tmp_path = self._write_temp_file(raw, filename)
queue: asyncio.Queue[dict[str, str | None]] = asyncio.Queue()
drop_event = asyncio.Event() # set on disconnect → stop emitting only
# spawn producer and DO NOT await it here; it lives independently
asyncio.create_task(
self._produce_transcripts(tmp_path, queue, on_final_text, on_error, drop_event),
name=f"triton-producer-{uuid.uuid4().hex[:8]}",
)
print("[SSE] Client connected", flush=True)
try:
while True:
msg = await queue.get()
if msg.get("event") == "done":
print("[SSE] Job finished → closing stream", flush=True)
break
yield f"data: {json.dumps(msg, ensure_ascii=False)}\n\n"
except asyncio.CancelledError:
# request closed; keep producer alive, just stop emitting
print("[SSE] Client disconnected (background continues)", flush=True)
drop_event.set()
return
finally:
print("[SSE] event_stream END", flush=True)
return _gen()
# ---------- WebSocket: same producer, push over WS ----------
@asynccontextmanager
async def _open_triton(self):
client = grpcclient.InferenceServerClient(self.triton_url, verbose=False)
try:
yield client
finally:
with suppress(Exception):
await client.close()
async def _produce_transcripts(
self,
tmp_path: Path,
queue: "asyncio.Queue[dict[str, str | None]]",
on_final_text: Optional[Callable[[str], Awaitable[None]]],
on_error: Optional[Callable[[str], Awaitable[None]]],
drop_event: asyncio.Event,
) -> None:
print("[BG] Started producer", flush=True)
last_msg: dict[str, str] | None = None
try:
print("[BG] Entering stream_transcript loop", flush=True)
async for last_msg in self._stream_transcript(str(tmp_path)):
if not drop_event.is_set():
await queue.put(last_msg)
print("[BG] stream_transcript finished", flush=True)
final_text = (last_msg or {}).get("text", "").strip()
if on_final_text:
with suppress(Exception):
await on_final_text(final_text)
except Exception as exc:
print(f"[BG] EXCEPTION: {exc!r}", flush=True)
if on_error:
with suppress(Exception):
await on_error(str(exc))
if not drop_event.is_set():
await queue.put({"event": "error", "detail": str(exc)})
finally:
print("[BG] Cleaning up temp file", flush=True)
with suppress(FileNotFoundError):
tmp_path.unlink(missing_ok=True)
if not drop_event.is_set():
await queue.put({"event": "done"})
print("[BG] producer END", flush=True)
# Replace your existing websocket_stream_from_websocket with this version
async def websocket_stream_from_websocket(
self,
websocket,
*,
filename: str = "stream",
content_type: Optional[str] = None,
on_final_text: Optional[Callable[..., Awaitable[None]]] = None,
on_error: Optional[Callable[[str], Awaitable[None]]] = None,
accept_in_client: bool = False,
) -> None:
"""
Supports both:
- Starlette/FastAPI WebSocket objects (they expose .receive(), .send_json(), .accept())
- `websockets` WebSocketServerProtocol (they expose .recv(), .send(), .close())
The adapter wraps the latter so the rest of your existing logic can stay unchanged.
"""
# Lazy import here so file-level imports need not change
import json
import uuid
import time
import asyncio
import tempfile
from contextlib import suppress
# If this is a websockets (WebSocketServerProtocol) instance, wrap it
# Heuristic: Starlette has .receive(); websockets has .recv()
if not hasattr(websocket, "receive") and hasattr(websocket, "recv"):
websocket = _WebsocketsAdapter(websocket)
# If caller requested server-side accept for ASGI websockets, call it.
# For websockets adapter accept() is a no-op.
if accept_in_client:
# In FastAPI typical pattern is server calls accept(); keep that behavior
await websocket.accept()
# We'll append all raw bytes to a temp file so the endpoint can upload later in the finalizer
src_suffix = Path(filename).suffix or ".bin"
src_path = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}{src_suffix}"
f_src = src_path.open("wb")
# Background pipeline control
queue: asyncio.Queue[dict[str, str | None]] = asyncio.Queue()
drop_event = asyncio.Event() # if client disconnects, stop SENDING but keep producing
async def _call_final_cb(final_text: str) -> None:
if not on_final_text:
return
try:
# Try callback(final_text, source_audio_path) first, fallback to (final_text)
if on_final_text.__code__.co_argcount >= 2:
await on_final_text(final_text, src_path)
else:
await on_final_text(final_text)
except Exception:
pass
async def recv_frames_and_transcode(proc):
"""
Read frames using the (possibly-adapted) `websocket.receive()` and write to ffmpeg stdin;
close stdin when we get {"event":"end"} or disconnect.
"""
try:
while True:
msg = await websocket.receive()
if "bytes" in msg and msg["bytes"] is not None:
chunk = msg["bytes"]
f_src.write(chunk)
proc.stdin.write(chunk) # type: ignore[attr-defined]
await proc.stdin.drain() # type: ignore[attr-defined]
elif "text" in msg and msg["text"] is not None:
try:
payload = json.loads(msg["text"])
if payload.get("event") == "end":
break
except Exception:
# ignore non-JSON text frames
pass
else:
# ignore pings/other control frames
await asyncio.sleep(0)
except Exception:
# If underlying connection closed abruptly, just stop receiving
pass
finally:
with suppress(Exception):
f_src.flush()
f_src.close()
with suppress(Exception):
# Some implementations use close(); others have no stdin to close
proc.stdin.close() # type: ignore[attr-defined]
async def transcribe_from_ffmpeg_stdout(proc):
"""
Read float32 PCM from ffmpeg stdout, chunk on-the-fly, call Triton per chunk,
and push partial/final messages to the queue.
"""
# Triton session + chunk sizes
async with self._open_triton() as client:
first_sz, chunk_sz = await self._get_chunk_sizes(client, MODEL_NAME)
seq_id = uuid.uuid4().int & 0x7FFF_FFFF_FFFF_FFFF
full_tx = ""
have_sent_any = False
need = first_sz # samples needed for the next chunk
# PCM sample buffer (float32)
buf = np.empty(0, dtype=np.float32)
async def infer_one(raw: np.ndarray, eff_len: int, is_first: bool, is_last: bool):
nonlocal full_tx
wav_np = raw[None, :] # (1, T)
len_np = np.array([[eff_len]], np.int32)
inp_wav = grpcclient.InferInput(INPUT_WAV_TENSOR, wav_np.shape, np_to_triton_dtype(np.float32))
inp_len = grpcclient.InferInput(INPUT_LEN_TENSOR, len_np.shape, np_to_triton_dtype(np.int32))
inp_wav.set_data_from_numpy(wav_np)
inp_len.set_data_from_numpy(len_np)
outs = [grpcclient.InferRequestedOutput(OUTPUT_TEXT_TENSOR)]
resp = await client.infer(
MODEL_NAME,
inputs=[inp_wav, inp_len],
outputs=outs,
sequence_id=seq_id,
sequence_start=is_first,
sequence_end=is_last,
)
txt = b" ".join(resp.as_numpy(OUTPUT_TEXT_TENSOR)).decode().strip()
if not txt:
return None
delta = txt[len(full_tx) :] if txt.startswith(full_tx) else txt
full_tx = txt
return delta or None
# Read ffmpeg stdout in bytes; convert to float32
try:
while True:
chunk = await proc.stdout.read(8192 * 4) # 8192 samples (float32)
if not chunk:
break
# append decoded samples
new = np.frombuffer(chunk, dtype=np.float32)
if new.size == 0:
continue
if buf.size == 0:
buf = new
else:
buf = np.concatenate((buf, new), axis=0)
# while enough samples for the next piece, send it
while buf.size >= need:
take = need
piece = buf[:take]
buf = buf[take:]
is_first = not have_sent_any
have_sent_any = True
delta = await infer_one(piece, take, is_first=is_first, is_last=False)
if delta and not drop_event.is_set():
await queue.put({"time": time.time(), "text": delta, "is_final": False})
# after first, normal chunk size
need = chunk_sz
finally:
# End of stream: flush any remainder and a zero-length piece if configured
if buf.size > 0:
# pad to chunk_sz for model framing, but eff_len is the real (short) length
eff = int(buf.size)
pad = np.zeros(chunk_sz, dtype=np.float32)
pad[:eff] = buf[:eff]
delta = await infer_one(pad, eff, is_first=not have_sent_any, is_last=False)
if delta and not drop_event.is_set():
await queue.put({"time": time.time(), "text": delta, "is_final": False})
have_sent_any = True
if ZERO_PAD_REQUEST_CONTENT:
# zero-length "flush" chunk
zero = np.zeros(0, dtype=np.float32)
await infer_one(zero, 0, is_first=not have_sent_any, is_last=True)
else:
# If not sending the explicit flush, still mark last by sending empty with last=True
zero = np.zeros(0, dtype=np.float32)
await infer_one(zero, 0, is_first=False, is_last=True)
# Emit final full transcript as a message, like your file-based path
if not drop_event.is_set():
await queue.put({"time": time.time(), "text": full_tx.strip(), "is_final": True})
# Call user finalizer
await _call_final_cb(full_tx.strip())
# signal done to sender
if not drop_event.is_set():
await queue.put({"event": "done"})
async def send_messages():
HEARTBEAT_SECS = 10
print("[WS] Client connected", flush=True)
try:
while True:
try:
# wait for a real message up to HEARTBEAT_SECS
msg = await asyncio.wait_for(queue.get(), timeout=HEARTBEAT_SECS)
except asyncio.TimeoutError:
# no real message → send heartbeat
hb = {"event": "heartbeat", "t": time.time()}
await websocket.send_json(hb)
continue
if msg.get("event") == "done":
print("[WS] Job finished → closing socket", flush=True)
break
await websocket.send_json(msg)
except WebSocketDisconnect:
print("[WS] Client disconnected (background continues)", flush=True)
drop_event.set()
except asyncio.CancelledError:
print("[WS] Client cancelled (background continues)", flush=True)
drop_event.set()
finally:
with suppress(Exception):
await websocket.close()
print("[WS] send_messages END", flush=True)
# The remainder of this function is unchanged from your original implementation:
# - transcribe_from_ffmpeg_stdout(proc)
# - send_messages()
# - ffmpeg process startup
# - task creation + gather + cleanup
#
# To keep this patch minimal I re-use your original nested function bodies as-is.
# Copy the original nested functions transcribe_from_ffmpeg_stdout and send_messages
# exactly as you had them (they will reference `websocket`, `queue`, `drop_event`,
# `_open_triton`, etc.). Below I paste them unchanged but with no modifications so
# you can drop them in place.
#
# ---- paste your original nested functions here exactly (no changes) ----
# ----- Start ffmpeg (stdin bytes -> stdout float32 PCM @ 16k mono)
proc = await asyncio.create_subprocess_exec(
"ffmpeg",
"-hide_banner",
"-loglevel",
"error",
"-i",
"pipe:0",
"-ac",
"1",
"-ar",
str(SAMPLE_RATE_HZ),
"-f",
"f32le",
"pipe:1",
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
async def drain_ffmpeg_stderr(proc):
try:
while True:
line = await proc.stderr.readline()
if not line:
break
# You can log it if needed:
# print(f"[ffmpeg] {line.decode().rstrip()}", flush=True)
await asyncio.sleep(0)
except asyncio.CancelledError:
pass
stderr_task = asyncio.create_task(drain_ffmpeg_stderr(proc), name="ffmpeg-stderr")
# Run 3 tasks concurrently:
# - recv frames & feed ffmpeg
# - read pcm & call Triton
# - send queue messages to client
# Note: `transcribe_from_ffmpeg_stdout` and `send_messages` are your original nested functions.
# Keep them unchanged and ensure they reference `websocket` only for send_json/receive (adapter handles it).
tasks = [
asyncio.create_task(recv_frames_and_transcode(proc), name="ws-recv"),
asyncio.create_task(transcribe_from_ffmpeg_stdout(proc), name="ws-triton"),
asyncio.create_task(send_messages(), name="ws-send"),
stderr_task,
]
try:
await asyncio.gather(*tasks)
finally:
for t in tasks:
if not t.done():
t.cancel()
with suppress(Exception):
proc.terminate()
with suppress(Exception):
await proc.wait()
with suppress(Exception):
if not f_src.closed:
f_src.close()
@staticmethod
def _write_temp_file(raw: bytes, filename: str) -> Path:
tmp = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}{Path(filename).suffix}"
tmp.write_bytes(raw)
print(f"[REQ] Temp file written at {tmp}", flush=True)
return tmp
@staticmethod
def _pcm_int16_to_float32(x: np.ndarray) -> np.ndarray:
return (x.astype(np.float32) / 32768.0).clip(-1.0, 1.0)
def _load_audio(self, path: str, target_sr: int = SAMPLE_RATE_HZ):
audio = AudioSegment.from_file(path)
audio = audio.set_channels(1).set_frame_rate(target_sr).set_sample_width(2)
pcm_int16 = np.frombuffer(audio.raw_data, dtype=np.int16)
return pcm_int16, target_sr
async def _get_chunk_sizes(self, client: grpcclient.InferenceServerClient, model: str) -> Tuple[int, int]:
try:
cfg = await client.get_model_config(model, as_json=True)
params = {p["key"]: p["value"]["string_param"] for p in cfg.get("parameters", [])}
first_sec = float(params.get("chunk_size_first", params.get("chunk_size", "0.465")))
norm_sec = float(params.get("chunk_size", "0.32"))
except Exception:
first_sec, norm_sec = 0.465, 0.32
return int(first_sec * SAMPLE_RATE_HZ), int(norm_sec * SAMPLE_RATE_HZ)
async def _stream_transcript(self, path: str) -> AsyncGenerator[dict, None]:
"""
Read audio from `path`, chunk it and stream to Triton, yielding partial/final messages.
Yields dicts of the form:
{"time": <unix-time>, "text": "<delta text>", "is_final": False}
and finally:
{"time": <unix-time>, "text": "<full transcript>", "is_final": True}
"""
import time
import uuid
from contextlib import suppress
# Load audio (int16 PCM) and convert to float32 in [-1, 1]
pcm_int16, sr = self._load_audio(path, target_sr=SAMPLE_RATE_HZ)
if pcm_int16.size == 0:
return
wav = self._pcm_int16_to_float32(pcm_int16) # float32 1-D numpy array
# Triton client + chunk sizes
async with self._open_triton() as client:
first_sz, chunk_sz = await self._get_chunk_sizes(client, MODEL_NAME)
seq_id = uuid.uuid4().int & 0x7FFF_FFFF_FFFF_FFFF
full_tx = "" # accumulated full transcript text
have_sent_any = False
async def infer_one(raw: np.ndarray, eff_len: int, is_first: bool, is_last: bool):
"""
Send one request to Triton and return the delta text (or None).
raw: 1-D float32 numpy array length == raw.shape[0] (should be >= eff_len; padded if needed)
eff_len: number of valid samples in raw (int)
"""
nonlocal full_tx
wav_np = raw[None, :] # shape (1, T)
len_np = np.array([[eff_len]], np.int32)
inp_wav = grpcclient.InferInput(INPUT_WAV_TENSOR, wav_np.shape, np_to_triton_dtype(np.float32))
inp_len = grpcclient.InferInput(INPUT_LEN_TENSOR, len_np.shape, np_to_triton_dtype(np.int32))
inp_wav.set_data_from_numpy(wav_np)
inp_len.set_data_from_numpy(len_np)
outs = [grpcclient.InferRequestedOutput(OUTPUT_TEXT_TENSOR)]
resp = await client.infer(
MODEL_NAME,
inputs=[inp_wav, inp_len],
outputs=outs,
sequence_id=seq_id,
sequence_start=is_first,
sequence_end=is_last,
)
# join parts if model returned multiple tensors; decode bytes -> str
txt = b" ".join(resp.as_numpy(OUTPUT_TEXT_TENSOR)).decode().strip()
if not txt:
return None
delta = txt[len(full_tx) :] if txt.startswith(full_tx) else txt
full_tx = txt
return delta or None
# iterate over wav in chunks
T = wav.shape[0]
offset = 0
need = first_sz
# If first_sz is 0 for some reason, fall back to chunk_sz
if need <= 0:
need = chunk_sz
try:
while offset < T:
take = min(need, T - offset)
piece = wav[offset : offset + take]
# If piece is shorter than need (shouldn't happen except maybe for first),
# pad to `need` as model may expect framing; eff_len tracks real length.
eff = int(piece.size)
if eff < need:
pad = np.zeros(need, dtype=np.float32)
pad[:eff] = piece
to_send = pad
else:
to_send = piece
is_first = not have_sent_any
have_sent_any = True
# Not last unless remaining after this is zero
is_last = False
delta = await infer_one(to_send, eff, is_first=is_first, is_last=is_last)
if delta:
yield {"time": time.time(), "text": delta, "is_final": False}
offset += take
need = chunk_sz # after first, normal chunk size
# End of stream: if there is leftover silence/padding behavior handled above.
# After all real chunks sent, optionally flush remainder (zero-length) as last=True.
if ZERO_PAD_REQUEST_CONTENT:
# send an explicit flush chunk (zero-length) with sequence_end=True
zero = np.zeros(0, dtype=np.float32)
# When sending zero-length, eff_len = 0. Mark is_first True only if never sent anything.
await infer_one(zero, 0, is_first=(not have_sent_any), is_last=True)
else:
# send empty with last=True but not as first
zero = np.zeros(0, dtype=np.float32)
await infer_one(zero, 0, is_first=False, is_last=True)
# Emit final full transcript
yield {"time": time.time(), "text": full_tx.strip(), "is_final": True}
finally:
# nothing specific to clean in this generator; Triton client closed by context manager
with suppress(Exception):
pass