from transformers import AutoImageProcessor, AutoModelForDepthEstimation from PIL import Image import torch import torch.nn.functional as F import io import base64 import numpy as np class EndpointHandler: def __init__(self, path=""): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.processor = AutoImageProcessor.from_pretrained(path) self.model = AutoModelForDepthEstimation.from_pretrained(path) self.model.to(self.device) self.model.eval() def __call__(self, data): """ Supports both common endpoint input styles: 1) JSON: {"inputs": ""} (recommended) 2) Raw bytes passed through as inputs (fallback) """ inputs = data.get("inputs", None) if inputs is None: raise ValueError('Missing "inputs". Send JSON {"inputs": ""} or raw bytes.') # Decode inputs -> image_bytes if isinstance(inputs, str): # JSON base64 string try: image_bytes = base64.b64decode(inputs) except Exception as e: raise ValueError(f'Failed to base64-decode "inputs" string: {e}') elif isinstance(inputs, (bytes, bytearray)): # raw bytes image_bytes = bytes(inputs) else: raise ValueError(f'Unsupported inputs type: {type(inputs)}') # Load image image = Image.open(io.BytesIO(image_bytes)).convert("RGB") orig_w, orig_h = image.size # Preprocess inputs_t = self.processor(images=image, return_tensors="pt") inputs_t = {k: v.to(self.device) for k, v in inputs_t.items()} # Inference with torch.no_grad(): outputs = self.model(**inputs_t) predicted_depth = outputs.predicted_depth # [B, H, W] # Upsample to original image size depth = predicted_depth.unsqueeze(1) # [B,1,H,W] depth = F.interpolate( depth, size=(orig_h, orig_w), mode="bicubic", align_corners=False, ) depth = depth.squeeze(1).squeeze(0) # [H,W] depth_np = depth.detach().float().cpu().numpy() # Visualization (0..255 grayscale) dmin, dmax = float(depth_np.min()), float(depth_np.max()) denom = (dmax - dmin) if (dmax - dmin) > 1e-12 else 1.0 depth_norm = (depth_np - dmin) / denom depth_uint8 = (depth_norm * 255.0).clip(0, 255).astype(np.uint8) depth_img = Image.fromarray(depth_uint8, mode="L") buf = io.BytesIO() depth_img.save(buf, format="PNG") depth_png_base64 = base64.b64encode(buf.getvalue()).decode("utf-8") # Raw float16 depth (compact) — NOTE: relative depth, not meters depth_f16 = depth_np.astype(np.float16) depth_raw_base64_f16 = base64.b64encode(depth_f16.tobytes()).decode("utf-8") return { "type": "relative_depth", "width": orig_w, "height": orig_h, "depth_png_base64": depth_png_base64, "depth_raw_base64_f16": depth_raw_base64_f16, "raw_dtype": "float16", "raw_shape": [orig_h, orig_w], "viz_min": dmin, "viz_max": dmax, }