611 lines
25 KiB
Python
611 lines
25 KiB
Python
# clients/triton_grpc_client.py
|
|
from __future__ import annotations
|
|
|
|
import os, uuid, time, json, asyncio, tempfile
|
|
from pathlib import Path
|
|
from contextlib import asynccontextmanager, suppress
|
|
from typing import AsyncGenerator, List, Tuple, Optional, Callable, Awaitable, TYPE_CHECKING
|
|
|
|
import numpy as np
|
|
from pydub import AudioSegment
|
|
import tritonclient.grpc.aio as grpcclient
|
|
from tritonclient.utils import np_to_triton_dtype
|
|
import websockets
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
# if TYPE_CHECKING:
|
|
# from fastapi import WebSocket
|
|
# from starlette.websockets import WebSocketDisconnect
|
|
|
|
# ---- constants (same as your code) ----
|
|
SAMPLE_RATE_HZ = 16_000
|
|
INPUT_WAV_TENSOR = "WAV"
|
|
INPUT_LEN_TENSOR = "WAV_LENS"
|
|
OUTPUT_TEXT_TENSOR = "TRANSCRIPTS"
|
|
MODEL_NAME = "transducer"
|
|
ZERO_PAD_REQUEST_CONTENT = True
|
|
TRITON_URL = os.getenv("TRITON_URL")
|
|
|
|
|
|
class _WebsocketsAdapter:
|
|
"""
|
|
Adapter to make a `websockets` WebSocketServerProtocol behave like
|
|
a Starlette/FastAPI WebSocket for the shape your code expects:
|
|
- accept()
|
|
- receive() -> dict with "bytes" or "text"
|
|
- send_json(obj)
|
|
- close()
|
|
"""
|
|
|
|
def __init__(self, ws):
|
|
self._ws = ws
|
|
self._is_server_protocol = True # semantic flag
|
|
|
|
async def accept(self):
|
|
# websockets server does its own accept when created; no-op here
|
|
return
|
|
|
|
async def receive(self):
|
|
# websockets.recv() returns bytes or str
|
|
data = await self._ws.recv()
|
|
if isinstance(data, bytes):
|
|
return {"bytes": data}
|
|
else:
|
|
return {"text": data}
|
|
|
|
async def send_json(self, obj):
|
|
import json
|
|
|
|
await self._ws.send(json.dumps(obj, ensure_ascii=False))
|
|
|
|
async def close(self):
|
|
await self._ws.close()
|
|
|
|
|
|
class TritonGrpcClient:
|
|
def __init__(self, triton_url: str):
|
|
self.triton_url = triton_url
|
|
|
|
"""
|
|
Owns only: audio -> chunking -> Triton -> queue -> streaming.
|
|
Keeps producing even if the HTTP/WS request disconnects.
|
|
MinIO/DB/file finalize stays outside via callbacks.
|
|
"""
|
|
|
|
# ---------- SSE: return async generator ----------
|
|
def event_stream_from_bytes(
|
|
self,
|
|
*,
|
|
raw: bytes,
|
|
filename: str,
|
|
content_type: Optional[str] = None,
|
|
on_final_text: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
on_error: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
) -> AsyncGenerator[str, None]:
|
|
async def _gen() -> AsyncGenerator[str, None]:
|
|
if not raw:
|
|
raise ValueError("Uploaded file is empty")
|
|
|
|
tmp_path = self._write_temp_file(raw, filename)
|
|
queue: asyncio.Queue[dict[str, str | None]] = asyncio.Queue()
|
|
drop_event = asyncio.Event() # set on disconnect → stop emitting only
|
|
|
|
# spawn producer and DO NOT await it here; it lives independently
|
|
asyncio.create_task(
|
|
self._produce_transcripts(tmp_path, queue, on_final_text, on_error, drop_event),
|
|
name=f"triton-producer-{uuid.uuid4().hex[:8]}",
|
|
)
|
|
|
|
print("[SSE] Client connected", flush=True)
|
|
try:
|
|
while True:
|
|
msg = await queue.get()
|
|
if msg.get("event") == "done":
|
|
print("[SSE] Job finished → closing stream", flush=True)
|
|
break
|
|
yield f"data: {json.dumps(msg, ensure_ascii=False)}\n\n"
|
|
except asyncio.CancelledError:
|
|
# request closed; keep producer alive, just stop emitting
|
|
print("[SSE] Client disconnected (background continues)", flush=True)
|
|
drop_event.set()
|
|
return
|
|
finally:
|
|
print("[SSE] event_stream END", flush=True)
|
|
|
|
return _gen()
|
|
|
|
# ---------- WebSocket: same producer, push over WS ----------
|
|
@asynccontextmanager
|
|
async def _open_triton(self):
|
|
client = grpcclient.InferenceServerClient(self.triton_url, verbose=False)
|
|
try:
|
|
yield client
|
|
finally:
|
|
with suppress(Exception):
|
|
await client.close()
|
|
|
|
async def _produce_transcripts(
|
|
self,
|
|
tmp_path: Path,
|
|
queue: "asyncio.Queue[dict[str, str | None]]",
|
|
on_final_text: Optional[Callable[[str], Awaitable[None]]],
|
|
on_error: Optional[Callable[[str], Awaitable[None]]],
|
|
drop_event: asyncio.Event,
|
|
) -> None:
|
|
print("[BG] Started producer", flush=True)
|
|
last_msg: dict[str, str] | None = None
|
|
|
|
try:
|
|
print("[BG] Entering stream_transcript loop", flush=True)
|
|
async for last_msg in self._stream_transcript(str(tmp_path)):
|
|
if not drop_event.is_set():
|
|
await queue.put(last_msg)
|
|
|
|
print("[BG] stream_transcript finished", flush=True)
|
|
final_text = (last_msg or {}).get("text", "").strip()
|
|
|
|
if on_final_text:
|
|
with suppress(Exception):
|
|
await on_final_text(final_text)
|
|
|
|
except Exception as exc:
|
|
print(f"[BG] EXCEPTION: {exc!r}", flush=True)
|
|
if on_error:
|
|
with suppress(Exception):
|
|
await on_error(str(exc))
|
|
if not drop_event.is_set():
|
|
await queue.put({"event": "error", "detail": str(exc)})
|
|
finally:
|
|
print("[BG] Cleaning up temp file", flush=True)
|
|
with suppress(FileNotFoundError):
|
|
tmp_path.unlink(missing_ok=True)
|
|
|
|
if not drop_event.is_set():
|
|
await queue.put({"event": "done"})
|
|
print("[BG] producer END", flush=True)
|
|
|
|
# Replace your existing websocket_stream_from_websocket with this version
|
|
async def websocket_stream_from_websocket(
|
|
self,
|
|
websocket,
|
|
*,
|
|
filename: str = "stream",
|
|
content_type: Optional[str] = None,
|
|
on_final_text: Optional[Callable[..., Awaitable[None]]] = None,
|
|
on_error: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
accept_in_client: bool = False,
|
|
) -> None:
|
|
"""
|
|
Supports both:
|
|
- Starlette/FastAPI WebSocket objects (they expose .receive(), .send_json(), .accept())
|
|
- `websockets` WebSocketServerProtocol (they expose .recv(), .send(), .close())
|
|
The adapter wraps the latter so the rest of your existing logic can stay unchanged.
|
|
"""
|
|
# Lazy import here so file-level imports need not change
|
|
import json
|
|
import uuid
|
|
import time
|
|
import asyncio
|
|
import tempfile
|
|
from contextlib import suppress
|
|
|
|
# If this is a websockets (WebSocketServerProtocol) instance, wrap it
|
|
# Heuristic: Starlette has .receive(); websockets has .recv()
|
|
if not hasattr(websocket, "receive") and hasattr(websocket, "recv"):
|
|
websocket = _WebsocketsAdapter(websocket)
|
|
|
|
# If caller requested server-side accept for ASGI websockets, call it.
|
|
# For websockets adapter accept() is a no-op.
|
|
if accept_in_client:
|
|
# In FastAPI typical pattern is server calls accept(); keep that behavior
|
|
await websocket.accept()
|
|
|
|
# We'll append all raw bytes to a temp file so the endpoint can upload later in the finalizer
|
|
src_suffix = Path(filename).suffix or ".bin"
|
|
src_path = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}{src_suffix}"
|
|
f_src = src_path.open("wb")
|
|
|
|
# Background pipeline control
|
|
queue: asyncio.Queue[dict[str, str | None]] = asyncio.Queue()
|
|
drop_event = asyncio.Event() # if client disconnects, stop SENDING but keep producing
|
|
|
|
async def _call_final_cb(final_text: str) -> None:
|
|
if not on_final_text:
|
|
return
|
|
try:
|
|
# Try callback(final_text, source_audio_path) first, fallback to (final_text)
|
|
if on_final_text.__code__.co_argcount >= 2:
|
|
await on_final_text(final_text, src_path)
|
|
else:
|
|
await on_final_text(final_text)
|
|
except Exception:
|
|
pass
|
|
|
|
async def recv_frames_and_transcode(proc):
|
|
"""
|
|
Read frames using the (possibly-adapted) `websocket.receive()` and write to ffmpeg stdin;
|
|
close stdin when we get {"event":"end"} or disconnect.
|
|
"""
|
|
try:
|
|
while True:
|
|
msg = await websocket.receive()
|
|
if "bytes" in msg and msg["bytes"] is not None:
|
|
chunk = msg["bytes"]
|
|
f_src.write(chunk)
|
|
proc.stdin.write(chunk) # type: ignore[attr-defined]
|
|
await proc.stdin.drain() # type: ignore[attr-defined]
|
|
elif "text" in msg and msg["text"] is not None:
|
|
try:
|
|
payload = json.loads(msg["text"])
|
|
if payload.get("event") == "end":
|
|
break
|
|
except Exception:
|
|
# ignore non-JSON text frames
|
|
pass
|
|
else:
|
|
# ignore pings/other control frames
|
|
await asyncio.sleep(0)
|
|
except Exception:
|
|
# If underlying connection closed abruptly, just stop receiving
|
|
pass
|
|
finally:
|
|
with suppress(Exception):
|
|
f_src.flush()
|
|
f_src.close()
|
|
with suppress(Exception):
|
|
# Some implementations use close(); others have no stdin to close
|
|
proc.stdin.close() # type: ignore[attr-defined]
|
|
|
|
async def transcribe_from_ffmpeg_stdout(proc):
|
|
"""
|
|
Read float32 PCM from ffmpeg stdout, chunk on-the-fly, call Triton per chunk,
|
|
and push partial/final messages to the queue.
|
|
"""
|
|
# Triton session + chunk sizes
|
|
async with self._open_triton() as client:
|
|
first_sz, chunk_sz = await self._get_chunk_sizes(client, MODEL_NAME)
|
|
seq_id = uuid.uuid4().int & 0x7FFF_FFFF_FFFF_FFFF
|
|
full_tx = ""
|
|
have_sent_any = False
|
|
need = first_sz # samples needed for the next chunk
|
|
|
|
# PCM sample buffer (float32)
|
|
buf = np.empty(0, dtype=np.float32)
|
|
|
|
async def infer_one(raw: np.ndarray, eff_len: int, is_first: bool, is_last: bool):
|
|
nonlocal full_tx
|
|
wav_np = raw[None, :] # (1, T)
|
|
len_np = np.array([[eff_len]], np.int32)
|
|
inp_wav = grpcclient.InferInput(INPUT_WAV_TENSOR, wav_np.shape, np_to_triton_dtype(np.float32))
|
|
inp_len = grpcclient.InferInput(INPUT_LEN_TENSOR, len_np.shape, np_to_triton_dtype(np.int32))
|
|
inp_wav.set_data_from_numpy(wav_np)
|
|
inp_len.set_data_from_numpy(len_np)
|
|
outs = [grpcclient.InferRequestedOutput(OUTPUT_TEXT_TENSOR)]
|
|
|
|
resp = await client.infer(
|
|
MODEL_NAME,
|
|
inputs=[inp_wav, inp_len],
|
|
outputs=outs,
|
|
sequence_id=seq_id,
|
|
sequence_start=is_first,
|
|
sequence_end=is_last,
|
|
)
|
|
txt = b" ".join(resp.as_numpy(OUTPUT_TEXT_TENSOR)).decode().strip()
|
|
if not txt:
|
|
return None
|
|
delta = txt[len(full_tx) :] if txt.startswith(full_tx) else txt
|
|
full_tx = txt
|
|
return delta or None
|
|
|
|
# Read ffmpeg stdout in bytes; convert to float32
|
|
try:
|
|
while True:
|
|
chunk = await proc.stdout.read(8192 * 4) # 8192 samples (float32)
|
|
if not chunk:
|
|
break
|
|
# append decoded samples
|
|
new = np.frombuffer(chunk, dtype=np.float32)
|
|
if new.size == 0:
|
|
continue
|
|
if buf.size == 0:
|
|
buf = new
|
|
else:
|
|
buf = np.concatenate((buf, new), axis=0)
|
|
|
|
# while enough samples for the next piece, send it
|
|
while buf.size >= need:
|
|
take = need
|
|
piece = buf[:take]
|
|
buf = buf[take:]
|
|
is_first = not have_sent_any
|
|
have_sent_any = True
|
|
|
|
delta = await infer_one(piece, take, is_first=is_first, is_last=False)
|
|
if delta and not drop_event.is_set():
|
|
await queue.put({"time": time.time(), "text": delta, "is_final": False})
|
|
|
|
# after first, normal chunk size
|
|
need = chunk_sz
|
|
|
|
finally:
|
|
# End of stream: flush any remainder and a zero-length piece if configured
|
|
if buf.size > 0:
|
|
# pad to chunk_sz for model framing, but eff_len is the real (short) length
|
|
eff = int(buf.size)
|
|
pad = np.zeros(chunk_sz, dtype=np.float32)
|
|
pad[:eff] = buf[:eff]
|
|
delta = await infer_one(pad, eff, is_first=not have_sent_any, is_last=False)
|
|
if delta and not drop_event.is_set():
|
|
await queue.put({"time": time.time(), "text": delta, "is_final": False})
|
|
have_sent_any = True
|
|
|
|
if ZERO_PAD_REQUEST_CONTENT:
|
|
# zero-length "flush" chunk
|
|
zero = np.zeros(0, dtype=np.float32)
|
|
await infer_one(zero, 0, is_first=not have_sent_any, is_last=True)
|
|
else:
|
|
# If not sending the explicit flush, still mark last by sending empty with last=True
|
|
zero = np.zeros(0, dtype=np.float32)
|
|
await infer_one(zero, 0, is_first=False, is_last=True)
|
|
|
|
# Emit final full transcript as a message, like your file-based path
|
|
if not drop_event.is_set():
|
|
await queue.put({"time": time.time(), "text": full_tx.strip(), "is_final": True})
|
|
|
|
# Call user finalizer
|
|
await _call_final_cb(full_tx.strip())
|
|
|
|
# signal done to sender
|
|
if not drop_event.is_set():
|
|
await queue.put({"event": "done"})
|
|
|
|
async def send_messages():
|
|
HEARTBEAT_SECS = 10
|
|
print("[WS] Client connected", flush=True)
|
|
try:
|
|
while True:
|
|
try:
|
|
# wait for a real message up to HEARTBEAT_SECS
|
|
msg = await asyncio.wait_for(queue.get(), timeout=HEARTBEAT_SECS)
|
|
except asyncio.TimeoutError:
|
|
# no real message → send heartbeat
|
|
hb = {"event": "heartbeat", "t": time.time()}
|
|
await websocket.send_json(hb)
|
|
continue
|
|
|
|
if msg.get("event") == "done":
|
|
print("[WS] Job finished → closing socket", flush=True)
|
|
break
|
|
|
|
await websocket.send_json(msg)
|
|
|
|
except WebSocketDisconnect:
|
|
print("[WS] Client disconnected (background continues)", flush=True)
|
|
drop_event.set()
|
|
except asyncio.CancelledError:
|
|
print("[WS] Client cancelled (background continues)", flush=True)
|
|
drop_event.set()
|
|
finally:
|
|
with suppress(Exception):
|
|
await websocket.close()
|
|
print("[WS] send_messages END", flush=True)
|
|
|
|
# The remainder of this function is unchanged from your original implementation:
|
|
# - transcribe_from_ffmpeg_stdout(proc)
|
|
# - send_messages()
|
|
# - ffmpeg process startup
|
|
# - task creation + gather + cleanup
|
|
#
|
|
# To keep this patch minimal I re-use your original nested function bodies as-is.
|
|
# Copy the original nested functions transcribe_from_ffmpeg_stdout and send_messages
|
|
# exactly as you had them (they will reference `websocket`, `queue`, `drop_event`,
|
|
# `_open_triton`, etc.). Below I paste them unchanged but with no modifications so
|
|
# you can drop them in place.
|
|
#
|
|
# ---- paste your original nested functions here exactly (no changes) ----
|
|
|
|
# ----- Start ffmpeg (stdin bytes -> stdout float32 PCM @ 16k mono)
|
|
proc = await asyncio.create_subprocess_exec(
|
|
"ffmpeg",
|
|
"-hide_banner",
|
|
"-loglevel",
|
|
"error",
|
|
"-i",
|
|
"pipe:0",
|
|
"-ac",
|
|
"1",
|
|
"-ar",
|
|
str(SAMPLE_RATE_HZ),
|
|
"-f",
|
|
"f32le",
|
|
"pipe:1",
|
|
stdin=asyncio.subprocess.PIPE,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
|
|
async def drain_ffmpeg_stderr(proc):
|
|
try:
|
|
while True:
|
|
line = await proc.stderr.readline()
|
|
if not line:
|
|
break
|
|
# You can log it if needed:
|
|
# print(f"[ffmpeg] {line.decode().rstrip()}", flush=True)
|
|
await asyncio.sleep(0)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
stderr_task = asyncio.create_task(drain_ffmpeg_stderr(proc), name="ffmpeg-stderr")
|
|
|
|
# Run 3 tasks concurrently:
|
|
# - recv frames & feed ffmpeg
|
|
# - read pcm & call Triton
|
|
# - send queue messages to client
|
|
# Note: `transcribe_from_ffmpeg_stdout` and `send_messages` are your original nested functions.
|
|
# Keep them unchanged and ensure they reference `websocket` only for send_json/receive (adapter handles it).
|
|
tasks = [
|
|
asyncio.create_task(recv_frames_and_transcode(proc), name="ws-recv"),
|
|
asyncio.create_task(transcribe_from_ffmpeg_stdout(proc), name="ws-triton"),
|
|
asyncio.create_task(send_messages(), name="ws-send"),
|
|
stderr_task,
|
|
]
|
|
|
|
try:
|
|
await asyncio.gather(*tasks)
|
|
finally:
|
|
for t in tasks:
|
|
if not t.done():
|
|
t.cancel()
|
|
with suppress(Exception):
|
|
proc.terminate()
|
|
with suppress(Exception):
|
|
await proc.wait()
|
|
with suppress(Exception):
|
|
if not f_src.closed:
|
|
f_src.close()
|
|
|
|
@staticmethod
|
|
def _write_temp_file(raw: bytes, filename: str) -> Path:
|
|
tmp = Path(tempfile.gettempdir()) / f"{uuid.uuid4()}{Path(filename).suffix}"
|
|
tmp.write_bytes(raw)
|
|
print(f"[REQ] Temp file written at {tmp}", flush=True)
|
|
return tmp
|
|
|
|
@staticmethod
|
|
def _pcm_int16_to_float32(x: np.ndarray) -> np.ndarray:
|
|
return (x.astype(np.float32) / 32768.0).clip(-1.0, 1.0)
|
|
|
|
def _load_audio(self, path: str, target_sr: int = SAMPLE_RATE_HZ):
|
|
audio = AudioSegment.from_file(path)
|
|
audio = audio.set_channels(1).set_frame_rate(target_sr).set_sample_width(2)
|
|
pcm_int16 = np.frombuffer(audio.raw_data, dtype=np.int16)
|
|
return pcm_int16, target_sr
|
|
|
|
async def _get_chunk_sizes(self, client: grpcclient.InferenceServerClient, model: str) -> Tuple[int, int]:
|
|
try:
|
|
cfg = await client.get_model_config(model, as_json=True)
|
|
params = {p["key"]: p["value"]["string_param"] for p in cfg.get("parameters", [])}
|
|
first_sec = float(params.get("chunk_size_first", params.get("chunk_size", "0.465")))
|
|
norm_sec = float(params.get("chunk_size", "0.32"))
|
|
except Exception:
|
|
first_sec, norm_sec = 0.465, 0.32
|
|
return int(first_sec * SAMPLE_RATE_HZ), int(norm_sec * SAMPLE_RATE_HZ)
|
|
|
|
async def _stream_transcript(self, path: str) -> AsyncGenerator[dict, None]:
|
|
"""
|
|
Read audio from `path`, chunk it and stream to Triton, yielding partial/final messages.
|
|
|
|
Yields dicts of the form:
|
|
{"time": <unix-time>, "text": "<delta text>", "is_final": False}
|
|
and finally:
|
|
{"time": <unix-time>, "text": "<full transcript>", "is_final": True}
|
|
"""
|
|
import time
|
|
import uuid
|
|
from contextlib import suppress
|
|
|
|
# Load audio (int16 PCM) and convert to float32 in [-1, 1]
|
|
pcm_int16, sr = self._load_audio(path, target_sr=SAMPLE_RATE_HZ)
|
|
if pcm_int16.size == 0:
|
|
return
|
|
|
|
wav = self._pcm_int16_to_float32(pcm_int16) # float32 1-D numpy array
|
|
|
|
# Triton client + chunk sizes
|
|
async with self._open_triton() as client:
|
|
first_sz, chunk_sz = await self._get_chunk_sizes(client, MODEL_NAME)
|
|
seq_id = uuid.uuid4().int & 0x7FFF_FFFF_FFFF_FFFF
|
|
|
|
full_tx = "" # accumulated full transcript text
|
|
have_sent_any = False
|
|
|
|
async def infer_one(raw: np.ndarray, eff_len: int, is_first: bool, is_last: bool):
|
|
"""
|
|
Send one request to Triton and return the delta text (or None).
|
|
raw: 1-D float32 numpy array length == raw.shape[0] (should be >= eff_len; padded if needed)
|
|
eff_len: number of valid samples in raw (int)
|
|
"""
|
|
nonlocal full_tx
|
|
wav_np = raw[None, :] # shape (1, T)
|
|
len_np = np.array([[eff_len]], np.int32)
|
|
inp_wav = grpcclient.InferInput(INPUT_WAV_TENSOR, wav_np.shape, np_to_triton_dtype(np.float32))
|
|
inp_len = grpcclient.InferInput(INPUT_LEN_TENSOR, len_np.shape, np_to_triton_dtype(np.int32))
|
|
inp_wav.set_data_from_numpy(wav_np)
|
|
inp_len.set_data_from_numpy(len_np)
|
|
outs = [grpcclient.InferRequestedOutput(OUTPUT_TEXT_TENSOR)]
|
|
|
|
resp = await client.infer(
|
|
MODEL_NAME,
|
|
inputs=[inp_wav, inp_len],
|
|
outputs=outs,
|
|
sequence_id=seq_id,
|
|
sequence_start=is_first,
|
|
sequence_end=is_last,
|
|
)
|
|
# join parts if model returned multiple tensors; decode bytes -> str
|
|
txt = b" ".join(resp.as_numpy(OUTPUT_TEXT_TENSOR)).decode().strip()
|
|
if not txt:
|
|
return None
|
|
delta = txt[len(full_tx) :] if txt.startswith(full_tx) else txt
|
|
full_tx = txt
|
|
return delta or None
|
|
|
|
# iterate over wav in chunks
|
|
T = wav.shape[0]
|
|
offset = 0
|
|
need = first_sz
|
|
# If first_sz is 0 for some reason, fall back to chunk_sz
|
|
if need <= 0:
|
|
need = chunk_sz
|
|
|
|
try:
|
|
while offset < T:
|
|
take = min(need, T - offset)
|
|
piece = wav[offset : offset + take]
|
|
|
|
# If piece is shorter than need (shouldn't happen except maybe for first),
|
|
# pad to `need` as model may expect framing; eff_len tracks real length.
|
|
eff = int(piece.size)
|
|
if eff < need:
|
|
pad = np.zeros(need, dtype=np.float32)
|
|
pad[:eff] = piece
|
|
to_send = pad
|
|
else:
|
|
to_send = piece
|
|
|
|
is_first = not have_sent_any
|
|
have_sent_any = True
|
|
|
|
# Not last unless remaining after this is zero
|
|
is_last = False
|
|
|
|
delta = await infer_one(to_send, eff, is_first=is_first, is_last=is_last)
|
|
if delta:
|
|
yield {"time": time.time(), "text": delta, "is_final": False}
|
|
|
|
offset += take
|
|
need = chunk_sz # after first, normal chunk size
|
|
|
|
# End of stream: if there is leftover silence/padding behavior handled above.
|
|
# After all real chunks sent, optionally flush remainder (zero-length) as last=True.
|
|
if ZERO_PAD_REQUEST_CONTENT:
|
|
# send an explicit flush chunk (zero-length) with sequence_end=True
|
|
zero = np.zeros(0, dtype=np.float32)
|
|
# When sending zero-length, eff_len = 0. Mark is_first True only if never sent anything.
|
|
await infer_one(zero, 0, is_first=(not have_sent_any), is_last=True)
|
|
else:
|
|
# send empty with last=True but not as first
|
|
zero = np.zeros(0, dtype=np.float32)
|
|
await infer_one(zero, 0, is_first=False, is_last=True)
|
|
|
|
# Emit final full transcript
|
|
yield {"time": time.time(), "text": full_tx.strip(), "is_final": True}
|
|
|
|
finally:
|
|
# nothing specific to clean in this generator; Triton client closed by context manager
|
|
with suppress(Exception):
|
|
pass
|