2025-09-22 16:56:48 +00:00

203 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Triton Python Backend: Face Warp / Alignment
This model warps each input face crop from 160x160 to a canonical 112x112
aligned face using 5 facial keypoints. Intended to bridge your
`face_allignment` → `face_embeding` pipeline.
Inputs (batched):
input : FP32 [N,3,160,160] NCHW face crops.
landmarks : FP32 [N,5,2] pixel coords (x,y) in 160x160 image space.
scale : FP32 [N] or [1] (optional) per-sample zoom; >1 zooms in.
Outputs:
output : FP32 [N,3,112,112] NCHW aligned faces.
# matrix : FP32 [N,2,3] optional affine matrices (commented out below).
Notes:
* Color order is preserved; no channel swapping.
* Value range is preserved; if your downstream embedding model expects
normalization (mean/std), perform that there (or in an ensemble step).
* The canonical 5-point template is scaled from a 96x112 source template
to 112x112 output width/height; matches typical ArcFace preprocessing.
"""
import os
import json
import numpy as np
import cv2
import triton_python_backend_utils as pb_utils
# --------------------------------------------------------------------------- #
# Utility: build canonical destination template once and reuse #
# --------------------------------------------------------------------------- #
def _canonical_template(output_w: int, output_h: int, scale_factor: float) -> np.ndarray:
"""
Compute canonical destination 5-point template scaled to the desired output
size and zoomed by `scale_factor`.
Returns:
(5,2) float32 array of (x,y) coords in output image space.
"""
# Canonical template as provided (nominal crop 96x112).
# Order: left_eye, right_eye, nose, left_mouth, right_mouth
reference_points = np.array(
[
[30.2946, 51.6963],
[65.5318, 51.5014],
[48.0252, 71.7366],
[33.5493, 92.3655],
[62.7299, 92.2041],
],
dtype=np.float32,
)
default_crop_size = np.array([96.0, 112.0], dtype=np.float32) # (w, h)
# Scale to target output size
scale_xy = np.array([output_w, output_h], dtype=np.float32) / default_crop_size
dst_kps = reference_points * scale_xy
# Apply zoom about the center
center = dst_kps.mean(axis=0, keepdims=True)
dst_kps = (dst_kps - center) * scale_factor + center
return dst_kps.astype(np.float32)
def _estimate_affine(src_kps: np.ndarray, dst_kps: np.ndarray) -> np.ndarray:
"""
Estimate 2x3 affine transformation mapping src_kps -> dst_kps.
Uses cv2.estimateAffinePartial2D with LMEDS for robustness.
"""
# cv2 expects shape (N,2). Ensure contiguous float32.
M, _ = cv2.estimateAffinePartial2D(src_kps, dst_kps, method=cv2.LMEDS)
if M is None:
# Fallback: identity with translation to keep image valid.
M = np.array([[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0]], dtype=np.float32)
return M.astype(np.float32)
def _warp_image_nchw(img_chw: np.ndarray, M: np.ndarray, out_w: int, out_h: int) -> np.ndarray:
"""
Warp a single NCHW FP32 image using affine matrix M into out size W,H.
Args:
img_chw: (3,H,W) float32
M: (2,3) float32
out_w, out_h: ints
Returns:
(3,out_h,out_w) float32 aligned image.
"""
# Convert to HWC for cv2.warpAffine (expects HxW xC, BGR/RGB agnostic)
img_hwc = np.transpose(img_chw, (1, 2, 0)) # H,W,C
warped = cv2.warpAffine(
img_hwc,
M,
dsize=(out_w, out_h), # (width, height)
flags=cv2.INTER_CUBIC,
borderMode=cv2.BORDER_REPLICATE,
)
# Back to NCHW
warped_chw = np.transpose(warped, (2, 0, 1))
return warped_chw.astype(np.float32)
class TritonPythonModel:
"""
Triton entrypoint class. One instance per model instance.
"""
def initialize(self, args):
"""
Called once when the model is loaded.
"""
# Parse model config to get default scale factor (if provided).
model_config = json.loads(args['model_config'])
params = model_config.get('parameters', {})
self.default_scale = float(params.get('scale_factor', {}).get('string_value', '1.0'))
# Output dimensions from config; we assume fixed 112x112.
# (We could parse from config but we'll hardcode to match pbtxt.)
self.out_w = 112
self.out_h = 112
# Precompute base canonical template for default scale (will adjust persample if needed).
self.base_template = _canonical_template(self.out_w, self.out_h, 1.0)
self.embeding_model_name = "face_embeding"
def execute(self, requests):
responses = []
for request in requests:
# ---- Fetch tensors ----
in_img_tensor = pb_utils.get_input_tensor_by_name(request, "input")
in_lmk_tensor = pb_utils.get_input_tensor_by_name(request, "landmarks")
score_tensor = pb_utils.get_input_tensor_by_name(request, "score")
imgs = in_img_tensor.as_numpy() # [B,3,160,160]
lmks = in_lmk_tensor.as_numpy() # [B,5,2]
scores = score_tensor.as_numpy() # [B,1]
# Ensure batch dimension
if imgs.ndim == 3:
imgs = imgs[np.newaxis, ...]
if lmks.ndim == 2:
lmks = lmks[np.newaxis, ...]
if scores.ndim == 1:
scores = scores[np.newaxis, ...]
batch_size = imgs.shape[0]
aligned_imgs = []
valid_indices = []
# Allocate output buffer
embedding_out = np.zeros((batch_size, 512), dtype=np.float32)
embedding_tensor_list = [pb_utils.Tensor("output", embedding_out)]
for i in range(batch_size):
score = max(0.0, scores[i][0])
# score = scores[i][0]
if score < 0.9:
continue # Skip, leave embedding as zero
src_img = imgs[i]
src_kps = lmks[i].astype(np.float32)
# Align
dst_kps = self.base_template
M = _estimate_affine(src_kps, dst_kps)
warped = _warp_image_nchw(src_img, M, self.out_w, self.out_h)
aligned_imgs.append(warped)
valid_indices.append(i)
# Only call embeding model if there are valid samples
if aligned_imgs:
aligned_batch = np.stack(aligned_imgs) # shape: [valid_N, 3, 112, 112]
infer_input = pb_utils.Tensor("input", aligned_batch)
inference_request = pb_utils.InferenceRequest(
model_name=self.embeding_model_name,
requested_output_names=["output"],
inputs=[infer_input]
)
inference_response = inference_request.exec()
embedding_tensor_list = inference_response.output_tensors()
responses.append(pb_utils.InferenceResponse(output_tensors=embedding_tensor_list))
return responses
def finalize(self):
"""
Called when the model is being unloaded. Nothing to clean up here.
"""
return