pslime commited on
Commit
5c20c47
·
verified ·
1 Parent(s): 059456e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +30 -20
handler.py CHANGED
@@ -11,7 +11,6 @@ class EndpointHandler:
11
  def __init__(self, path=""):
12
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
- # Load processor + model from the *endpoint repo*
15
  self.processor = AutoImageProcessor.from_pretrained(path)
16
  self.model = AutoModelForDepthEstimation.from_pretrained(path)
17
  self.model.to(self.device)
@@ -19,54 +18,65 @@ class EndpointHandler:
19
 
20
  def __call__(self, data):
21
  """
22
- Expected request body: raw image bytes (recommended)
23
- Hugging Face Endpoints typically pass:
24
- data["inputs"] -> bytes
25
  """
26
- image_bytes = data.get("inputs", None)
27
- if image_bytes is None:
28
- raise ValueError('Missing "inputs". Send raw image bytes as the request body.')
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # Load image
31
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
32
  orig_w, orig_h = image.size
33
 
34
  # Preprocess
35
- inputs = self.processor(images=image, return_tensors="pt")
36
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
37
 
38
  # Inference
39
  with torch.no_grad():
40
- outputs = self.model(**inputs)
41
- predicted_depth = outputs.predicted_depth # shape: [B, H, W] (or similar)
42
 
43
- # Upsample depth to original image size (as in the docs)
44
- # Make it [B,1,H,W] for interpolate
45
- depth = predicted_depth.unsqueeze(1)
46
  depth = F.interpolate(
47
  depth,
48
  size=(orig_h, orig_w),
49
  mode="bicubic",
50
  align_corners=False,
51
  )
52
- depth = depth.squeeze(1).squeeze(0) # [H, W]
53
  depth_np = depth.detach().float().cpu().numpy()
54
 
55
- # ---- Make a nice visualization PNG (0..255) ----
56
  dmin, dmax = float(depth_np.min()), float(depth_np.max())
57
  denom = (dmax - dmin) if (dmax - dmin) > 1e-12 else 1.0
58
  depth_norm = (depth_np - dmin) / denom
59
  depth_uint8 = (depth_norm * 255.0).clip(0, 255).astype(np.uint8)
60
 
61
- depth_img = Image.fromarray(depth_uint8, mode="L") # grayscale
62
  buf = io.BytesIO()
63
  depth_img.save(buf, format="PNG")
64
  depth_png_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
65
 
66
- # ---- Optional: return raw depth as float16 bytes (compact) ----
67
  depth_f16 = depth_np.astype(np.float16)
68
- raw_bytes = depth_f16.tobytes()
69
- depth_raw_base64_f16 = base64.b64encode(raw_bytes).decode("utf-8")
70
 
71
  return {
72
  "type": "relative_depth",
 
11
  def __init__(self, path=""):
12
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
 
14
  self.processor = AutoImageProcessor.from_pretrained(path)
15
  self.model = AutoModelForDepthEstimation.from_pretrained(path)
16
  self.model.to(self.device)
 
18
 
19
  def __call__(self, data):
20
  """
21
+ Supports both common endpoint input styles:
22
+ 1) JSON: {"inputs": "<base64-encoded image bytes>"} (recommended)
23
+ 2) Raw bytes passed through as inputs (fallback)
24
  """
25
+ inputs = data.get("inputs", None)
26
+ if inputs is None:
27
+ raise ValueError('Missing "inputs". Send JSON {"inputs": "<base64>"} or raw bytes.')
28
+
29
+ # Decode inputs -> image_bytes
30
+ if isinstance(inputs, str):
31
+ # JSON base64 string
32
+ try:
33
+ image_bytes = base64.b64decode(inputs)
34
+ except Exception as e:
35
+ raise ValueError(f'Failed to base64-decode "inputs" string: {e}')
36
+ elif isinstance(inputs, (bytes, bytearray)):
37
+ # raw bytes
38
+ image_bytes = bytes(inputs)
39
+ else:
40
+ raise ValueError(f'Unsupported inputs type: {type(inputs)}')
41
 
42
  # Load image
43
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
44
  orig_w, orig_h = image.size
45
 
46
  # Preprocess
47
+ inputs_t = self.processor(images=image, return_tensors="pt")
48
+ inputs_t = {k: v.to(self.device) for k, v in inputs_t.items()}
49
 
50
  # Inference
51
  with torch.no_grad():
52
+ outputs = self.model(**inputs_t)
53
+ predicted_depth = outputs.predicted_depth # [B, H, W]
54
 
55
+ # Upsample to original image size
56
+ depth = predicted_depth.unsqueeze(1) # [B,1,H,W]
 
57
  depth = F.interpolate(
58
  depth,
59
  size=(orig_h, orig_w),
60
  mode="bicubic",
61
  align_corners=False,
62
  )
63
+ depth = depth.squeeze(1).squeeze(0) # [H,W]
64
  depth_np = depth.detach().float().cpu().numpy()
65
 
66
+ # Visualization (0..255 grayscale)
67
  dmin, dmax = float(depth_np.min()), float(depth_np.max())
68
  denom = (dmax - dmin) if (dmax - dmin) > 1e-12 else 1.0
69
  depth_norm = (depth_np - dmin) / denom
70
  depth_uint8 = (depth_norm * 255.0).clip(0, 255).astype(np.uint8)
71
 
72
+ depth_img = Image.fromarray(depth_uint8, mode="L")
73
  buf = io.BytesIO()
74
  depth_img.save(buf, format="PNG")
75
  depth_png_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
76
 
77
+ # Raw float16 depth (compact) NOTE: relative depth, not meters
78
  depth_f16 = depth_np.astype(np.float16)
79
+ depth_raw_base64_f16 = base64.b64encode(depth_f16.tobytes()).decode("utf-8")
 
80
 
81
  return {
82
  "type": "relative_depth",