pslime's picture
Update handler.py
5c20c47 verified
raw
history blame
3.31 kB
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
from PIL import Image
import torch
import torch.nn.functional as F
import io
import base64
import numpy as np
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = AutoImageProcessor.from_pretrained(path)
self.model = AutoModelForDepthEstimation.from_pretrained(path)
self.model.to(self.device)
self.model.eval()
def __call__(self, data):
"""
Supports both common endpoint input styles:
1) JSON: {"inputs": "<base64-encoded image bytes>"} (recommended)
2) Raw bytes passed through as inputs (fallback)
"""
inputs = data.get("inputs", None)
if inputs is None:
raise ValueError('Missing "inputs". Send JSON {"inputs": "<base64>"} or raw bytes.')
# Decode inputs -> image_bytes
if isinstance(inputs, str):
# JSON base64 string
try:
image_bytes = base64.b64decode(inputs)
except Exception as e:
raise ValueError(f'Failed to base64-decode "inputs" string: {e}')
elif isinstance(inputs, (bytes, bytearray)):
# raw bytes
image_bytes = bytes(inputs)
else:
raise ValueError(f'Unsupported inputs type: {type(inputs)}')
# Load image
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
orig_w, orig_h = image.size
# Preprocess
inputs_t = self.processor(images=image, return_tensors="pt")
inputs_t = {k: v.to(self.device) for k, v in inputs_t.items()}
# Inference
with torch.no_grad():
outputs = self.model(**inputs_t)
predicted_depth = outputs.predicted_depth # [B, H, W]
# Upsample to original image size
depth = predicted_depth.unsqueeze(1) # [B,1,H,W]
depth = F.interpolate(
depth,
size=(orig_h, orig_w),
mode="bicubic",
align_corners=False,
)
depth = depth.squeeze(1).squeeze(0) # [H,W]
depth_np = depth.detach().float().cpu().numpy()
# Visualization (0..255 grayscale)
dmin, dmax = float(depth_np.min()), float(depth_np.max())
denom = (dmax - dmin) if (dmax - dmin) > 1e-12 else 1.0
depth_norm = (depth_np - dmin) / denom
depth_uint8 = (depth_norm * 255.0).clip(0, 255).astype(np.uint8)
depth_img = Image.fromarray(depth_uint8, mode="L")
buf = io.BytesIO()
depth_img.save(buf, format="PNG")
depth_png_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
# Raw float16 depth (compact) — NOTE: relative depth, not meters
depth_f16 = depth_np.astype(np.float16)
depth_raw_base64_f16 = base64.b64encode(depth_f16.tobytes()).decode("utf-8")
return {
"type": "relative_depth",
"width": orig_w,
"height": orig_h,
"depth_png_base64": depth_png_base64,
"depth_raw_base64_f16": depth_raw_base64_f16,
"raw_dtype": "float16",
"raw_shape": [orig_h, orig_w],
"viz_min": dmin,
"viz_max": dmax,
}