xinjie.wang commited on
Commit
be013ba
·
1 Parent(s): 9e78843
app.py.bak → app.bk.py RENAMED
File without changes
app_full.py.bak DELETED
@@ -1,571 +0,0 @@
1
- import os as _os
2
- import sys as _sys
3
- import subprocess as _subprocess
4
-
5
- print("=" * 60, flush=True)
6
- print("[DEBUG] ===== Environment Diagnostics (no CUDA init) =====", flush=True)
7
- print(f"[DEBUG] Python: {_sys.version}", flush=True)
8
- print(f"[DEBUG] CWD: {_os.getcwd()}", flush=True)
9
-
10
- try:
11
- _nvcc_out = _subprocess.check_output(["nvcc", "--version"], stderr=_subprocess.STDOUT, text=True)
12
- print(f"[DEBUG] nvcc: {_nvcc_out.strip().splitlines()[-1]}", flush=True)
13
- except Exception as _e:
14
- print(f"[DEBUG] nvcc not found: {_e}", flush=True)
15
-
16
- try:
17
- _smi_out = _subprocess.check_output(["nvidia-smi", "-L"], stderr=_subprocess.STDOUT, text=True)
18
- print(f"[DEBUG] nvidia-smi -L: {_smi_out.strip()}", flush=True)
19
- except Exception:
20
- print("[DEBUG] nvidia-smi not available at startup (expected for ZeroGPU)", flush=True)
21
-
22
- try:
23
- with open("/proc/driver/nvidia/version") as _f:
24
- _lines = _f.read().strip().splitlines()
25
- print(f"[DEBUG] NVIDIA driver: {_lines[0] if _lines else 'unknown'}", flush=True)
26
- except Exception:
27
- print("[DEBUG] /proc/driver/nvidia/version not found", flush=True)
28
-
29
- for _env_key in sorted(_os.environ):
30
- if any(_kw in _env_key.upper() for _kw in ["CUDA", "GPU", "NVIDIA", "ZERO", "SPACES"]):
31
- print(f"[DEBUG] ENV {_env_key}={_os.environ[_env_key]}", flush=True)
32
-
33
- print("=" * 60, flush=True)
34
-
35
- # Project EmbodiedGen
36
- #
37
- # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
38
- #
39
- # Licensed under the Apache License, Version 2.0 (the "License");
40
- # you may not use this file except in compliance with the License.
41
- # You may obtain a copy of the License at
42
- #
43
- # http://www.apache.org/licenses/LICENSE-2.0
44
- #
45
- # Unless required by applicable law or agreed to in writing, software
46
- # distributed under the License is distributed on an "AS IS" BASIS,
47
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
48
- # implied. See the License for the specific language governing
49
- # permissions and limitations under the License.
50
-
51
-
52
- import os
53
-
54
- # GRADIO_APP == "imageto3d_sam3d", sam3d object model, by default.
55
- # GRADIO_APP == "imageto3d", TRELLIS model.
56
- os.environ["GRADIO_APP"] = "imageto3d_sam3d"
57
- from glob import glob
58
-
59
- import gradio as gr
60
- from app_style import custom_theme, image_css, lighting_css
61
- from common import (
62
- MAX_SEED,
63
- VERSION,
64
- active_btn_by_content,
65
- end_session,
66
- extract_3d_representations_v3,
67
- extract_urdf,
68
- get_seed,
69
- image_to_3d,
70
- preprocess_image_fn,
71
- preprocess_sam_image_fn,
72
- select_point,
73
- start_session,
74
- )
75
-
76
- app_name = os.getenv("GRADIO_APP")
77
- if app_name == "imageto3d_sam3d":
78
- _enable_pre_resize_default = False
79
- sample_step = 25
80
- bg_rm_model_name = "rembg" # "rembg", "rmbg14"
81
- elif app_name == "imageto3d":
82
- _enable_pre_resize_default = True
83
- sample_step = 12
84
- bg_rm_model_name = "rembg" # "rembg", "rmbg14"
85
-
86
- current_rmbg_tag = bg_rm_model_name
87
- def set_current_rmbg_tag(rmbg: str) -> None:
88
- global current_rmbg_tag
89
- current_rmbg_tag = rmbg
90
-
91
-
92
- def preprocess_example_image(
93
- img: str,
94
- ) -> tuple[object, object, gr.Button]:
95
- image, image_cache = preprocess_image_fn(
96
- img, current_rmbg_tag, _enable_pre_resize_default
97
- )
98
- return image, image_cache, gr.Button(interactive=True)
99
-
100
- with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
101
- gr.HTML(image_css, visible=False)
102
- # gr.HTML(lighting_css, visible=False)
103
- gr.Markdown(
104
- """
105
- ## ***EmbodiedGen***: Image-to-3D Asset
106
- **🔖 Version**: {VERSION}
107
- <p style="display: flex; gap: 10px; flex-wrap: nowrap;">
108
- <a href="https://horizonrobotics.github.io/EmbodiedGen">
109
- <img alt="📖 Documentation" src="https://img.shields.io/badge/📖-Documentation-blue">
110
- </a>
111
- <a href="https://arxiv.org/abs/2506.10600">
112
- <img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
113
- </a>
114
- <a href="https://github.com/HorizonRobotics/EmbodiedGen">
115
- <img alt="💻 GitHub" src="https://img.shields.io/badge/GitHub-000000?logo=github">
116
- </a>
117
- <a href="https://www.youtube.com/watch?v=rG4odybuJRk">
118
- <img alt="🎥 Video" src="https://img.shields.io/badge/🎥-Video-red">
119
- </a>
120
- </p>
121
-
122
- 🖼️ Generate physically plausible 3D asset from single input image.
123
- """.format(
124
- VERSION=VERSION
125
- ),
126
- elem_classes=["header"],
127
- )
128
- enable_pre_resize = gr.State(_enable_pre_resize_default)
129
- with gr.Row():
130
- with gr.Column(scale=3):
131
- with gr.Tabs() as input_tabs:
132
- with gr.Tab(
133
- label="Image(auto seg)", id=0
134
- ) as single_image_input_tab:
135
- raw_image_cache = gr.Image(
136
- format="png",
137
- image_mode="RGB",
138
- type="pil",
139
- visible=False,
140
- )
141
- image_prompt = gr.Image(
142
- label="Input Image",
143
- format="png",
144
- image_mode="RGBA",
145
- type="pil",
146
- height=400,
147
- elem_classes=["image_fit"],
148
- )
149
- gr.Markdown(
150
- """
151
- If you are not satisfied with the auto segmentation
152
- result, please switch to the `Image(SAM seg)` tab."""
153
- )
154
- with gr.Tab(
155
- label="Image(SAM seg)", id=1
156
- ) as samimage_input_tab:
157
- with gr.Row():
158
- with gr.Column(scale=1):
159
- image_prompt_sam = gr.Image(
160
- label="Input Image",
161
- type="numpy",
162
- height=400,
163
- elem_classes=["image_fit"],
164
- )
165
- image_seg_sam = gr.Image(
166
- label="SAM Seg Image",
167
- image_mode="RGBA",
168
- type="pil",
169
- height=400,
170
- visible=False,
171
- )
172
- with gr.Column(scale=1):
173
- image_mask_sam = gr.AnnotatedImage(
174
- elem_classes=["image_fit"]
175
- )
176
-
177
- fg_bg_radio = gr.Radio(
178
- ["foreground_point", "background_point"],
179
- label="Select foreground(green) or background(red) points, by default foreground", # noqa
180
- value="foreground_point",
181
- )
182
- gr.Markdown(
183
- """ Click the `Input Image` to select SAM points,
184
- after get the satisified segmentation, click `Generate`
185
- button to generate the 3D asset. \n
186
- Note: If the segmented foreground is too small relative
187
- to the entire image area, the generation will fail.
188
- """
189
- )
190
-
191
- with gr.Accordion(label="Generation Settings", open=False):
192
- with gr.Row():
193
- seed = gr.Slider(
194
- 0, MAX_SEED, label="Seed", value=0, step=1
195
- )
196
- texture_size = gr.Slider(
197
- 1024,
198
- 4096,
199
- label="UV texture size",
200
- value=2048,
201
- step=256,
202
- )
203
- rmbg_tag = gr.Radio(
204
- choices=["rembg", "rmbg14"],
205
- value=bg_rm_model_name,
206
- label="Background Removal Model",
207
- )
208
- with gr.Row():
209
- randomize_seed = gr.Checkbox(
210
- label="Randomize Seed", value=False
211
- )
212
- project_delight = gr.Checkbox(
213
- label="Back-project Delight",
214
- value=True,
215
- )
216
- gr.Markdown("Geo Structure Generation")
217
- with gr.Row():
218
- ss_guidance_strength = gr.Slider(
219
- 0.0,
220
- 10.0,
221
- label="Guidance Strength",
222
- value=7.5,
223
- step=0.1,
224
- )
225
- ss_sampling_steps = gr.Slider(
226
- 1,
227
- 50,
228
- label="Sampling Steps",
229
- value=sample_step,
230
- step=1,
231
- )
232
- gr.Markdown("Visual Appearance Generation")
233
- with gr.Row():
234
- slat_guidance_strength = gr.Slider(
235
- 0.0,
236
- 10.0,
237
- label="Guidance Strength",
238
- value=3.0,
239
- step=0.1,
240
- )
241
- slat_sampling_steps = gr.Slider(
242
- 1,
243
- 50,
244
- label="Sampling Steps",
245
- value=sample_step,
246
- step=1,
247
- )
248
-
249
- generate_btn = gr.Button(
250
- "🚀 1. Generate(~2 mins)",
251
- variant="primary",
252
- interactive=False,
253
- )
254
- model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
255
- # with gr.Row():
256
- # extract_rep3d_btn = gr.Button(
257
- # "🔍 2. Extract 3D Representation(~2 mins)",
258
- # variant="primary",
259
- # interactive=False,
260
- # )
261
- with gr.Accordion(
262
- label="Enter Asset Attributes(optional)", open=False
263
- ):
264
- asset_cat_text = gr.Textbox(
265
- label="Enter Asset Category (e.g., chair)"
266
- )
267
- height_range_text = gr.Textbox(
268
- label="Enter **Height Range** in meter (e.g., 0.5-0.6)"
269
- )
270
- mass_range_text = gr.Textbox(
271
- label="Enter **Mass Range** in kg (e.g., 1.1-1.2)"
272
- )
273
- asset_version_text = gr.Textbox(
274
- label=f"Enter version (e.g., {VERSION})"
275
- )
276
- with gr.Row():
277
- extract_urdf_btn = gr.Button(
278
- "🧩 2. Extract URDF with physics(~1 mins)",
279
- variant="primary",
280
- interactive=False,
281
- )
282
- with gr.Row():
283
- gr.Markdown(
284
- "#### Estimated Asset 3D Attributes(No input required)"
285
- )
286
- with gr.Row():
287
- est_type_text = gr.Textbox(
288
- label="Asset category", interactive=False
289
- )
290
- est_height_text = gr.Textbox(
291
- label="Real height(.m)", interactive=False
292
- )
293
- est_mass_text = gr.Textbox(
294
- label="Mass(.kg)", interactive=False
295
- )
296
- est_mu_text = gr.Textbox(
297
- label="Friction coefficient", interactive=False
298
- )
299
- with gr.Row():
300
- download_urdf = gr.DownloadButton(
301
- label="⬇️ 3. Download URDF",
302
- variant="primary",
303
- interactive=False,
304
- )
305
-
306
- gr.Markdown(
307
- """ NOTE: If `Asset Attributes` are provided, it will guide
308
- GPT to perform physical attributes restoration. \n
309
- The `Download URDF` file is restored to the real scale and
310
- has quality inspection, open with an editor to view details.
311
- """
312
- )
313
- with gr.Row() as single_image_example:
314
- examples = gr.Examples(
315
- label="Image Gallery",
316
- examples=[
317
- [image_path]
318
- for image_path in sorted(
319
- glob("assets/example_image/*")
320
- )
321
- ],
322
- inputs=[image_prompt],
323
- fn=preprocess_example_image,
324
- outputs=[image_prompt, raw_image_cache, generate_btn],
325
- run_on_click=True,
326
- examples_per_page=10,
327
- cache_examples=False,
328
- )
329
-
330
- with gr.Row(visible=False) as single_sam_image_example:
331
- examples = gr.Examples(
332
- label="Image Gallery",
333
- examples=[
334
- [image_path]
335
- for image_path in sorted(
336
- glob("assets/example_image/*")
337
- )
338
- ],
339
- inputs=[image_prompt_sam],
340
- fn=preprocess_sam_image_fn,
341
- outputs=[image_prompt_sam, raw_image_cache],
342
- run_on_click=True,
343
- examples_per_page=10,
344
- )
345
- with gr.Column(scale=2):
346
- gr.Markdown("<br>")
347
- video_output = gr.Video(
348
- label="Generated 3D Asset",
349
- autoplay=True,
350
- loop=True,
351
- height=400,
352
- )
353
- model_output_gs = gr.Model3D(
354
- label="Gaussian Representation", height=350, interactive=False
355
- )
356
- aligned_gs = gr.Textbox(visible=False)
357
- gr.Markdown(
358
- """ The rendering of `Gaussian Representation` takes additional 10s. """ # noqa
359
- )
360
- with gr.Row():
361
- model_output_mesh = gr.Model3D(
362
- label="Mesh Representation",
363
- height=350,
364
- interactive=False,
365
- clear_color=[0, 0, 0, 1],
366
- elem_id="lighter_mesh",
367
- )
368
-
369
- is_samimage = gr.State(False)
370
- output_buf = gr.State()
371
- selected_points = gr.State(value=[])
372
-
373
- demo.load(start_session)
374
- demo.unload(end_session)
375
-
376
- single_image_input_tab.select(
377
- lambda: tuple(
378
- [False, gr.Row.update(visible=True), gr.Row.update(visible=False)]
379
- ),
380
- outputs=[is_samimage, single_image_example, single_sam_image_example],
381
- )
382
- samimage_input_tab.select(
383
- lambda: tuple(
384
- [True, gr.Row.update(visible=True), gr.Row.update(visible=False)]
385
- ),
386
- outputs=[is_samimage, single_sam_image_example, single_image_example],
387
- )
388
-
389
- image_prompt.upload(
390
- lambda img, rmbg: preprocess_image_fn(img, rmbg, _enable_pre_resize_default),
391
- inputs=[image_prompt, rmbg_tag],
392
- outputs=[image_prompt, raw_image_cache],
393
- queue=False,
394
- ).success(
395
- active_btn_by_content,
396
- inputs=image_prompt,
397
- outputs=generate_btn,
398
- )
399
- rmbg_tag.change(
400
- set_current_rmbg_tag,
401
- inputs=[rmbg_tag],
402
- outputs=[],
403
- )
404
-
405
- image_prompt.change(
406
- lambda: tuple(
407
- [
408
- # gr.Button(interactive=False),
409
- gr.Button(interactive=False),
410
- gr.Button(interactive=False),
411
- None,
412
- "",
413
- None,
414
- None,
415
- "",
416
- "",
417
- "",
418
- "",
419
- "",
420
- "",
421
- "",
422
- "",
423
- ]
424
- ),
425
- outputs=[
426
- # extract_rep3d_btn,
427
- extract_urdf_btn,
428
- download_urdf,
429
- model_output_gs,
430
- aligned_gs,
431
- model_output_mesh,
432
- video_output,
433
- asset_cat_text,
434
- height_range_text,
435
- mass_range_text,
436
- asset_version_text,
437
- est_type_text,
438
- est_height_text,
439
- est_mass_text,
440
- est_mu_text,
441
- ],
442
- )
443
- image_prompt.clear(
444
- lambda: gr.Button(interactive=False),
445
- outputs=[generate_btn],
446
- )
447
-
448
- image_prompt_sam.upload(
449
- preprocess_sam_image_fn,
450
- inputs=[image_prompt_sam],
451
- outputs=[image_prompt_sam, raw_image_cache],
452
- )
453
- image_prompt_sam.change(
454
- lambda: tuple(
455
- [
456
- # gr.Button(interactive=False),
457
- gr.Button(interactive=False),
458
- gr.Button(interactive=False),
459
- None,
460
- None,
461
- None,
462
- "",
463
- "",
464
- "",
465
- "",
466
- "",
467
- "",
468
- "",
469
- "",
470
- None,
471
- [],
472
- ]
473
- ),
474
- outputs=[
475
- # extract_rep3d_btn,
476
- extract_urdf_btn,
477
- download_urdf,
478
- model_output_gs,
479
- model_output_mesh,
480
- video_output,
481
- asset_cat_text,
482
- height_range_text,
483
- mass_range_text,
484
- asset_version_text,
485
- est_type_text,
486
- est_height_text,
487
- est_mass_text,
488
- est_mu_text,
489
- image_mask_sam,
490
- selected_points,
491
- ],
492
- )
493
-
494
- image_prompt_sam.select(
495
- select_point,
496
- [
497
- image_prompt_sam,
498
- selected_points,
499
- fg_bg_radio,
500
- ],
501
- [image_mask_sam, image_seg_sam],
502
- )
503
- image_seg_sam.change(
504
- active_btn_by_content,
505
- inputs=image_seg_sam,
506
- outputs=generate_btn,
507
- )
508
-
509
- generate_btn.click(
510
- get_seed,
511
- inputs=[randomize_seed, seed],
512
- outputs=[seed],
513
- ).success(
514
- image_to_3d,
515
- inputs=[
516
- image_prompt,
517
- seed,
518
- ss_sampling_steps,
519
- slat_sampling_steps,
520
- raw_image_cache,
521
- ss_guidance_strength,
522
- slat_guidance_strength,
523
- image_seg_sam,
524
- is_samimage,
525
- ],
526
- outputs=[output_buf, video_output],
527
- ).success(
528
- extract_3d_representations_v3,
529
- inputs=[
530
- output_buf,
531
- project_delight,
532
- texture_size,
533
- ],
534
- outputs=[
535
- model_output_mesh,
536
- model_output_gs,
537
- model_output_obj,
538
- aligned_gs,
539
- ],
540
- ).success(
541
- lambda: gr.Button(interactive=True),
542
- outputs=[extract_urdf_btn],
543
- )
544
-
545
- extract_urdf_btn.click(
546
- extract_urdf,
547
- inputs=[
548
- aligned_gs,
549
- model_output_obj,
550
- asset_cat_text,
551
- height_range_text,
552
- mass_range_text,
553
- asset_version_text,
554
- ],
555
- outputs=[
556
- download_urdf,
557
- est_type_text,
558
- est_height_text,
559
- est_mass_text,
560
- est_mu_text,
561
- ],
562
- queue=True,
563
- show_progress="full",
564
- ).success(
565
- lambda: gr.Button(interactive=True),
566
- outputs=[download_urdf],
567
- )
568
-
569
-
570
- if __name__ == "__main__":
571
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
common.bk.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import spaces
18
+ from embodied_gen.utils.monkey_patch.trellis import monkey_path_trellis
19
+
20
+ monkey_path_trellis()
21
+
22
+ import gc
23
+ import logging
24
+ import os
25
+ import shutil
26
+ import subprocess
27
+ import sys
28
+ from glob import glob
29
+
30
+ import cv2
31
+ import gradio as gr
32
+ import numpy as np
33
+ import torch
34
+ import trimesh
35
+ from PIL import Image
36
+ from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
37
+ from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
38
+ from embodied_gen.data.differentiable_render import entrypoint as render_api
39
+ from embodied_gen.data.utils import trellis_preprocess, zip_files
40
+ from embodied_gen.models.delight_model import DelightingModel
41
+ from embodied_gen.models.gs_model import GaussianOperator
42
+ from embodied_gen.models.sam3d import Sam3dInference
43
+ from embodied_gen.models.segment_model import (
44
+ BMGG14Remover,
45
+ RembgRemover,
46
+ SAMPredictor,
47
+ )
48
+ from embodied_gen.models.sr_model import ImageRealESRGAN, ImageStableSR
49
+ from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
50
+ from embodied_gen.scripts.render_mv import build_texture_gen_pipe, infer_pipe
51
+ from embodied_gen.scripts.text2image import (
52
+ build_text2img_ip_pipeline,
53
+ build_text2img_pipeline,
54
+ text2img_gen,
55
+ )
56
+ from embodied_gen.utils.gpt_clients import GPT_CLIENT
57
+ from embodied_gen.utils.process_media import (
58
+ filter_image_small_connected_components,
59
+ keep_largest_connected_component,
60
+ merge_images_video,
61
+ )
62
+ from embodied_gen.utils.tags import VERSION
63
+ from embodied_gen.utils.trender import pack_state, render_video, unpack_state
64
+ from embodied_gen.validators.quality_checkers import (
65
+ BaseChecker,
66
+ ImageAestheticChecker,
67
+ ImageSegChecker,
68
+ MeshGeoChecker,
69
+ )
70
+ from embodied_gen.validators.urdf_convertor import URDFGenerator
71
+
72
+ current_file_path = os.path.abspath(__file__)
73
+ current_dir = os.path.dirname(current_file_path)
74
+ sys.path.append(os.path.join(current_dir, ".."))
75
+ from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
76
+ from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
77
+
78
+ logging.basicConfig(
79
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
80
+ )
81
+ logger = logging.getLogger(__name__)
82
+
83
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
84
+ os.environ.setdefault("OPENAI_API_KEY", "sk-placeholder")
85
+ MAX_SEED = 100000
86
+
87
+ # DELIGHT = DelightingModel()
88
+ # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
89
+ # IMAGESR_MODEL = ImageStableSR()
90
+ if os.getenv("GRADIO_APP").startswith("imageto3d"):
91
+ RBG_REMOVER = RembgRemover()
92
+ RBG14_REMOVER = BMGG14Remover()
93
+ SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
94
+ if "sam3d" in os.getenv("GRADIO_APP"):
95
+ PIPELINE = Sam3dInference()
96
+ else:
97
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
98
+ "microsoft/TRELLIS-image-large"
99
+ )
100
+ # PIPELINE.cuda()
101
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
102
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
103
+ AESTHETIC_CHECKER = ImageAestheticChecker()
104
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
105
+ TMP_DIR = os.path.join(
106
+ os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
107
+ )
108
+ os.makedirs(TMP_DIR, exist_ok=True)
109
+ elif os.getenv("GRADIO_APP").startswith("textto3d"):
110
+ RBG_REMOVER = RembgRemover()
111
+ RBG14_REMOVER = BMGG14Remover()
112
+ if "sam3d" in os.getenv("GRADIO_APP"):
113
+ PIPELINE = Sam3dInference()
114
+ else:
115
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
116
+ "microsoft/TRELLIS-image-large"
117
+ )
118
+ # PIPELINE.cuda()
119
+ text_model_dir = "weights/Kolors"
120
+ PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
121
+ PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
122
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
123
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
124
+ AESTHETIC_CHECKER = ImageAestheticChecker()
125
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
126
+ TMP_DIR = os.path.join(
127
+ os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
128
+ )
129
+ os.makedirs(TMP_DIR, exist_ok=True)
130
+ elif os.getenv("GRADIO_APP") == "texture_edit":
131
+ DELIGHT = DelightingModel()
132
+ IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
133
+ PIPELINE_IP = build_texture_gen_pipe(
134
+ base_ckpt_dir="./weights",
135
+ ip_adapt_scale=0.7,
136
+ device="cuda",
137
+ )
138
+ PIPELINE = build_texture_gen_pipe(
139
+ base_ckpt_dir="./weights",
140
+ ip_adapt_scale=0,
141
+ device="cuda",
142
+ )
143
+ TMP_DIR = os.path.join(
144
+ os.path.dirname(os.path.abspath(__file__)), "sessions/texture_edit"
145
+ )
146
+ os.makedirs(TMP_DIR, exist_ok=True)
147
+
148
+
149
+ def start_session(req: gr.Request) -> None:
150
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
151
+ os.makedirs(user_dir, exist_ok=True)
152
+
153
+
154
+ def end_session(req: gr.Request) -> None:
155
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
156
+ if os.path.exists(user_dir):
157
+ shutil.rmtree(user_dir)
158
+
159
+
160
+ def preprocess_image_fn(
161
+ image: str | np.ndarray | Image.Image,
162
+ rmbg_tag: str = "rembg",
163
+ preprocess: bool = True,
164
+ ) -> tuple[Image.Image, Image.Image]:
165
+ if isinstance(image, str):
166
+ image = Image.open(image)
167
+ elif isinstance(image, np.ndarray):
168
+ image = Image.fromarray(image)
169
+
170
+ image_cache = image.copy() # resize_pil(image.copy(), 1024)
171
+
172
+ bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
173
+ image = bg_remover(image)
174
+ image = keep_largest_connected_component(image)
175
+
176
+ if preprocess:
177
+ image = trellis_preprocess(image)
178
+
179
+ return image, image_cache
180
+
181
+
182
+ def preprocess_sam_image_fn(
183
+ image: Image.Image,
184
+ ) -> tuple[Image.Image, Image.Image]:
185
+ if isinstance(image, np.ndarray):
186
+ image = Image.fromarray(image)
187
+
188
+ sam_image = SAM_PREDICTOR.preprocess_image(image)
189
+ image_cache = sam_image.copy()
190
+ SAM_PREDICTOR.predictor.set_image(sam_image)
191
+
192
+ return sam_image, image_cache
193
+
194
+
195
+ def active_btn_by_content(content: gr.Image) -> gr.Button:
196
+ interactive = True if content is not None else False
197
+
198
+ return gr.Button(interactive=interactive)
199
+
200
+
201
+ def active_btn_by_text_content(content: gr.Textbox) -> gr.Button:
202
+ if content is not None and len(content) > 0:
203
+ interactive = True
204
+ else:
205
+ interactive = False
206
+
207
+ return gr.Button(interactive=interactive)
208
+
209
+
210
+ def get_selected_image(
211
+ choice: str, sample1: str, sample2: str, sample3: str
212
+ ) -> str:
213
+ if choice == "sample1":
214
+ return sample1
215
+ elif choice == "sample2":
216
+ return sample2
217
+ elif choice == "sample3":
218
+ return sample3
219
+ else:
220
+ raise ValueError(f"Invalid choice: {choice}")
221
+
222
+
223
+ def get_cached_image(image_path: str) -> Image.Image:
224
+ if isinstance(image_path, Image.Image):
225
+ return image_path
226
+ return Image.open(image_path).resize((512, 512))
227
+
228
+
229
+ def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int:
230
+ return np.random.randint(0, max_seed) if randomize_seed else seed
231
+
232
+
233
+ def select_point(
234
+ image: np.ndarray,
235
+ sel_pix: list,
236
+ point_type: str,
237
+ evt: gr.SelectData,
238
+ ):
239
+ if point_type == "foreground_point":
240
+ sel_pix.append((evt.index, 1)) # append the foreground_point
241
+ elif point_type == "background_point":
242
+ sel_pix.append((evt.index, 0)) # append the background_point
243
+ else:
244
+ sel_pix.append((evt.index, 1)) # default foreground_point
245
+
246
+ masks = SAM_PREDICTOR.generate_masks(image, sel_pix)
247
+ seg_image = SAM_PREDICTOR.get_segmented_image(image, masks)
248
+
249
+ for point, label in sel_pix:
250
+ color = (255, 0, 0) if label == 0 else (0, 255, 0)
251
+ marker_type = 1 if label == 0 else 5
252
+ cv2.drawMarker(
253
+ image,
254
+ point,
255
+ color,
256
+ markerType=marker_type,
257
+ markerSize=15,
258
+ thickness=10,
259
+ )
260
+
261
+ torch.cuda.empty_cache()
262
+
263
+ return (image, masks), seg_image
264
+
265
+
266
+ @spaces.GPU(duration=300)
267
+ def image_to_3d(
268
+ image: Image.Image,
269
+ seed: int,
270
+ ss_sampling_steps: int,
271
+ slat_sampling_steps: int,
272
+ raw_image_cache: Image.Image,
273
+ ss_guidance_strength: float,
274
+ slat_guidance_strength: float,
275
+ sam_image: Image.Image = None,
276
+ is_sam_image: bool = False,
277
+ req: gr.Request = None,
278
+ ) -> tuple[dict, str]:
279
+ if is_sam_image:
280
+ seg_image = filter_image_small_connected_components(sam_image)
281
+ seg_image = Image.fromarray(seg_image, mode="RGBA")
282
+ else:
283
+ seg_image = image
284
+
285
+ if isinstance(seg_image, np.ndarray):
286
+ seg_image = Image.fromarray(seg_image)
287
+
288
+ logger.info("Start generating 3D representation from image...")
289
+ if isinstance(PIPELINE, Sam3dInference):
290
+ outputs = PIPELINE.run(
291
+ seg_image,
292
+ seed=seed,
293
+ stage1_inference_steps=ss_sampling_steps,
294
+ stage2_inference_steps=slat_sampling_steps,
295
+ )
296
+ else:
297
+ PIPELINE.cuda()
298
+ seg_image = trellis_preprocess(seg_image)
299
+ outputs = PIPELINE.run(
300
+ seg_image,
301
+ seed=seed,
302
+ formats=["gaussian", "mesh"],
303
+ preprocess_image=False,
304
+ sparse_structure_sampler_params={
305
+ "steps": ss_sampling_steps,
306
+ "cfg_strength": ss_guidance_strength,
307
+ },
308
+ slat_sampler_params={
309
+ "steps": slat_sampling_steps,
310
+ "cfg_strength": slat_guidance_strength,
311
+ },
312
+ )
313
+ # Set back to cpu for memory saving.
314
+ PIPELINE.cpu()
315
+
316
+ gs_model = outputs["gaussian"][0]
317
+ mesh_model = outputs["mesh"][0]
318
+ color_images = render_video(gs_model, r=1.85)["color"]
319
+ normal_images = render_video(mesh_model, r=1.85)["normal"]
320
+
321
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
322
+ os.makedirs(output_root, exist_ok=True)
323
+ seg_image.save(f"{output_root}/seg_image.png")
324
+ raw_image_cache.save(f"{output_root}/raw_image.png")
325
+
326
+ video_path = os.path.join(output_root, "gs_mesh.mp4")
327
+ merge_images_video(color_images, normal_images, video_path)
328
+ state = pack_state(gs_model, mesh_model)
329
+
330
+ gc.collect()
331
+ torch.cuda.empty_cache()
332
+
333
+ return state, video_path
334
+
335
+
336
+ def extract_3d_representations_v2(
337
+ state: dict,
338
+ enable_delight: bool,
339
+ texture_size: int,
340
+ req: gr.Request,
341
+ ):
342
+ """Back-Projection Version of Texture Super-Resolution."""
343
+ output_root = TMP_DIR
344
+ user_dir = os.path.join(output_root, str(req.session_hash))
345
+ gs_model, mesh_model = unpack_state(state, device="cpu")
346
+
347
+ filename = "sample"
348
+ gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
349
+ gs_model.save_ply(gs_path)
350
+
351
+ # Rotate mesh and GS by 90 degrees around Z-axis.
352
+ rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
353
+ gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
354
+ mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
355
+
356
+ # Addtional rotation for GS to align mesh.
357
+ gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
358
+ pose = GaussianOperator.trans_to_quatpose(gs_rot)
359
+ aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
360
+ GaussianOperator.resave_ply(
361
+ in_ply=gs_path,
362
+ out_ply=aligned_gs_path,
363
+ instance_pose=pose,
364
+ device="cpu",
365
+ )
366
+ color_path = os.path.join(user_dir, "color.png")
367
+ render_gs_api(
368
+ input_gs=aligned_gs_path,
369
+ output_path=color_path,
370
+ elevation=[20, -10, 60, -50],
371
+ num_images=12,
372
+ )
373
+
374
+ mesh = trimesh.Trimesh(
375
+ vertices=mesh_model.vertices.cpu().numpy(),
376
+ faces=mesh_model.faces.cpu().numpy(),
377
+ )
378
+ mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
379
+ mesh.vertices = mesh.vertices @ np.array(rot_matrix)
380
+
381
+ mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
382
+ mesh.export(mesh_obj_path)
383
+
384
+ mesh = backproject_api(
385
+ delight_model=DELIGHT,
386
+ imagesr_model=IMAGESR_MODEL,
387
+ color_path=color_path,
388
+ mesh_path=mesh_obj_path,
389
+ output_path=mesh_obj_path,
390
+ skip_fix_mesh=False,
391
+ delight=enable_delight,
392
+ texture_wh=[texture_size, texture_size],
393
+ elevation=[20, -10, 60, -50],
394
+ num_images=12,
395
+ )
396
+
397
+ mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
398
+ mesh.export(mesh_glb_path)
399
+
400
+ return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
401
+
402
+
403
+ def extract_3d_representations_v3(
404
+ state: dict,
405
+ enable_delight: bool,
406
+ texture_size: int,
407
+ req: gr.Request,
408
+ ):
409
+ """Back-Projection Version with Optimization-Based."""
410
+ output_root = TMP_DIR
411
+ user_dir = os.path.join(output_root, str(req.session_hash))
412
+ gs_model, mesh_model = unpack_state(state, device="cpu")
413
+
414
+ filename = "sample"
415
+ gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
416
+ gs_model.save_ply(gs_path)
417
+
418
+ # Rotate mesh and GS by 90 degrees around Z-axis.
419
+ rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
420
+ gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
421
+ mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
422
+
423
+ # Addtional rotation for GS to align mesh.
424
+ gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
425
+ pose = GaussianOperator.trans_to_quatpose(gs_rot)
426
+ aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
427
+ GaussianOperator.resave_ply(
428
+ in_ply=gs_path,
429
+ out_ply=aligned_gs_path,
430
+ instance_pose=pose,
431
+ device="cpu",
432
+ )
433
+
434
+ mesh = trimesh.Trimesh(
435
+ vertices=mesh_model.vertices.cpu().numpy(),
436
+ faces=mesh_model.faces.cpu().numpy(),
437
+ )
438
+ mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
439
+ mesh.vertices = mesh.vertices @ np.array(rot_matrix)
440
+
441
+ mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
442
+ mesh.export(mesh_obj_path)
443
+
444
+ mesh = backproject_api_v3(
445
+ gs_path=aligned_gs_path,
446
+ mesh_path=mesh_obj_path,
447
+ output_path=mesh_obj_path,
448
+ skip_fix_mesh=False,
449
+ texture_size=texture_size,
450
+ )
451
+
452
+ mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
453
+ mesh.export(mesh_glb_path)
454
+
455
+ return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
456
+
457
+
458
+ def extract_urdf(
459
+ gs_path: str,
460
+ mesh_obj_path: str,
461
+ asset_cat_text: str,
462
+ height_range_text: str,
463
+ mass_range_text: str,
464
+ asset_version_text: str,
465
+ req: gr.Request = None,
466
+ ):
467
+ output_root = TMP_DIR
468
+ if req is not None:
469
+ output_root = os.path.join(output_root, str(req.session_hash))
470
+
471
+ # Convert to URDF and recover attrs by GPT.
472
+ filename = "sample"
473
+ urdf_convertor = URDFGenerator(
474
+ GPT_CLIENT, render_view_num=4, decompose_convex=True
475
+ )
476
+ asset_attrs = {
477
+ "version": VERSION,
478
+ "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
479
+ }
480
+ if asset_version_text:
481
+ asset_attrs["version"] = asset_version_text
482
+ if asset_cat_text:
483
+ asset_attrs["category"] = asset_cat_text.lower()
484
+ if height_range_text:
485
+ try:
486
+ min_height, max_height = map(float, height_range_text.split("-"))
487
+ asset_attrs["min_height"] = min_height
488
+ asset_attrs["max_height"] = max_height
489
+ except ValueError:
490
+ return "Invalid height input format. Use the format: min-max."
491
+ if mass_range_text:
492
+ try:
493
+ min_mass, max_mass = map(float, mass_range_text.split("-"))
494
+ asset_attrs["min_mass"] = min_mass
495
+ asset_attrs["max_mass"] = max_mass
496
+ except ValueError:
497
+ return "Invalid mass input format. Use the format: min-max."
498
+
499
+ urdf_path = urdf_convertor(
500
+ mesh_path=mesh_obj_path,
501
+ output_root=f"{output_root}/URDF_{filename}",
502
+ **asset_attrs,
503
+ )
504
+
505
+ # Rescale GS and save to URDF/mesh folder.
506
+ real_height = urdf_convertor.get_attr_from_urdf(
507
+ urdf_path, attr_name="real_height"
508
+ )
509
+ out_gs = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa
510
+ GaussianOperator.resave_ply(
511
+ in_ply=gs_path,
512
+ out_ply=out_gs,
513
+ real_height=real_height,
514
+ device="cpu",
515
+ )
516
+
517
+ # Quality check and update .urdf file.
518
+ mesh_out = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa
519
+ trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb"))
520
+ # image_paths = render_asset3d(
521
+ # mesh_path=mesh_out,
522
+ # output_root=f"{output_root}/URDF_{filename}",
523
+ # output_subdir="qa_renders",
524
+ # num_images=8,
525
+ # elevation=(30, -30),
526
+ # distance=5.5,
527
+ # )
528
+
529
+ image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa
530
+ image_paths = glob(f"{image_dir}/*.png")
531
+ images_list = []
532
+ for checker in CHECKERS:
533
+ images = image_paths
534
+ if isinstance(checker, ImageSegChecker):
535
+ images = [
536
+ f"{TMP_DIR}/{req.session_hash}/raw_image.png",
537
+ f"{TMP_DIR}/{req.session_hash}/seg_image.png",
538
+ ]
539
+ images_list.append(images)
540
+
541
+ results = BaseChecker.validate(CHECKERS, images_list)
542
+ urdf_convertor.add_quality_tag(urdf_path, results)
543
+
544
+ # Zip urdf files
545
+ urdf_zip = zip_files(
546
+ input_paths=[
547
+ f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}",
548
+ f"{output_root}/URDF_{filename}/{filename}.urdf",
549
+ ],
550
+ output_zip=f"{output_root}/urdf_{filename}.zip",
551
+ )
552
+
553
+ estimated_type = urdf_convertor.estimated_attrs["category"]
554
+ estimated_height = urdf_convertor.estimated_attrs["height"]
555
+ estimated_mass = urdf_convertor.estimated_attrs["mass"]
556
+ estimated_mu = urdf_convertor.estimated_attrs["mu"]
557
+
558
+ return (
559
+ urdf_zip,
560
+ estimated_type,
561
+ estimated_height,
562
+ estimated_mass,
563
+ estimated_mu,
564
+ )
565
+
566
+
567
+ @spaces.GPU(duration=300)
568
+ def text2image_fn(
569
+ prompt: str,
570
+ guidance_scale: float,
571
+ infer_step: int = 50,
572
+ ip_image: Image.Image | str = None,
573
+ ip_adapt_scale: float = 0.3,
574
+ image_wh: int | tuple[int, int] = [1024, 1024],
575
+ rmbg_tag: str = "rembg",
576
+ seed: int = None,
577
+ enable_pre_resize: bool = True,
578
+ n_sample: int = 3,
579
+ req: gr.Request = None,
580
+ ):
581
+ if isinstance(image_wh, int):
582
+ image_wh = (image_wh, image_wh)
583
+ output_root = TMP_DIR
584
+ if req is not None:
585
+ output_root = os.path.join(output_root, str(req.session_hash))
586
+ os.makedirs(output_root, exist_ok=True)
587
+
588
+ pipeline = PIPELINE_IMG if ip_image is None else PIPELINE_IMG_IP
589
+ if ip_image is not None:
590
+ pipeline.set_ip_adapter_scale([ip_adapt_scale])
591
+
592
+ images = text2img_gen(
593
+ prompt=prompt,
594
+ n_sample=n_sample,
595
+ guidance_scale=guidance_scale,
596
+ pipeline=pipeline,
597
+ ip_image=ip_image,
598
+ image_wh=image_wh,
599
+ infer_step=infer_step,
600
+ seed=seed,
601
+ )
602
+
603
+ for idx in range(len(images)):
604
+ image = images[idx]
605
+ images[idx], _ = preprocess_image_fn(
606
+ image, rmbg_tag, enable_pre_resize
607
+ )
608
+
609
+ save_paths = []
610
+ for idx, image in enumerate(images):
611
+ save_path = f"{output_root}/sample_{idx}.png"
612
+ image.save(save_path)
613
+ save_paths.append(save_path)
614
+
615
+ logger.info(f"Images saved to {output_root}")
616
+
617
+ gc.collect()
618
+ torch.cuda.empty_cache()
619
+
620
+ return save_paths + save_paths
621
+
622
+
623
+ @spaces.GPU(duration=120)
624
+ def generate_condition(mesh_path: str, req: gr.Request, uuid: str = "sample"):
625
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
626
+
627
+ _ = render_api(
628
+ mesh_path=mesh_path,
629
+ output_root=f"{output_root}/condition",
630
+ uuid=str(uuid),
631
+ )
632
+
633
+ gc.collect()
634
+ torch.cuda.empty_cache()
635
+
636
+ return None, None, None
637
+
638
+
639
+ @spaces.GPU(duration=300)
640
+ def generate_texture_mvimages(
641
+ prompt: str,
642
+ controlnet_cond_scale: float = 0.55,
643
+ guidance_scale: float = 9,
644
+ strength: float = 0.9,
645
+ num_inference_steps: int = 50,
646
+ seed: int = 0,
647
+ ip_adapt_scale: float = 0,
648
+ ip_img_path: str = None,
649
+ uid: str = "sample",
650
+ sub_idxs: tuple[tuple[int]] = ((0, 1, 2), (3, 4, 5)),
651
+ req: gr.Request = None,
652
+ ) -> list[str]:
653
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
654
+ use_ip_adapter = True if ip_img_path and ip_adapt_scale > 0 else False
655
+ PIPELINE_IP.set_ip_adapter_scale([ip_adapt_scale])
656
+ img_save_paths = infer_pipe(
657
+ index_file=f"{output_root}/condition/index.json",
658
+ controlnet_cond_scale=controlnet_cond_scale,
659
+ guidance_scale=guidance_scale,
660
+ strength=strength,
661
+ num_inference_steps=num_inference_steps,
662
+ ip_adapt_scale=ip_adapt_scale,
663
+ ip_img_path=ip_img_path,
664
+ uid=uid,
665
+ prompt=prompt,
666
+ save_dir=f"{output_root}/multi_view",
667
+ sub_idxs=sub_idxs,
668
+ pipeline=PIPELINE_IP if use_ip_adapter else PIPELINE,
669
+ seed=seed,
670
+ )
671
+
672
+ gc.collect()
673
+ torch.cuda.empty_cache()
674
+
675
+ return img_save_paths + img_save_paths
676
+
677
+
678
+ def backproject_texture(
679
+ mesh_path: str,
680
+ input_image: str,
681
+ texture_size: int,
682
+ uuid: str = "sample",
683
+ req: gr.Request = None,
684
+ ) -> str:
685
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
686
+ output_dir = os.path.join(output_root, "texture_mesh")
687
+ os.makedirs(output_dir, exist_ok=True)
688
+ command = [
689
+ "backproject-cli",
690
+ "--mesh_path",
691
+ mesh_path,
692
+ "--input_image",
693
+ input_image,
694
+ "--output_root",
695
+ output_dir,
696
+ "--uuid",
697
+ f"{uuid}",
698
+ "--texture_size",
699
+ str(texture_size),
700
+ "--skip_fix_mesh",
701
+ ]
702
+
703
+ _ = subprocess.run(
704
+ command, capture_output=True, text=True, encoding="utf-8"
705
+ )
706
+ output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
707
+ output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
708
+ _ = trimesh.load(output_obj_mesh).export(output_glb_mesh)
709
+
710
+ zip_file = zip_files(
711
+ input_paths=[
712
+ output_glb_mesh,
713
+ output_obj_mesh,
714
+ os.path.join(output_dir, "material.mtl"),
715
+ os.path.join(output_dir, "material_0.png"),
716
+ ],
717
+ output_zip=os.path.join(output_dir, f"{uuid}.zip"),
718
+ )
719
+
720
+ gc.collect()
721
+ torch.cuda.empty_cache()
722
+
723
+ return output_glb_mesh, output_obj_mesh, zip_file
724
+
725
+
726
+ @spaces.GPU(duration=300)
727
+ def backproject_texture_v2(
728
+ mesh_path: str,
729
+ input_image: str,
730
+ texture_size: int,
731
+ enable_delight: bool = True,
732
+ fix_mesh: bool = False,
733
+ no_mesh_post_process: bool = False,
734
+ uuid: str = "sample",
735
+ req: gr.Request = None,
736
+ ) -> str:
737
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
738
+ output_dir = os.path.join(output_root, "texture_mesh")
739
+ os.makedirs(output_dir, exist_ok=True)
740
+
741
+ textured_mesh = backproject_api(
742
+ delight_model=DELIGHT,
743
+ imagesr_model=IMAGESR_MODEL,
744
+ color_path=input_image,
745
+ mesh_path=mesh_path,
746
+ output_path=f"{output_dir}/{uuid}.obj",
747
+ skip_fix_mesh=not fix_mesh,
748
+ delight=enable_delight,
749
+ texture_wh=[texture_size, texture_size],
750
+ no_mesh_post_process=no_mesh_post_process,
751
+ )
752
+
753
+ output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
754
+ output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
755
+ _ = textured_mesh.export(output_glb_mesh)
756
+
757
+ zip_file = zip_files(
758
+ input_paths=[
759
+ output_glb_mesh,
760
+ output_obj_mesh,
761
+ os.path.join(output_dir, "material.mtl"),
762
+ os.path.join(output_dir, "material_0.png"),
763
+ ],
764
+ output_zip=os.path.join(output_dir, f"{uuid}.zip"),
765
+ )
766
+
767
+ gc.collect()
768
+ torch.cuda.empty_cache()
769
+
770
+ return output_glb_mesh, output_obj_mesh, zip_file
771
+
772
+
773
+ @spaces.GPU(duration=120)
774
+ def render_result_video(
775
+ mesh_path: str, video_size: int, req: gr.Request, uuid: str = ""
776
+ ) -> str:
777
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
778
+ output_dir = os.path.join(output_root, "texture_mesh")
779
+
780
+ _ = render_api(
781
+ mesh_path=mesh_path,
782
+ output_root=output_dir,
783
+ num_images=90,
784
+ elevation=[20],
785
+ with_mtl=True,
786
+ pbr_light_factor=1,
787
+ uuid=str(uuid),
788
+ gen_color_mp4=True,
789
+ gen_glonormal_mp4=True,
790
+ distance=5.5,
791
+ resolution_hw=(video_size, video_size),
792
+ )
793
+
794
+ gc.collect()
795
+ torch.cuda.empty_cache()
796
+
797
+ return f"{output_dir}/color.mp4"
common.py CHANGED
@@ -88,38 +88,20 @@ MAX_SEED = 100000
88
  # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
89
  # IMAGESR_MODEL = ImageStableSR()
90
  if os.getenv("GRADIO_APP").startswith("imageto3d"):
91
- print("[INIT 1/7] Loading RembgRemover ...", flush=True)
92
  RBG_REMOVER = RembgRemover()
93
- print("[INIT 1/7] RembgRemover done.", flush=True)
94
-
95
- print("[INIT 2/7] Loading BMGG14Remover ...", flush=True)
96
  RBG14_REMOVER = BMGG14Remover()
97
- print("[INIT 2/7] BMGG14Remover done.", flush=True)
98
-
99
- print("[INIT 3/7] Loading SAMPredictor(cpu) ...", flush=True)
100
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
101
- print("[INIT 3/7] SAMPredictor done.", flush=True)
102
-
103
- if "sam3d" in os.getenv("GRADIO_APP"):
104
- print("[INIT 4/7] Loading Sam3dInference ...", flush=True)
105
- PIPELINE = Sam3dInference()
106
- print("[INIT 4/7] Sam3dInference done.", flush=True)
107
- else:
108
- print("[INIT 4/7] Loading TrellisImageTo3DPipeline ...", flush=True)
109
- PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
110
- "microsoft/TRELLIS-image-large"
111
- )
112
- print("[INIT 4/7] TrellisImageTo3DPipeline done.", flush=True)
113
- # PIPELINE.cuda()
114
- print("[INIT 5/7] Loading SEG_CHECKER ...", flush=True)
115
- SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
116
- print("[INIT 5/7] SEG_CHECKER done.", flush=True)
117
-
118
- print("[INIT 6/7] Loading GEO_CHECKER + AESTHETIC_CHECKER ...", flush=True)
119
- GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
120
- AESTHETIC_CHECKER = ImageAestheticChecker()
121
- CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
122
- print("[INIT 6/7] Checkers done.", flush=True)
123
  TMP_DIR = os.path.join(
124
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
125
  )
@@ -281,22 +263,8 @@ def select_point(
281
  return (image, masks), seg_image
282
 
283
 
284
- @spaces.GPU(duration=60)
285
- def _gpu_alloc_test():
286
- """Minimal @spaces.GPU test - no model tensors, just GPU allocation."""
287
- import torch, time as _t
288
- print(f"[GPU-ALLOC-TEST] entered function body at {_t.strftime('%H:%M:%S')}", flush=True)
289
- print(f"[GPU-ALLOC-TEST] cuda.is_available={torch.cuda.is_available()}", flush=True)
290
- if torch.cuda.is_available():
291
- x = torch.randn(4, 4, device="cuda")
292
- print(f"[GPU-ALLOC-TEST] test tensor OK: {x.device}", flush=True)
293
- del x
294
- print("[GPU-ALLOC-TEST] done!", flush=True)
295
- return True
296
-
297
-
298
- @spaces.GPU(duration=100)
299
- def _image_to_3d_inner(
300
  image: Image.Image,
301
  seed: int,
302
  ss_sampling_steps: int,
@@ -308,7 +276,6 @@ def _image_to_3d_inner(
308
  is_sam_image: bool = False,
309
  req: gr.Request = None,
310
  ) -> tuple[dict, str]:
311
- print("[STEP 0] >>>>>> image_to_3d function body entered! <<<<<<", flush=True)
312
  if is_sam_image:
313
  seg_image = filter_image_small_connected_components(sam_image)
314
  seg_image = Image.fromarray(seg_image, mode="RGBA")
@@ -318,24 +285,16 @@ def _image_to_3d_inner(
318
  if isinstance(seg_image, np.ndarray):
319
  seg_image = Image.fromarray(seg_image)
320
 
321
- print("[STEP 1] image_to_3d entered, cuda available:", torch.cuda.is_available(), flush=True)
322
- if torch.cuda.is_available():
323
- print("[STEP 1] device:", torch.cuda.get_device_name(0), flush=True)
324
-
325
  logger.info("Start generating 3D representation from image...")
326
  if isinstance(PIPELINE, Sam3dInference):
327
- print("[STEP 2] Calling PIPELINE.run (Sam3dInference) ...", flush=True)
328
  outputs = PIPELINE.run(
329
  seg_image,
330
  seed=seed,
331
  stage1_inference_steps=ss_sampling_steps,
332
  stage2_inference_steps=slat_sampling_steps,
333
  )
334
- print("[STEP 2] PIPELINE.run done.", flush=True)
335
  else:
336
- print("[STEP 2] Moving PIPELINE to cuda ...", flush=True)
337
  PIPELINE.cuda()
338
- print("[STEP 2] PIPELINE.cuda() done. Running inference ...", flush=True)
339
  seg_image = trellis_preprocess(seg_image)
340
  outputs = PIPELINE.run(
341
  seg_image,
@@ -351,76 +310,29 @@ def _image_to_3d_inner(
351
  "cfg_strength": slat_guidance_strength,
352
  },
353
  )
354
- print("[STEP 2] PIPELINE.run done. Moving back to cpu ...", flush=True)
355
  # Set back to cpu for memory saving.
356
  PIPELINE.cpu()
357
 
358
- print("[STEP 3] Extracting gs_model and mesh_model ...", flush=True)
359
  gs_model = outputs["gaussian"][0]
360
  mesh_model = outputs["mesh"][0]
361
-
362
- print("[STEP 4] Rendering color video ...", flush=True)
363
  color_images = render_video(gs_model, r=1.85)["color"]
364
- print("[STEP 4] Rendering normal video ...", flush=True)
365
  normal_images = render_video(mesh_model, r=1.85)["normal"]
366
- print("[STEP 4] Render done.", flush=True)
367
 
368
  output_root = os.path.join(TMP_DIR, str(req.session_hash))
369
  os.makedirs(output_root, exist_ok=True)
370
  seg_image.save(f"{output_root}/seg_image.png")
371
  raw_image_cache.save(f"{output_root}/raw_image.png")
372
 
373
- print("[STEP 5] Merging video and packing state ...", flush=True)
374
  video_path = os.path.join(output_root, "gs_mesh.mp4")
375
  merge_images_video(color_images, normal_images, video_path)
376
  state = pack_state(gs_model, mesh_model)
377
 
378
  gc.collect()
379
  torch.cuda.empty_cache()
380
- print("[STEP 6] image_to_3d done!", flush=True)
381
 
382
  return state, video_path
383
 
384
 
385
- def image_to_3d(
386
- image,
387
- seed,
388
- ss_sampling_steps,
389
- slat_sampling_steps,
390
- raw_image_cache,
391
- ss_guidance_strength,
392
- slat_guidance_strength,
393
- sam_image=None,
394
- is_sam_image=False,
395
- req=None,
396
- ):
397
- """Wrapper outside @spaces.GPU to diagnose where the hang occurs."""
398
- import time as _time
399
- _t0 = _time.time()
400
- print(f"[WRAPPER] image_to_3d called at {_time.strftime('%H:%M:%S')}", flush=True)
401
- print(f"[WRAPPER] Step 1: calling _gpu_alloc_test (minimal @spaces.GPU) ...", flush=True)
402
- try:
403
- _gpu_alloc_test()
404
- print(f"[WRAPPER] Step 1 done in {_time.time()-_t0:.1f}s. GPU alloc works!", flush=True)
405
- except Exception as e:
406
- print(f"[WRAPPER] Step 1 _gpu_alloc_test FAILED: {type(e).__name__}: {e}", flush=True)
407
- raise
408
-
409
- _t1 = _time.time()
410
- print(f"[WRAPPER] Step 2: calling _image_to_3d_inner (heavy, 13.7G tensors) ...", flush=True)
411
- try:
412
- result = _image_to_3d_inner(
413
- image, seed, ss_sampling_steps, slat_sampling_steps,
414
- raw_image_cache, ss_guidance_strength, slat_guidance_strength,
415
- sam_image, is_sam_image, req,
416
- )
417
- print(f"[WRAPPER] _image_to_3d_inner returned in {_time.time()-_t1:.1f}s (total {_time.time()-_t0:.1f}s)", flush=True)
418
- return result
419
- except Exception as e:
420
- print(f"[WRAPPER] _image_to_3d_inner FAILED after {_time.time()-_t1:.1f}s: {type(e).__name__}: {e}", flush=True)
421
- raise
422
-
423
-
424
  def extract_3d_representations_v2(
425
  state: dict,
426
  enable_delight: bool,
 
88
  # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
89
  # IMAGESR_MODEL = ImageStableSR()
90
  if os.getenv("GRADIO_APP").startswith("imageto3d"):
 
91
  RBG_REMOVER = RembgRemover()
 
 
 
92
  RBG14_REMOVER = BMGG14Remover()
 
 
 
93
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
94
+ # if "sam3d" in os.getenv("GRADIO_APP"):
95
+ # PIPELINE = Sam3dInference()
96
+ # else:
97
+ # PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
98
+ # "microsoft/TRELLIS-image-large"
99
+ # )
100
+ # # PIPELINE.cuda()
101
+ # SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
102
+ # GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
103
+ # AESTHETIC_CHECKER = ImageAestheticChecker()
104
+ # CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
 
 
 
 
 
 
 
 
 
 
 
105
  TMP_DIR = os.path.join(
106
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
107
  )
 
263
  return (image, masks), seg_image
264
 
265
 
266
+ @spaces.GPU(duration=300)
267
+ def image_to_3d(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  image: Image.Image,
269
  seed: int,
270
  ss_sampling_steps: int,
 
276
  is_sam_image: bool = False,
277
  req: gr.Request = None,
278
  ) -> tuple[dict, str]:
 
279
  if is_sam_image:
280
  seg_image = filter_image_small_connected_components(sam_image)
281
  seg_image = Image.fromarray(seg_image, mode="RGBA")
 
285
  if isinstance(seg_image, np.ndarray):
286
  seg_image = Image.fromarray(seg_image)
287
 
 
 
 
 
288
  logger.info("Start generating 3D representation from image...")
289
  if isinstance(PIPELINE, Sam3dInference):
 
290
  outputs = PIPELINE.run(
291
  seg_image,
292
  seed=seed,
293
  stage1_inference_steps=ss_sampling_steps,
294
  stage2_inference_steps=slat_sampling_steps,
295
  )
 
296
  else:
 
297
  PIPELINE.cuda()
 
298
  seg_image = trellis_preprocess(seg_image)
299
  outputs = PIPELINE.run(
300
  seg_image,
 
310
  "cfg_strength": slat_guidance_strength,
311
  },
312
  )
 
313
  # Set back to cpu for memory saving.
314
  PIPELINE.cpu()
315
 
 
316
  gs_model = outputs["gaussian"][0]
317
  mesh_model = outputs["mesh"][0]
 
 
318
  color_images = render_video(gs_model, r=1.85)["color"]
 
319
  normal_images = render_video(mesh_model, r=1.85)["normal"]
 
320
 
321
  output_root = os.path.join(TMP_DIR, str(req.session_hash))
322
  os.makedirs(output_root, exist_ok=True)
323
  seg_image.save(f"{output_root}/seg_image.png")
324
  raw_image_cache.save(f"{output_root}/raw_image.png")
325
 
 
326
  video_path = os.path.join(output_root, "gs_mesh.mp4")
327
  merge_images_video(color_images, normal_images, video_path)
328
  state = pack_state(gs_model, mesh_model)
329
 
330
  gc.collect()
331
  torch.cuda.empty_cache()
 
332
 
333
  return state, video_path
334
 
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  def extract_3d_representations_v2(
337
  state: dict,
338
  enable_delight: bool,
embodied_gen/models/sam3d.py CHANGED
@@ -22,8 +22,7 @@ import sys
22
 
23
  import numpy as np
24
  from hydra.utils import instantiate
25
- # from modelscope import snapshot_download
26
- from huggingface_hub import snapshot_download
27
  from omegaconf import OmegaConf
28
  from PIL import Image
29
 
@@ -65,12 +64,8 @@ class Sam3dInference:
65
  def __init__(
66
  self, local_dir: str = "weights/sam-3d-objects", compile: bool = False
67
  ) -> None:
68
- print("[SAM3D-INIT] Starting Sam3dInference.__init__", flush=True)
69
  if not os.path.exists(local_dir):
70
- print("[SAM3D-INIT] Downloading weights ...", flush=True)
71
- # snapshot_download("facebook/sam-3d-objects", local_dir=local_dir)
72
- snapshot_download(repo_id="tuandao-zenai/sam-3d-objects", local_dir=local_dir)
73
- print("[SAM3D-INIT] Download done.", flush=True)
74
  config_file = os.path.join(local_dir, "checkpoints/pipeline.yaml")
75
  config = OmegaConf.load(config_file)
76
  config.rendering_engine = "nvdiffrast"
@@ -83,9 +78,7 @@ class Sam3dInference:
83
  config["slat_decoder_gs_ckpt_path"] = config.pop(
84
  "slat_decoder_gs_4_ckpt_path", "slat_decoder_gs_4.ckpt"
85
  )
86
- print("[SAM3D-INIT] Instantiating InferencePipelinePointMap ...", flush=True)
87
  self.pipeline: InferencePipelinePointMap = instantiate(config)
88
- print("[SAM3D-INIT] Sam3dInference.__init__ done.", flush=True)
89
 
90
  def merge_mask_to_rgba(
91
  self, image: np.ndarray, mask: np.ndarray
@@ -107,15 +100,11 @@ class Sam3dInference:
107
  stage1_inference_steps: int = 25,
108
  stage2_inference_steps: int = 25,
109
  ) -> dict:
110
- print("[SAM3D-RUN] Entering Sam3dInference.run", flush=True)
111
  if isinstance(image, Image.Image):
112
  image = np.array(image)
113
  if mask is not None:
114
  image = self.merge_mask_to_rgba(image, mask)
115
- print(f"[SAM3D-RUN] image shape: {image.shape}, dtype: {image.dtype}", flush=True)
116
- print(f"[SAM3D-RUN] seed={seed}, stage1_steps={stage1_inference_steps}, stage2_steps={stage2_inference_steps}", flush=True)
117
- print("[SAM3D-RUN] Calling self.pipeline.run ...", flush=True)
118
- result = self.pipeline.run(
119
  image,
120
  None,
121
  seed,
@@ -130,8 +119,6 @@ class Sam3dInference:
130
  stage2_inference_steps=stage2_inference_steps,
131
  pointmap=pointmap,
132
  )
133
- print("[SAM3D-RUN] self.pipeline.run returned.", flush=True)
134
- return result
135
 
136
 
137
  if __name__ == "__main__":
 
22
 
23
  import numpy as np
24
  from hydra.utils import instantiate
25
+ from modelscope import snapshot_download
 
26
  from omegaconf import OmegaConf
27
  from PIL import Image
28
 
 
64
  def __init__(
65
  self, local_dir: str = "weights/sam-3d-objects", compile: bool = False
66
  ) -> None:
 
67
  if not os.path.exists(local_dir):
68
+ snapshot_download("facebook/sam-3d-objects", local_dir=local_dir)
 
 
 
69
  config_file = os.path.join(local_dir, "checkpoints/pipeline.yaml")
70
  config = OmegaConf.load(config_file)
71
  config.rendering_engine = "nvdiffrast"
 
78
  config["slat_decoder_gs_ckpt_path"] = config.pop(
79
  "slat_decoder_gs_4_ckpt_path", "slat_decoder_gs_4.ckpt"
80
  )
 
81
  self.pipeline: InferencePipelinePointMap = instantiate(config)
 
82
 
83
  def merge_mask_to_rgba(
84
  self, image: np.ndarray, mask: np.ndarray
 
100
  stage1_inference_steps: int = 25,
101
  stage2_inference_steps: int = 25,
102
  ) -> dict:
 
103
  if isinstance(image, Image.Image):
104
  image = np.array(image)
105
  if mask is not None:
106
  image = self.merge_mask_to_rgba(image, mask)
107
+ return self.pipeline.run(
 
 
 
108
  image,
109
  None,
110
  seed,
 
119
  stage2_inference_steps=stage2_inference_steps,
120
  pointmap=pointmap,
121
  )
 
 
122
 
123
 
124
  if __name__ == "__main__":
embodied_gen/utils/monkey_patch/sam3d.py CHANGED
@@ -40,7 +40,7 @@ def monkey_patch_sam3d():
40
  if sam3d_root not in sys.path:
41
  sys.path.insert(0, sam3d_root)
42
 
43
- def patch_pointmap_infer_pipeline():
44
  """Patches InferencePipelinePointMap.run to handle pointmap generation and 3D structure sampling."""
45
  try:
46
  from sam3d_objects.pipeline.inference_pipeline_pointmap import (
@@ -202,7 +202,7 @@ def monkey_patch_sam3d():
202
 
203
  InferencePipelinePointMap.run = patch_run
204
 
205
- def patch_infer_init():
206
  """Patches InferencePipeline.__init__ to allow CPU offloading during model initialization."""
207
  import torch
208
 
@@ -380,7 +380,7 @@ def monkey_patch_sam3d():
380
 
381
  InferencePipeline.__init__ = patch_init
382
 
383
- # patch_pointmap_infer_pipeline()
384
- # patch_infer_init()
385
 
386
  return
 
40
  if sam3d_root not in sys.path:
41
  sys.path.insert(0, sam3d_root)
42
 
43
+ def # patch_pointmap_infer_pipeline():
44
  """Patches InferencePipelinePointMap.run to handle pointmap generation and 3D structure sampling."""
45
  try:
46
  from sam3d_objects.pipeline.inference_pipeline_pointmap import (
 
202
 
203
  InferencePipelinePointMap.run = patch_run
204
 
205
+ def # patch_infer_init():
206
  """Patches InferencePipeline.__init__ to allow CPU offloading during model initialization."""
207
  import torch
208
 
 
380
 
381
  InferencePipeline.__init__ = patch_init
382
 
383
+ # # patch_pointmap_infer_pipeline()
384
+ # # patch_infer_init()
385
 
386
  return
thirdparty/sam3d/sam3d_objects/pipeline/inference_pipeline.py CHANGED
@@ -98,7 +98,6 @@ class InferencePipeline:
98
  logger.info(f"self.device: {self.device}")
99
  logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}")
100
  logger.info(f"Actually using GPU: {torch.cuda.current_device()}")
101
- print(f"[PIPE-INIT] entering with self.device ({self.device}) ...", flush=True)
102
  with self.device:
103
  self.decode_formats = decode_formats
104
  self.pad_size = pad_size
@@ -130,41 +129,33 @@ class InferencePipeline:
130
  self.slat_preprocessor = slat_preprocessor
131
 
132
  logger.info("Loading model weights...")
133
- print("[PIPE-INIT] Loading ss_generator ...", flush=True)
134
  ss_generator = self.init_ss_generator(
135
  ss_generator_config_path, ss_generator_ckpt_path
136
  )
137
- print("[PIPE-INIT] Loading slat_generator ...", flush=True)
138
  slat_generator = self.init_slat_generator(
139
  slat_generator_config_path, slat_generator_ckpt_path
140
  )
141
- print("[PIPE-INIT] Loading ss_decoder ...", flush=True)
142
  ss_decoder = self.init_ss_decoder(
143
  ss_decoder_config_path, ss_decoder_ckpt_path
144
  )
145
- print("[PIPE-INIT] Loading ss_encoder ...", flush=True)
146
  ss_encoder = self.init_ss_encoder(
147
  ss_encoder_config_path, ss_encoder_ckpt_path
148
  )
149
- print("[PIPE-INIT] Loading slat_decoder_gs ...", flush=True)
150
  slat_decoder_gs = self.init_slat_decoder_gs(
151
  slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path
152
  )
153
- print("[PIPE-INIT] Loading slat_decoder_gs_4 ...", flush=True)
154
  slat_decoder_gs_4 = self.init_slat_decoder_gs(
155
  slat_decoder_gs_4_config_path, slat_decoder_gs_4_ckpt_path
156
  )
157
- print("[PIPE-INIT] Loading slat_decoder_mesh ...", flush=True)
158
  slat_decoder_mesh = self.init_slat_decoder_mesh(
159
  slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path
160
  )
161
 
162
  # Load conditioner embedder so that we only load it once
163
- print("[PIPE-INIT] Loading ss_condition_embedder ...", flush=True)
164
  ss_condition_embedder = self.init_ss_condition_embedder(
165
  ss_generator_config_path, ss_generator_ckpt_path
166
  )
167
- print("[PIPE-INIT] Loading slat_condition_embedder ...", flush=True)
168
  slat_condition_embedder = self.init_slat_condition_embedder(
169
  slat_generator_config_path, slat_generator_ckpt_path
170
  )
@@ -202,7 +193,6 @@ class InferencePipeline:
202
  "slat_decoder_mesh": slat_decoder_mesh,
203
  }
204
  )
205
- print("[PIPE-INIT] All models loaded into ModuleDict.", flush=True)
206
  logger.info("Loading model weights completed!")
207
 
208
  if self.compile_model:
 
98
  logger.info(f"self.device: {self.device}")
99
  logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}")
100
  logger.info(f"Actually using GPU: {torch.cuda.current_device()}")
 
101
  with self.device:
102
  self.decode_formats = decode_formats
103
  self.pad_size = pad_size
 
129
  self.slat_preprocessor = slat_preprocessor
130
 
131
  logger.info("Loading model weights...")
132
+
133
  ss_generator = self.init_ss_generator(
134
  ss_generator_config_path, ss_generator_ckpt_path
135
  )
 
136
  slat_generator = self.init_slat_generator(
137
  slat_generator_config_path, slat_generator_ckpt_path
138
  )
 
139
  ss_decoder = self.init_ss_decoder(
140
  ss_decoder_config_path, ss_decoder_ckpt_path
141
  )
 
142
  ss_encoder = self.init_ss_encoder(
143
  ss_encoder_config_path, ss_encoder_ckpt_path
144
  )
 
145
  slat_decoder_gs = self.init_slat_decoder_gs(
146
  slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path
147
  )
 
148
  slat_decoder_gs_4 = self.init_slat_decoder_gs(
149
  slat_decoder_gs_4_config_path, slat_decoder_gs_4_ckpt_path
150
  )
 
151
  slat_decoder_mesh = self.init_slat_decoder_mesh(
152
  slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path
153
  )
154
 
155
  # Load conditioner embedder so that we only load it once
 
156
  ss_condition_embedder = self.init_ss_condition_embedder(
157
  ss_generator_config_path, ss_generator_ckpt_path
158
  )
 
159
  slat_condition_embedder = self.init_slat_condition_embedder(
160
  slat_generator_config_path, slat_generator_ckpt_path
161
  )
 
193
  "slat_decoder_mesh": slat_decoder_mesh,
194
  }
195
  )
 
196
  logger.info("Loading model weights completed!")
197
 
198
  if self.compile_model:
thirdparty/sam3d/sam3d_objects/pipeline/inference_pipeline_pointmap.py CHANGED
@@ -332,11 +332,8 @@ class InferencePipelinePointMap(InferencePipeline):
332
  estimate_plane=False,
333
  ) -> dict:
334
  image = self.merge_image_and_mask(image, mask)
335
- print(f"[PIPE-RUN] entering with self.device ({self.device}) ...", flush=True)
336
  with self.device:
337
- print("[PIPE-RUN] compute_pointmap ...", flush=True)
338
  pointmap_dict = self.compute_pointmap(image, pointmap)
339
- print("[PIPE-RUN] compute_pointmap done.", flush=True)
340
  pointmap = pointmap_dict["pointmap"]
341
  pts = type(self)._down_sample_img(pointmap)
342
  pts_colors = type(self)._down_sample_img(pointmap_dict["pts_color"])
@@ -344,21 +341,18 @@ class InferencePipelinePointMap(InferencePipeline):
344
  if estimate_plane:
345
  return self.estimate_plane(pointmap_dict, image)
346
 
347
- print("[PIPE-RUN] preprocess_image (ss) ...", flush=True)
348
  ss_input_dict = self.preprocess_image(
349
  image, self.ss_preprocessor, pointmap=pointmap
350
  )
351
- print("[PIPE-RUN] preprocess_image (slat) ...", flush=True)
352
  slat_input_dict = self.preprocess_image(image, self.slat_preprocessor)
353
  if seed is not None:
354
  torch.manual_seed(seed)
355
- print("[PIPE-RUN] sample_sparse_structure (stage1) ...", flush=True)
356
  ss_return_dict = self.sample_sparse_structure(
357
  ss_input_dict,
358
  inference_steps=stage1_inference_steps,
359
  use_distillation=use_stage1_distillation,
360
  )
361
- print("[PIPE-RUN] sample_sparse_structure done.", flush=True)
362
 
363
  # We could probably use the decoder from the models themselves
364
  pointmap_scale = ss_input_dict.get("pointmap_scale", None)
@@ -385,20 +379,15 @@ class InferencePipelinePointMap(InferencePipeline):
385
  # return ss_return_dict
386
 
387
  coords = ss_return_dict["coords"]
388
- print("[PIPE-RUN] sample_slat (stage2) ...", flush=True)
389
  slat = self.sample_slat(
390
  slat_input_dict,
391
  coords,
392
  inference_steps=stage2_inference_steps,
393
  use_distillation=use_stage2_distillation,
394
  )
395
- print("[PIPE-RUN] sample_slat done.", flush=True)
396
- print("[PIPE-RUN] decode_slat ...", flush=True)
397
  outputs = self.decode_slat(
398
  slat, self.decode_formats if decode_formats is None else decode_formats
399
  )
400
- print("[PIPE-RUN] decode_slat done.", flush=True)
401
- print("[PIPE-RUN] postprocess_slat_output ...", flush=True)
402
  outputs = self.postprocess_slat_output(
403
  outputs, with_mesh_postprocess, with_texture_baking, use_vertex_color
404
  )
@@ -424,7 +413,6 @@ class InferencePipelinePointMap(InferencePipeline):
424
  )
425
 
426
  # glb.export("sample.glb")
427
- print("[PIPE-RUN] ALL DONE, returning results.", flush=True)
428
  logger.info("Finished!")
429
 
430
  return {
 
332
  estimate_plane=False,
333
  ) -> dict:
334
  image = self.merge_image_and_mask(image, mask)
 
335
  with self.device:
 
336
  pointmap_dict = self.compute_pointmap(image, pointmap)
 
337
  pointmap = pointmap_dict["pointmap"]
338
  pts = type(self)._down_sample_img(pointmap)
339
  pts_colors = type(self)._down_sample_img(pointmap_dict["pts_color"])
 
341
  if estimate_plane:
342
  return self.estimate_plane(pointmap_dict, image)
343
 
 
344
  ss_input_dict = self.preprocess_image(
345
  image, self.ss_preprocessor, pointmap=pointmap
346
  )
347
+
348
  slat_input_dict = self.preprocess_image(image, self.slat_preprocessor)
349
  if seed is not None:
350
  torch.manual_seed(seed)
 
351
  ss_return_dict = self.sample_sparse_structure(
352
  ss_input_dict,
353
  inference_steps=stage1_inference_steps,
354
  use_distillation=use_stage1_distillation,
355
  )
 
356
 
357
  # We could probably use the decoder from the models themselves
358
  pointmap_scale = ss_input_dict.get("pointmap_scale", None)
 
379
  # return ss_return_dict
380
 
381
  coords = ss_return_dict["coords"]
 
382
  slat = self.sample_slat(
383
  slat_input_dict,
384
  coords,
385
  inference_steps=stage2_inference_steps,
386
  use_distillation=use_stage2_distillation,
387
  )
 
 
388
  outputs = self.decode_slat(
389
  slat, self.decode_formats if decode_formats is None else decode_formats
390
  )
 
 
391
  outputs = self.postprocess_slat_output(
392
  outputs, with_mesh_postprocess, with_texture_baking, use_vertex_color
393
  )
 
413
  )
414
 
415
  # glb.export("sample.glb")
 
416
  logger.info("Finished!")
417
 
418
  return {