234 lines
8.5 KiB
Python
234 lines
8.5 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
Triton Python Backend: Face Warp / Alignment
|
||
|
||
This model warps each input face crop from 160x160 to a canonical 112x112
|
||
aligned face using 5 facial keypoints. Intended to bridge your
|
||
`face_allignment` → `face_embeding` pipeline.
|
||
|
||
Inputs (batched):
|
||
input : FP32 [N,3,160,160] NCHW face crops.
|
||
landmarks : FP32 [N,5,2] pixel coords (x,y) in 160x160 image space.
|
||
scale : FP32 [N] or [1] (optional) per-sample zoom; >1 zooms in.
|
||
|
||
Outputs:
|
||
output : FP32 [N,3,112,112] NCHW aligned faces.
|
||
# matrix : FP32 [N,2,3] optional affine matrices (commented out below).
|
||
|
||
Notes:
|
||
* Color order is preserved; no channel swapping.
|
||
* Value range is preserved; if your downstream embedding model expects
|
||
normalization (mean/std), perform that there (or in an ensemble step).
|
||
* The canonical 5-point template is scaled from a 96x112 source template
|
||
to 112x112 output width/height; matches typical ArcFace preprocessing.
|
||
"""
|
||
|
||
# import os
|
||
import json
|
||
import numpy as np
|
||
import cv2
|
||
|
||
import triton_python_backend_utils as pb_utils
|
||
|
||
|
||
# import logging
|
||
|
||
# # Put this at the top of your script or inside initialize()
|
||
# logging.basicConfig(level=logging.INFO)
|
||
# logger = logging.getLogger(__name__)
|
||
|
||
|
||
# --------------------------------------------------------------------------- #
|
||
# Utility: build canonical destination template once and reuse #
|
||
# --------------------------------------------------------------------------- #
|
||
def _canonical_template(
|
||
output_w: int, output_h: int, scale_factor: float
|
||
) -> np.ndarray:
|
||
"""
|
||
Compute canonical destination 5-point template scaled to the desired output
|
||
size and zoomed by `scale_factor`.
|
||
|
||
Returns:
|
||
(5,2) float32 array of (x,y) coords in output image space.
|
||
"""
|
||
# Canonical template as provided (nominal crop 96x112).
|
||
# Order: left_eye, right_eye, nose, left_mouth, right_mouth
|
||
reference_points = np.array(
|
||
[
|
||
[30.2946, 51.6963],
|
||
[65.5318, 51.5014],
|
||
[48.0252, 71.7366],
|
||
[33.5493, 92.3655],
|
||
[62.7299, 92.2041],
|
||
],
|
||
dtype=np.float32,
|
||
)
|
||
default_crop_size = np.array([96.0, 112.0], dtype=np.float32) # (w, h)
|
||
|
||
# Scale to target output size
|
||
scale_xy = np.array([output_w, output_h], dtype=np.float32) / default_crop_size
|
||
dst_kps = reference_points * scale_xy
|
||
|
||
# Apply zoom about the center
|
||
center = dst_kps.mean(axis=0, keepdims=True)
|
||
dst_kps = (dst_kps - center) * scale_factor + center
|
||
return dst_kps.astype(np.float32)
|
||
|
||
|
||
def _estimate_affine(src_kps: np.ndarray, dst_kps: np.ndarray) -> np.ndarray:
|
||
"""
|
||
Estimate 2x3 affine transformation mapping src_kps -> dst_kps.
|
||
|
||
Uses cv2.estimateAffinePartial2D with LMEDS for robustness.
|
||
"""
|
||
M, _ = cv2.estimateAffinePartial2D(src_kps, dst_kps, method=cv2.LMEDS)
|
||
if M is None:
|
||
# Fallback: identity with translation to keep image valid.
|
||
M = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=np.float32)
|
||
return M.astype(np.float32)
|
||
|
||
|
||
def _warp_image_nchw(
|
||
img_chw: np.ndarray, M: np.ndarray, out_w: int, out_h: int
|
||
) -> np.ndarray:
|
||
"""
|
||
Warp a single NCHW FP32 image using affine matrix M into out size W,H.
|
||
|
||
Args:
|
||
img_chw: (3,H,W) float32
|
||
M: (2,3) float32
|
||
out_w, out_h: ints
|
||
|
||
|
||
Returns:
|
||
(3,out_h,out_w) float32 aligned image.
|
||
"""
|
||
# logger.info(f"shape of image is: {img_chw.shape}, type of image: {img_chw.dtype}, min: {img_chw.min()} , max is {img_chw.max()}")
|
||
# Convert to HWC for cv2.warpAffine (expects HxW xC, BGR/RGB agnostic)
|
||
img_hwc = np.transpose(img_chw, (1, 2, 0)) # H,W,C
|
||
img_hwc = ((img_hwc + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
|
||
# Ithink input is between -1 to 1, so we change it to 0 , 255 uint
|
||
# img_hwc = ((img_hwc + 1) * 127.5).astype(np.uint8)
|
||
# cv2.imwrite('/models/input_of_warp.jpg', img_hwc)
|
||
warped = cv2.warpAffine(
|
||
img_hwc,
|
||
M,
|
||
dsize=(out_w, out_h), # (width, height)
|
||
flags=cv2.INTER_CUBIC,
|
||
borderMode=cv2.BORDER_REPLICATE,
|
||
)
|
||
# make it bgr:
|
||
# warped = warped[..., ::-1]
|
||
# logger.info(f"shape of warped is: {warped.shape}, type of image: {warped.dtype}, min: {warped.min()} , max is {warped.max()}")
|
||
# warped.astype(np.float32)
|
||
# Back to NCHW
|
||
# cv2.imwrite('/models/warped.jpg', warped)
|
||
warped = np.transpose(warped, (2, 0, 1))
|
||
warped = ((warped.astype(np.float32) / 255.0) - 0.5) / 0.5
|
||
# warped = ((warped /warped.max()) - 0.5) / 0.5
|
||
# logger.info(f"after preproces for embeding: shape of warped is: {warped.shape}, type of image: {warped.dtype}, min: {warped.min()} , max is {warped.max()}")
|
||
return warped
|
||
|
||
|
||
class TritonPythonModel:
|
||
"""
|
||
Triton entrypoint class. One instance per model instance.
|
||
"""
|
||
|
||
def initialize(self, args):
|
||
"""
|
||
Called once when the model is loaded.
|
||
"""
|
||
# Parse model config to get default scale factor (if provided).
|
||
model_config = json.loads(args["model_config"])
|
||
params = model_config.get("parameters", {})
|
||
self.default_scale = float(
|
||
params.get("scale_factor", {}).get("string_value", "1.0")
|
||
)
|
||
|
||
# Output dimensions from config; we assume fixed 112x112.
|
||
# (We could parse from config but we'll hardcode to match pbtxt.)
|
||
self.out_w = 112
|
||
self.out_h = 112
|
||
|
||
# Precompute base canonical template for default scale (will adjust per‑sample if needed).
|
||
self.base_template = _canonical_template(self.out_w, self.out_h, 0.93)
|
||
self.embeding_model_name = "face_embeding"
|
||
|
||
def execute(self, requests):
|
||
responses = []
|
||
|
||
for request in requests:
|
||
# ---- Fetch tensors ----
|
||
# print("hi, new sample")
|
||
in_img_tensor = pb_utils.get_input_tensor_by_name(request, "input")
|
||
in_lmk_tensor = pb_utils.get_input_tensor_by_name(request, "landmarks")
|
||
score_tensor = pb_utils.get_input_tensor_by_name(request, "score")
|
||
|
||
imgs = in_img_tensor.as_numpy() # [B,3,160,160]
|
||
lmks = in_lmk_tensor.as_numpy() # [B,5,2]
|
||
scores = score_tensor.as_numpy() # [B,1]
|
||
|
||
# Ensure batch dimension
|
||
if imgs.ndim == 3:
|
||
imgs = imgs[np.newaxis, ...]
|
||
if lmks.ndim == 2:
|
||
lmks = lmks[np.newaxis, ...]
|
||
if scores.ndim == 1:
|
||
scores = scores[np.newaxis, ...]
|
||
|
||
batch_size = imgs.shape[0]
|
||
aligned_imgs = []
|
||
valid_indices = []
|
||
|
||
# Allocate output buffer
|
||
embedding_out = np.zeros((batch_size, 512), dtype=np.float32)
|
||
embedding_tensor_list = [pb_utils.Tensor("output", embedding_out)]
|
||
|
||
for i in range(batch_size):
|
||
score = max(0.0, scores[i][0])
|
||
# score = scores[i][0]
|
||
if score < 0.9:
|
||
continue # Skip, leave embedding as zero
|
||
src_img = imgs[i]
|
||
src_kps = lmks[i].astype(np.float32) * 160
|
||
|
||
# Align
|
||
dst_kps = self.base_template
|
||
|
||
M = _estimate_affine(src_kps, dst_kps)
|
||
# logger.info(f"src_kps(input): {src_kps}")
|
||
# logger.info(f"dst_kps(grandtruth): {dst_kps}")
|
||
# logger.info(f"M is : {M}")
|
||
warped = _warp_image_nchw(src_img, M, self.out_w, self.out_h)
|
||
|
||
aligned_imgs.append(warped)
|
||
valid_indices.append(i)
|
||
|
||
# Only call embeding model if there are valid samples
|
||
if aligned_imgs:
|
||
aligned_batch = np.stack(aligned_imgs) # shape: [valid_N, 3, 112, 112]
|
||
|
||
# logger.info(f"shape of input of embeding batch : {aligned_batch.shape}, type of image: {aligned_batch.dtype}, min: {aligned_batch.min()} , max is {aligned_batch.max()}")
|
||
infer_input = pb_utils.Tensor("input", aligned_batch)
|
||
inference_request = pb_utils.InferenceRequest(
|
||
model_name=self.embeding_model_name,
|
||
requested_output_names=["output"],
|
||
inputs=[infer_input],
|
||
)
|
||
inference_response = inference_request.exec()
|
||
|
||
embedding_tensor_list = inference_response.output_tensors()
|
||
|
||
responses.append(
|
||
pb_utils.InferenceResponse(output_tensors=embedding_tensor_list)
|
||
)
|
||
|
||
return responses
|
||
|
||
def finalize(self):
|
||
"""
|
||
Called when the model is being unloaded. Nothing to clean up here.
|
||
"""
|
||
return
|