diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..746120a --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +*.onnx filter=lfs diff=lfs merge=lfs -text +*.plan filter=lfs diff=lfs merge=lfs -text +*.pbtxt filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index ee31109..4093d40 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ build* CMakeLists.txt.user +__pycache__/ +*.pyc diff --git a/face_post_process/face_allignment/1/model.onnx b/face_post_process/face_allignment/1/model.onnx new file mode 100644 index 0000000..7102624 --- /dev/null +++ b/face_post_process/face_allignment/1/model.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c72127f2cc0b17f7d2d8626ecc2a544aa0a138fe138d082ab7c4c217d74970f6 +size 109446440 diff --git a/face_post_process/face_allignment/1/model.plan b/face_post_process/face_allignment/1/model.plan new file mode 100644 index 0000000..2365899 --- /dev/null +++ b/face_post_process/face_allignment/1/model.plan @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:790c8fc9dfb39f3552eb4e76075e48ff0b92d9e79be9ebdfc6d70c686e4aeda3 +size 59224540 diff --git a/face_post_process/face_allignment/config.pbtxt b/face_post_process/face_allignment/config.pbtxt new file mode 100644 index 0000000..aea0820 --- /dev/null +++ b/face_post_process/face_allignment/config.pbtxt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b09a4985a89fb0a6ecc4ffbd30beb91013101315e68674f87da6a659aa5919e8 +size 500 diff --git a/face_post_process/face_embeding/1/model.onnx b/face_post_process/face_embeding/1/model.onnx new file mode 100644 index 0000000..baf56d5 --- /dev/null +++ b/face_post_process/face_embeding/1/model.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91f5fb93931901456c0d79d87def8b2dbdc7c0d4343122d22faecd6b9364be93 +size 260718443 diff --git a/face_post_process/face_embeding/1/model.plan b/face_post_process/face_embeding/1/model.plan new file mode 100644 index 0000000..fa8cc03 --- /dev/null +++ b/face_post_process/face_embeding/1/model.plan @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ade14618a895be9551c43403707ced5a2a8e3fa579cc73ba7647b1da254faea1 +size 414621820 diff --git a/face_post_process/face_embeding/config.pbtxt b/face_post_process/face_embeding/config.pbtxt new file mode 100644 index 0000000..546115a --- /dev/null +++ b/face_post_process/face_embeding/config.pbtxt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e38a5af547dce4a2ac7563cc0c6a2407629018b6641153150064fd60e23b300e +size 438 diff --git a/face_post_process/face_recognition/config.pbtxt b/face_post_process/face_recognition/config.pbtxt new file mode 100644 index 0000000..fc5eaed --- /dev/null +++ b/face_post_process/face_recognition/config.pbtxt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b9788dd18519fc05b46c07cedcd8642e848f31517314b512b4447917190c30a +size 986 diff --git a/face_post_process/face_warp/1/model.py b/face_post_process/face_warp/1/model.py new file mode 100644 index 0000000..a9bb759 --- /dev/null +++ b/face_post_process/face_warp/1/model.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +""" +Triton Python Backend: Face Warp / Alignment + +This model warps each input face crop from 160x160 to a canonical 112x112 +aligned face using 5 facial keypoints. Intended to bridge your +`face_allignment` → `face_embeding` pipeline. + +Inputs (batched): + input : FP32 [N,3,160,160] NCHW face crops. + landmarks : FP32 [N,5,2] pixel coords (x,y) in 160x160 image space. + scale : FP32 [N] or [1] (optional) per-sample zoom; >1 zooms in. + +Outputs: + output : FP32 [N,3,112,112] NCHW aligned faces. + # matrix : FP32 [N,2,3] optional affine matrices (commented out below). + +Notes: + * Color order is preserved; no channel swapping. + * Value range is preserved; if your downstream embedding model expects + normalization (mean/std), perform that there (or in an ensemble step). + * The canonical 5-point template is scaled from a 96x112 source template + to 112x112 output width/height; matches typical ArcFace preprocessing. +""" + +import os +import json +import numpy as np +import cv2 + +import triton_python_backend_utils as pb_utils + + +# --------------------------------------------------------------------------- # +# Utility: build canonical destination template once and reuse # +# --------------------------------------------------------------------------- # +def _canonical_template(output_w: int, output_h: int, scale_factor: float) -> np.ndarray: + """ + Compute canonical destination 5-point template scaled to the desired output + size and zoomed by `scale_factor`. + + Returns: + (5,2) float32 array of (x,y) coords in output image space. + """ + # Canonical template as provided (nominal crop 96x112). + # Order: left_eye, right_eye, nose, left_mouth, right_mouth + reference_points = np.array( + [ + [30.2946, 51.6963], + [65.5318, 51.5014], + [48.0252, 71.7366], + [33.5493, 92.3655], + [62.7299, 92.2041], + ], + dtype=np.float32, + ) + default_crop_size = np.array([96.0, 112.0], dtype=np.float32) # (w, h) + + # Scale to target output size + scale_xy = np.array([output_w, output_h], dtype=np.float32) / default_crop_size + dst_kps = reference_points * scale_xy + + # Apply zoom about the center + center = dst_kps.mean(axis=0, keepdims=True) + dst_kps = (dst_kps - center) * scale_factor + center + return dst_kps.astype(np.float32) + + +def _estimate_affine(src_kps: np.ndarray, dst_kps: np.ndarray) -> np.ndarray: + """ + Estimate 2x3 affine transformation mapping src_kps -> dst_kps. + + Uses cv2.estimateAffinePartial2D with LMEDS for robustness. + """ + # cv2 expects shape (N,2). Ensure contiguous float32. + M, _ = cv2.estimateAffinePartial2D(src_kps, dst_kps, method=cv2.LMEDS) + if M is None: + # Fallback: identity with translation to keep image valid. + M = np.array([[1.0, 0.0, 0.0], + [0.0, 1.0, 0.0]], dtype=np.float32) + return M.astype(np.float32) + + +def _warp_image_nchw(img_chw: np.ndarray, M: np.ndarray, out_w: int, out_h: int) -> np.ndarray: + """ + Warp a single NCHW FP32 image using affine matrix M into out size W,H. + + Args: + img_chw: (3,H,W) float32 + M: (2,3) float32 + out_w, out_h: ints + + Returns: + (3,out_h,out_w) float32 aligned image. + """ + # Convert to HWC for cv2.warpAffine (expects HxW xC, BGR/RGB agnostic) + img_hwc = np.transpose(img_chw, (1, 2, 0)) # H,W,C + warped = cv2.warpAffine( + img_hwc, + M, + dsize=(out_w, out_h), # (width, height) + flags=cv2.INTER_CUBIC, + borderMode=cv2.BORDER_REPLICATE, + ) + # Back to NCHW + warped_chw = np.transpose(warped, (2, 0, 1)) + return warped_chw.astype(np.float32) + + +class TritonPythonModel: + """ + Triton entrypoint class. One instance per model instance. + """ + + def initialize(self, args): + """ + Called once when the model is loaded. + """ + # Parse model config to get default scale factor (if provided). + model_config = json.loads(args['model_config']) + params = model_config.get('parameters', {}) + self.default_scale = float(params.get('scale_factor', {}).get('string_value', '1.0')) + + # Output dimensions from config; we assume fixed 112x112. + # (We could parse from config but we'll hardcode to match pbtxt.) + self.out_w = 112 + self.out_h = 112 + + # Precompute base canonical template for default scale (will adjust per‑sample if needed). + self.base_template = _canonical_template(self.out_w, self.out_h, 1.0) + self.embeding_model_name = "face_embeding" + + def execute(self, requests): + responses = [] + + for request in requests: + # ---- Fetch tensors ---- + in_img_tensor = pb_utils.get_input_tensor_by_name(request, "input") + in_lmk_tensor = pb_utils.get_input_tensor_by_name(request, "landmarks") + score_tensor = pb_utils.get_input_tensor_by_name(request, "score") + + + imgs = in_img_tensor.as_numpy() # [B,3,160,160] + lmks = in_lmk_tensor.as_numpy() # [B,5,2] + scores = score_tensor.as_numpy() # [B,1] + + + + # Ensure batch dimension + if imgs.ndim == 3: + imgs = imgs[np.newaxis, ...] + if lmks.ndim == 2: + lmks = lmks[np.newaxis, ...] + if scores.ndim == 1: + scores = scores[np.newaxis, ...] + + batch_size = imgs.shape[0] + aligned_imgs = [] + valid_indices = [] + + # Allocate output buffer + embedding_out = np.zeros((batch_size, 512), dtype=np.float32) + embedding_tensor_list = [pb_utils.Tensor("output", embedding_out)] + + for i in range(batch_size): + score = max(0.0, scores[i][0]) + # score = scores[i][0] + if score < 0.9: + continue # Skip, leave embedding as zero + src_img = imgs[i] + src_kps = lmks[i].astype(np.float32) + + # Align + dst_kps = self.base_template + M = _estimate_affine(src_kps, dst_kps) + warped = _warp_image_nchw(src_img, M, self.out_w, self.out_h) + + aligned_imgs.append(warped) + valid_indices.append(i) + + # Only call embeding model if there are valid samples + if aligned_imgs: + aligned_batch = np.stack(aligned_imgs) # shape: [valid_N, 3, 112, 112] + + infer_input = pb_utils.Tensor("input", aligned_batch) + inference_request = pb_utils.InferenceRequest( + model_name=self.embeding_model_name, + requested_output_names=["output"], + inputs=[infer_input] + ) + inference_response = inference_request.exec() + + embedding_tensor_list = inference_response.output_tensors() + + responses.append(pb_utils.InferenceResponse(output_tensors=embedding_tensor_list)) + + return responses + + def finalize(self): + """ + Called when the model is being unloaded. Nothing to clean up here. + """ + return \ No newline at end of file diff --git a/face_post_process/face_warp/config.pbtxt b/face_post_process/face_warp/config.pbtxt new file mode 100644 index 0000000..1b7d169 --- /dev/null +++ b/face_post_process/face_warp/config.pbtxt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:91f819bb4160084c18306177969a845a996841154442869649fb99c433b80934 +size 736 diff --git a/face_post_process/face_warp/requirements.txt b/face_post_process/face_warp/requirements.txt new file mode 100644 index 0000000..80d4dfa --- /dev/null +++ b/face_post_process/face_warp/requirements.txt @@ -0,0 +1,2 @@ +opencv-python-headless==4.10.0.84 +numpy==1.26.4 \ No newline at end of file diff --git a/face_post_process/shahab.jpg b/face_post_process/shahab.jpg new file mode 100644 index 0000000..e683d72 Binary files /dev/null and b/face_post_process/shahab.jpg differ diff --git a/face_post_process/test.py b/face_post_process/test.py new file mode 100644 index 0000000..132c15f --- /dev/null +++ b/face_post_process/test.py @@ -0,0 +1,29 @@ +import numpy as np +import tritonclient.http as httpclient + +# Connect to Triton +client = httpclient.InferenceServerClient(url="localhost:8089") + +# Prepare dummy input image (e.g., normalized float32 [0,1]) +input_data = np.random.rand(1, 3, 160, 160).astype(np.float32) + +# Create Triton input +input_tensor = httpclient.InferInput("input", input_data.shape, "FP32") +input_tensor.set_data_from_numpy(input_data) + +# Declare expected outputs +output_names = ["embedding", "bbox", "score", "landmarks"] +output_tensors = [httpclient.InferRequestedOutput(name) for name in output_names] + +# Send inference request +response = client.infer( + model_name="face_recognition", + inputs=[input_tensor], + outputs=output_tensors +) + +# Parse and print outputs +for name in output_names: + output = response.as_numpy(name) + print(f"{name}: shape={output.shape}, dtype={output.dtype}") + print(output) diff --git a/face_post_process/test2.py b/face_post_process/test2.py new file mode 100644 index 0000000..f28a47e --- /dev/null +++ b/face_post_process/test2.py @@ -0,0 +1,51 @@ +import numpy as np +import tritonclient.http as httpclient +import cv2 # or use PIL.Image if preferred +from pathlib import Path + +# Path to current .py file +current_file = Path(__file__) +current_dir = current_file.parent.resolve() + +# ----------------------------- +# Load JPEG and preprocess +# ----------------------------- +image_path = current_dir / "shahab.jpg" # path to your JPEG file +img = cv2.imread(image_path) # BGR, shape: (H, W, 3) +img = cv2.resize(img, (160, 160)) # resize to 160x160 +img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # convert to RGB +img = img.astype(np.float32) / 255.0 # normalize to [0, 1] + +# Change to NCHW (3, 160, 160) +img_chw = np.transpose(img, (2, 0, 1)) + +# Add batch dim: (1, 3, 160, 160) +input_data = img_chw[np.newaxis, :] + +# ----------------------------- +# Prepare Triton HTTP client +# ----------------------------- +client = httpclient.InferenceServerClient(url="localhost:9000") + +# Prepare input tensor +input_tensor = httpclient.InferInput("input", input_data.shape, "FP32") +input_tensor.set_data_from_numpy(input_data) + +# Prepare expected outputs +output_names = ["embedding", "bbox", "score", "landmarks"] +output_tensors = [httpclient.InferRequestedOutput(name) for name in output_names] + +# Send inference request +response = client.infer( + model_name="face_recognition", + inputs=[input_tensor], + outputs=output_tensors +) + +# ----------------------------- +# Print outputs +# ----------------------------- +for name in output_names: + output = response.as_numpy(name) + print(f"{name}: shape={output.shape}, dtype={output.dtype}") + print(output)