skrithik commited on
Commit
3d79f62
·
verified ·
1 Parent(s): 4f7479d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -186
app.py CHANGED
@@ -1,10 +1,8 @@
1
- """FASHN VTON v1.5 HuggingFace Space Demo."""
2
 
3
  import os
4
  import platform
5
-
6
  import gradio as gr
7
- import spaces
8
  import torch
9
  from huggingface_hub import hf_hub_download
10
  from PIL import Image
@@ -18,20 +16,17 @@ WEIGHTS_DIR = os.path.join(SCRIPT_DIR, "weights")
18
  CATEGORIES = ["tops", "bottoms", "one-pieces"]
19
  GARMENT_PHOTO_TYPES = ["model", "flat-lay"]
20
 
21
- # Global pipeline instance (lazy loaded)
22
  _pipeline = None
23
 
24
 
25
  # ----------------- HELPERS ----------------- #
26
 
27
-
28
  def download_weights():
29
  """Download model weights from HuggingFace Hub."""
30
  os.makedirs(WEIGHTS_DIR, exist_ok=True)
31
  dwpose_dir = os.path.join(WEIGHTS_DIR, "dwpose")
32
  os.makedirs(dwpose_dir, exist_ok=True)
33
 
34
- # Download TryOnModel weights
35
  tryon_path = os.path.join(WEIGHTS_DIR, "model.safetensors")
36
  if not os.path.exists(tryon_path):
37
  print("Downloading TryOnModel weights...")
@@ -41,7 +36,6 @@ def download_weights():
41
  local_dir=WEIGHTS_DIR,
42
  )
43
 
44
- # Download DWPose models
45
  dwpose_files = ["yolox_l.onnx", "dw-ll_ucoco_384.onnx"]
46
  for filename in dwpose_files:
47
  filepath = os.path.join(dwpose_dir, filename)
@@ -53,56 +47,45 @@ def download_weights():
53
  local_dir=dwpose_dir,
54
  )
55
 
56
- print("Weights downloaded successfully!")
57
 
58
 
59
  # ----------------- MODEL LOADING ----------------- #
60
 
61
-
62
  def get_pipeline():
63
- """Lazy-load the pipeline on first use (ensures GPU is available on ZeroGPU)."""
64
  global _pipeline
 
65
  if _pipeline is None:
66
- # Check CUDA availability (will be true inside @spaces.GPU context)
67
- if not torch.cuda.is_available():
68
- raise gr.Error(
69
- "CUDA is not available. This demo requires a GPU to run. "
70
- "If you're on HuggingFace Spaces, please try again in a moment."
71
- )
72
 
73
- # ---------------------------------- Diagnostics ---------------------------------- #
74
- print(f"Python : {platform.python_version()}")
75
- print(f"PyTorch : {torch.__version__}")
76
- print(f" • built for CUDA : {torch.version.cuda}")
77
- if torch.backends.cudnn.is_available():
78
- print(f" • built for cuDNN: {torch.backends.cudnn.version()}")
79
- print(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
80
  if torch.cuda.is_available():
81
  dev = torch.cuda.current_device()
82
- cc = torch.cuda.get_device_capability(dev)
83
- print(f"GPU {dev}: {torch.cuda.get_device_name(dev)} (compute {cc[0]}.{cc[1]})")
84
-
85
- # Enable TF32 for faster computation on Ampere+ GPUs
86
- if torch.cuda.get_device_properties(0).major >= 8:
87
- torch.backends.cuda.matmul.allow_tf32 = True
88
- torch.backends.cudnn.allow_tf32 = True
89
 
90
- print("Downloading weights (if needed)...")
91
  download_weights()
92
 
93
  print("Loading pipeline...")
94
  from fashn_vton import TryOnPipeline
95
 
96
- _pipeline = TryOnPipeline(weights_dir=WEIGHTS_DIR, device="cuda")
97
- print("Pipeline loaded on CUDA!")
 
 
 
 
98
 
99
  return _pipeline
100
 
101
 
102
  # ----------------- INFERENCE ----------------- #
103
 
104
-
105
- @spaces.GPU
106
  def try_on(
107
  person_image: Image.Image,
108
  garment_image: Image.Image,
@@ -113,26 +96,23 @@ def try_on(
113
  seed: int,
114
  segmentation_free: bool,
115
  ) -> Image.Image:
116
- """Run virtual try-on inference."""
117
  if person_image is None:
118
  raise gr.Error("Please upload a person image")
119
  if garment_image is None:
120
  raise gr.Error("Please upload a garment image")
121
 
122
- # Handle seed (guard against None or invalid values)
123
  if seed is None or seed < 0:
124
  seed = 42
125
 
126
- # Convert to RGB if needed
127
  if person_image.mode != "RGB":
128
  person_image = person_image.convert("RGB")
 
129
  if garment_image.mode != "RGB":
130
  garment_image = garment_image.convert("RGB")
131
 
132
- # Get pipeline (lazy loads on first call)
133
  pipeline = get_pipeline()
134
 
135
- # Run inference
136
  result = pipeline(
137
  person_image=person_image,
138
  garment_image=garment_image,
@@ -150,7 +130,6 @@ def try_on(
150
 
151
  # ----------------- UI ----------------- #
152
 
153
- # Custom CSS
154
  CUSTOM_CSS = """
155
  .contain img {
156
  object-fit: contain !important;
@@ -159,151 +138,40 @@ CUSTOM_CSS = """
159
  }
160
  """
161
 
162
- # Load HTML content
163
- with open(os.path.join(SCRIPT_DIR, "banner.html"), "r") as f:
164
- banner_html = f.read()
165
- with open(os.path.join(SCRIPT_DIR, "tips.html"), "r") as f:
166
- tips_html = f.read()
167
-
168
- # Build example paths
169
- examples_dir = os.path.join(ASSETS_DIR, "examples")
170
-
171
- # Paired examples: [person_path, garment_path, category, garment_photo_type]
172
- paired_examples = [
173
- [os.path.join(examples_dir, "person1.png"), os.path.join(examples_dir, "garment1.jpeg"), "one-pieces", "model"],
174
- [os.path.join(examples_dir, "person2.png"), os.path.join(examples_dir, "garment2.webp"), "tops", "model"],
175
- [os.path.join(examples_dir, "person3.png"), os.path.join(examples_dir, "garment3.jpeg"), "tops", "flat-lay"],
176
- [os.path.join(examples_dir, "person4.png"), os.path.join(examples_dir, "garment4.webp"), "tops", "model"],
177
- [os.path.join(examples_dir, "person5.png"), os.path.join(examples_dir, "garment5.jpeg"), "bottoms", "flat-lay"],
178
- [os.path.join(examples_dir, "person6.png"), os.path.join(examples_dir, "garment6.webp"), "one-pieces", "model"],
179
- ]
180
-
181
- # Individual examples (classic from repo)
182
- person_only_examples = [os.path.join(examples_dir, "person0.png")]
183
-
184
- # Garment examples with their settings: (image_path, category, photo_type)
185
- # Order matters - index in Gallery corresponds to this list
186
- garment_examples_data = [
187
- (os.path.join(examples_dir, "garment0.png"), "tops", "model"),
188
- (os.path.join(examples_dir, "garment7.jpg"), "tops", "flat-lay"),
189
- ]
190
- garment_gallery_images = [item[0] for item in garment_examples_data]
191
-
192
-
193
- def on_garment_gallery_select(evt: gr.SelectData):
194
- """Handle garment gallery selection - load image and update dropdowns."""
195
- idx = evt.index
196
- if idx < len(garment_examples_data):
197
- image_path, cat, photo_type = garment_examples_data[idx]
198
- return Image.open(image_path), cat, photo_type
199
- return None, "tops", "model"
200
-
201
-
202
- # Build UI
203
  with gr.Blocks(css=CUSTOM_CSS) as demo:
204
- # Header
205
- gr.HTML(banner_html)
206
- gr.HTML(tips_html)
207
-
208
- with gr.Row(equal_height=False):
209
- # Column 1: Person
210
- with gr.Column(scale=1):
211
- person_image = gr.Image(
212
- label="Person Image",
213
- type="pil",
214
- sources=["upload", "clipboard"],
215
- elem_classes=["contain"],
216
- )
217
 
218
- # Individual person examples
219
- gr.Examples(
220
- examples=person_only_examples,
221
- inputs=person_image,
222
- label="Person Examples",
223
- )
224
 
225
- # Column 2: Garment
226
- with gr.Column(scale=1):
227
- garment_image = gr.Image(
228
- label="Garment Image",
229
- type="pil",
230
- sources=["upload", "clipboard"],
231
- elem_classes=["contain"],
232
- )
233
 
234
- with gr.Row():
235
- category = gr.Dropdown(
236
- choices=CATEGORIES,
237
- value="tops",
238
- label="Category",
239
- )
240
- garment_photo_type = gr.Dropdown(
241
- choices=GARMENT_PHOTO_TYPES,
242
- value="model",
243
- label="Photo Type",
244
- )
245
-
246
- # Garment examples as clickable gallery
247
- gr.Markdown("**Garment Examples** (click to load with settings)")
248
- garment_gallery = gr.Gallery(
249
- value=garment_gallery_images,
250
- columns=2,
251
- rows=1,
252
- height="auto",
253
- object_fit="contain",
254
- show_label=False,
255
- allow_preview=False,
256
  )
257
 
258
- # Column 3: Result
259
- with gr.Column(scale=1):
260
- result_image = gr.Image(
261
- label="Try-On Result",
262
- type="pil",
263
- interactive=False,
264
- elem_classes=["contain"],
265
  )
266
 
267
- run_button = gr.Button("Try On", variant="primary", size="lg")
 
 
 
268
 
269
- # Advanced settings
270
  with gr.Accordion("Advanced Settings", open=False):
271
- num_timesteps = gr.Slider(
272
- minimum=10,
273
- maximum=50,
274
- value=50,
275
- step=5,
276
- label="Sampling Steps",
277
- info="Higher = better quality, slower.",
278
- )
279
- guidance_scale = gr.Slider(
280
- minimum=1.0,
281
- maximum=3.0,
282
- value=1.5,
283
- step=0.1,
284
- label="Guidance Scale",
285
- info="How closely to follow the garment. 1.5 recommended.",
286
- )
287
- seed = gr.Number(
288
- value=42,
289
- label="Seed",
290
- info="Random seed for reproducibility.",
291
- precision=0,
292
- )
293
- segmentation_free = gr.Checkbox(
294
- value=True,
295
- label="Segmentation Free",
296
- info="Preserves body features and allows unconstrained garment volume. Disable for tighter garment fitting.",
297
- )
298
-
299
- # Paired examples at the bottom
300
- gr.Examples(
301
- examples=paired_examples,
302
- inputs=[person_image, garment_image, category, garment_photo_type],
303
- label="Complete Examples (click to load person + garment + settings)",
304
- )
305
 
306
- # Event handlers
307
  run_button.click(
308
  fn=try_on,
309
  inputs=[
@@ -316,18 +184,10 @@ with gr.Blocks(css=CUSTOM_CSS) as demo:
316
  seed,
317
  segmentation_free,
318
  ],
319
- outputs=[result_image],
320
- )
321
-
322
- # Garment gallery selection - loads image and updates dropdowns
323
- garment_gallery.select(
324
- fn=on_garment_gallery_select,
325
- inputs=None,
326
- outputs=[garment_image, category, garment_photo_type],
327
  )
328
 
329
- # Configure queue with concurrency limit to prevent GPU OOM
330
- demo.queue(default_concurrency_limit=1, max_size=30)
331
 
332
  if __name__ == "__main__":
333
- demo.launch(share=False)
 
1
+ """FASHN VTON v1.5 HuggingFace Space Demo (CPU Compatible Version)."""
2
 
3
  import os
4
  import platform
 
5
  import gradio as gr
 
6
  import torch
7
  from huggingface_hub import hf_hub_download
8
  from PIL import Image
 
16
  CATEGORIES = ["tops", "bottoms", "one-pieces"]
17
  GARMENT_PHOTO_TYPES = ["model", "flat-lay"]
18
 
 
19
  _pipeline = None
20
 
21
 
22
  # ----------------- HELPERS ----------------- #
23
 
 
24
  def download_weights():
25
  """Download model weights from HuggingFace Hub."""
26
  os.makedirs(WEIGHTS_DIR, exist_ok=True)
27
  dwpose_dir = os.path.join(WEIGHTS_DIR, "dwpose")
28
  os.makedirs(dwpose_dir, exist_ok=True)
29
 
 
30
  tryon_path = os.path.join(WEIGHTS_DIR, "model.safetensors")
31
  if not os.path.exists(tryon_path):
32
  print("Downloading TryOnModel weights...")
 
36
  local_dir=WEIGHTS_DIR,
37
  )
38
 
 
39
  dwpose_files = ["yolox_l.onnx", "dw-ll_ucoco_384.onnx"]
40
  for filename in dwpose_files:
41
  filepath = os.path.join(dwpose_dir, filename)
 
47
  local_dir=dwpose_dir,
48
  )
49
 
50
+ print("Weights ready!")
51
 
52
 
53
  # ----------------- MODEL LOADING ----------------- #
54
 
 
55
  def get_pipeline():
56
+ """Lazy-load pipeline (CPU/GPU auto detect)."""
57
  global _pipeline
58
+
59
  if _pipeline is None:
 
 
 
 
 
 
60
 
61
+ device = "cuda" if torch.cuda.is_available() else "cpu"
62
+ print(f"Using device: {device}")
63
+
64
+ print(f"Python: {platform.python_version()}")
65
+ print(f"PyTorch: {torch.__version__}")
66
+
 
67
  if torch.cuda.is_available():
68
  dev = torch.cuda.current_device()
69
+ print(f"GPU: {torch.cuda.get_device_name(dev)}")
 
 
 
 
 
 
70
 
71
+ print("Downloading weights if needed...")
72
  download_weights()
73
 
74
  print("Loading pipeline...")
75
  from fashn_vton import TryOnPipeline
76
 
77
+ _pipeline = TryOnPipeline(
78
+ weights_dir=WEIGHTS_DIR,
79
+ device=device,
80
+ )
81
+
82
+ print(f"Pipeline loaded successfully on {device}")
83
 
84
  return _pipeline
85
 
86
 
87
  # ----------------- INFERENCE ----------------- #
88
 
 
 
89
  def try_on(
90
  person_image: Image.Image,
91
  garment_image: Image.Image,
 
96
  seed: int,
97
  segmentation_free: bool,
98
  ) -> Image.Image:
99
+
100
  if person_image is None:
101
  raise gr.Error("Please upload a person image")
102
  if garment_image is None:
103
  raise gr.Error("Please upload a garment image")
104
 
 
105
  if seed is None or seed < 0:
106
  seed = 42
107
 
 
108
  if person_image.mode != "RGB":
109
  person_image = person_image.convert("RGB")
110
+
111
  if garment_image.mode != "RGB":
112
  garment_image = garment_image.convert("RGB")
113
 
 
114
  pipeline = get_pipeline()
115
 
 
116
  result = pipeline(
117
  person_image=person_image,
118
  garment_image=garment_image,
 
130
 
131
  # ----------------- UI ----------------- #
132
 
 
133
  CUSTOM_CSS = """
134
  .contain img {
135
  object-fit: contain !important;
 
138
  }
139
  """
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  with gr.Blocks(css=CUSTOM_CSS) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ gr.Markdown("# 👕 FASHN VTON v1.5 (CPU Compatible)")
 
 
 
 
 
144
 
145
+ with gr.Row():
146
+ with gr.Column():
147
+ person_image = gr.Image(label="Person Image", type="pil")
 
 
 
 
 
148
 
149
+ with gr.Column():
150
+ garment_image = gr.Image(label="Garment Image", type="pil")
151
+
152
+ category = gr.Dropdown(
153
+ choices=CATEGORIES,
154
+ value="tops",
155
+ label="Category",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  )
157
 
158
+ garment_photo_type = gr.Dropdown(
159
+ choices=GARMENT_PHOTO_TYPES,
160
+ value="model",
161
+ label="Photo Type",
 
 
 
162
  )
163
 
164
+ with gr.Column():
165
+ result_image = gr.Image(label="Try-On Result", type="pil")
166
+
167
+ run_button = gr.Button("Try On", variant="primary")
168
 
 
169
  with gr.Accordion("Advanced Settings", open=False):
170
+ num_timesteps = gr.Slider(10, 50, value=50, step=5, label="Sampling Steps")
171
+ guidance_scale = gr.Slider(1.0, 3.0, value=1.5, step=0.1, label="Guidance Scale")
172
+ seed = gr.Number(value=42, label="Seed", precision=0)
173
+ segmentation_free = gr.Checkbox(value=True, label="Segmentation Free")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
 
175
  run_button.click(
176
  fn=try_on,
177
  inputs=[
 
184
  seed,
185
  segmentation_free,
186
  ],
187
+ outputs=result_image,
 
 
 
 
 
 
 
188
  )
189
 
190
+ demo.queue(default_concurrency_limit=1, max_size=20)
 
191
 
192
  if __name__ == "__main__":
193
+ demo.launch()