| from transformers import AutoImageProcessor, AutoModelForDepthEstimation |
| from PIL import Image |
| import torch |
| import torch.nn.functional as F |
| import io |
| import base64 |
| import numpy as np |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| self.processor = AutoImageProcessor.from_pretrained(path) |
| self.model = AutoModelForDepthEstimation.from_pretrained(path) |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| def __call__(self, data): |
| """ |
| Supports both common endpoint input styles: |
| 1) JSON: {"inputs": "<base64-encoded image bytes>"} (recommended) |
| 2) Raw bytes passed through as inputs (fallback) |
| """ |
| inputs = data.get("inputs", None) |
| if inputs is None: |
| raise ValueError('Missing "inputs". Send JSON {"inputs": "<base64>"} or raw bytes.') |
|
|
| |
| if isinstance(inputs, str): |
| |
| try: |
| image_bytes = base64.b64decode(inputs) |
| except Exception as e: |
| raise ValueError(f'Failed to base64-decode "inputs" string: {e}') |
| elif isinstance(inputs, (bytes, bytearray)): |
| |
| image_bytes = bytes(inputs) |
| else: |
| raise ValueError(f'Unsupported inputs type: {type(inputs)}') |
|
|
| |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| orig_w, orig_h = image.size |
|
|
| |
| inputs_t = self.processor(images=image, return_tensors="pt") |
| inputs_t = {k: v.to(self.device) for k, v in inputs_t.items()} |
|
|
| |
| with torch.no_grad(): |
| outputs = self.model(**inputs_t) |
| predicted_depth = outputs.predicted_depth |
|
|
| |
| depth = predicted_depth.unsqueeze(1) |
| depth = F.interpolate( |
| depth, |
| size=(orig_h, orig_w), |
| mode="bicubic", |
| align_corners=False, |
| ) |
| depth = depth.squeeze(1).squeeze(0) |
| depth_np = depth.detach().float().cpu().numpy() |
|
|
| |
| dmin, dmax = float(depth_np.min()), float(depth_np.max()) |
| denom = (dmax - dmin) if (dmax - dmin) > 1e-12 else 1.0 |
| depth_norm = (depth_np - dmin) / denom |
| depth_uint8 = (depth_norm * 255.0).clip(0, 255).astype(np.uint8) |
|
|
| depth_img = Image.fromarray(depth_uint8, mode="L") |
| buf = io.BytesIO() |
| depth_img.save(buf, format="PNG") |
| depth_png_base64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
|
|
| |
| depth_f16 = depth_np.astype(np.float16) |
| depth_raw_base64_f16 = base64.b64encode(depth_f16.tobytes()).decode("utf-8") |
|
|
| return { |
| "type": "relative_depth", |
| "width": orig_w, |
| "height": orig_h, |
| "depth_png_base64": depth_png_base64, |
| "depth_raw_base64_f16": depth_raw_base64_f16, |
| "raw_dtype": "float16", |
| "raw_shape": [orig_h, orig_w], |
| "viz_min": dmin, |
| "viz_max": dmax, |
| } |