Files changed (1) hide show
  1. app.py +81 -13
app.py CHANGED
@@ -15,6 +15,8 @@ from tqdm import tqdm
15
  import cv2
16
  import numpy as np
17
  import torch
 
 
18
  from torch.nn import functional as F
19
  from PIL import Image
20
 
@@ -231,9 +233,30 @@ def interpolate_bits(frames_np, multiplier=2, scale=1.0):
231
 
232
  # WAN
233
 
234
- MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
 
 
 
 
235
  CACHE_DIR = os.path.expanduser("~/.cache/huggingface/")
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  MAX_DIM = 832
238
  MIN_DIM = 480
239
  SQUARE_DIM = 640
@@ -258,11 +281,43 @@ SCHEDULER_MAP = {
258
  }
259
 
260
  pipe = WanImageToVideoPipeline.from_pretrained(
261
- "TestOrganizationPleaseIgnore/WAMU_v2_WAN2.2_I2V_LIGHTNING",
262
  torch_dtype=torch.bfloat16,
263
  ).to('cuda')
264
  original_scheduler = copy.deepcopy(pipe.scheduler)
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  if os.path.exists(CACHE_DIR):
267
  shutil.rmtree(CACHE_DIR)
268
  print("Deleted Hugging Face cache.")
@@ -270,8 +325,11 @@ else:
270
  print("No hub cache found.")
271
 
272
  quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
 
273
  quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
 
274
  quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
 
275
 
276
  aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
277
  aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
@@ -283,10 +341,13 @@ default_prompt_i2v = "make this image come alive, cinematic motion, smooth anima
283
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
284
 
285
 
 
 
 
 
 
 
286
  def resize_image(image: Image.Image) -> Image.Image:
287
- """
288
- Resizes an image to fit within the model's constraints, preserving aspect ratio as much as possible.
289
- """
290
  width, height = image.size
291
  if width == height:
292
  return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
@@ -322,7 +383,6 @@ def resize_image(image: Image.Image) -> Image.Image:
322
 
323
 
324
  def resize_and_crop_to_match(target_image, reference_image):
325
- """Resizes and center-crops the target image to match the reference image's dimensions."""
326
  ref_width, ref_height = reference_image.size
327
  target_width, target_height = target_image.size
328
  scale = max(ref_width / target_width, ref_height / target_height)
@@ -406,7 +466,7 @@ def run_inference(
406
  clear_vram()
407
 
408
  task_name = str(uuid.uuid4())[:8]
409
- print(f"Task: {task_name}, {duration_seconds}, {resized_image.size}, FM={frame_multiplier}")
410
  start = time.time()
411
  result = pipe(
412
  image=resized_image,
@@ -422,6 +482,7 @@ def run_inference(
422
  generator=torch.Generator(device="cuda").manual_seed(current_seed),
423
  output_type="np"
424
  )
 
425
 
426
  raw_frames_np = result.frames[0] # Returns (T, H, W, C) float32
427
  pipe.scheduler = original_scheduler
@@ -429,9 +490,11 @@ def run_inference(
429
  frame_factor = frame_multiplier // FIXED_FPS
430
  if frame_factor > 1:
431
  start = time.time()
 
432
  rife_model.device()
433
  rife_model.flownet = rife_model.flownet.half()
434
  final_frames = interpolate_bits(raw_frames_np, multiplier=int(frame_factor))
 
435
  else:
436
  final_frames = list(raw_frames_np)
437
 
@@ -445,6 +508,7 @@ def run_inference(
445
  pbar.update(2)
446
  export_to_video(final_frames, video_path, fps=final_fps, quality=quality)
447
  pbar.update(1)
 
448
 
449
  return video_path, task_name
450
 
@@ -559,10 +623,9 @@ CSS = """
559
  """
560
 
561
 
562
- with gr.Blocks(theme=gr.themes.Soft(), css=CSS, delete_cache=(3600, 10800)) as demo:
563
- gr.Markdown("## WAMU V2 - Wan 2.2 I2V (14B) 🐢🐢")
564
  gr.Markdown("#### ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
565
- gr.Markdown('Try the previous version: [WAMU v1](https://huggingface.co/spaces/r3gm/wan2-2-fp8da-aoti-preview2)')
566
  gr.Markdown("Run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU")
567
 
568
  with gr.Row():
@@ -571,7 +634,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CSS, delete_cache=(3600, 10800)) as d
571
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
572
  duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
573
  frame_multi = gr.Dropdown(
574
- choices=[FIXED_FPS, FIXED_FPS*2, FIXED_FPS*4],
575
  value=FIXED_FPS,
576
  label="Video Fluidity (Frames per Second)",
577
  info="Extra frames will be generated using flow estimation, which estimates motion between frames to make the video smoother."
@@ -593,12 +656,17 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CSS, delete_cache=(3600, 10800)) as d
593
  )
594
  flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
595
  play_result_video = gr.Checkbox(label="Display result", value=True, interactive=True)
 
 
 
 
 
596
 
597
  generate_button = gr.Button("Generate Video", variant="primary")
598
 
599
  with gr.Column():
600
  # ASSIGNED elem_id="generated-video" so JS can find it
601
- video_output = gr.Video(label="Generated Video", autoplay=True, sources=["upload"], show_download_button=True, show_share_button=True, interactive=False, elem_id="generated-video")
602
 
603
  # --- Frame Grabbing UI ---
604
  with gr.Row():
@@ -641,6 +709,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CSS, delete_cache=(3600, 10800)) as d
641
  if __name__ == "__main__":
642
  demo.queue().launch(
643
  mcp_server=True,
644
- ssr_mode=False,
645
  show_error=True,
646
  )
 
15
  import cv2
16
  import numpy as np
17
  import torch
18
+ import torch._dynamo
19
+ from huggingface_hub import list_models
20
  from torch.nn import functional as F
21
  from PIL import Image
22
 
 
233
 
234
  # WAN
235
 
236
+ ORG_NAME = "TestOrganizationPleaseIgnore"
237
+ # MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
238
+ MODEL_ID = os.getenv("REPO_ID") or random.choice(
239
+ list(list_models(author=ORG_NAME, filter='diffusers:WanImageToVideoPipeline'))
240
+ ).modelId
241
  CACHE_DIR = os.path.expanduser("~/.cache/huggingface/")
242
 
243
+ LORA_MODELS = [
244
+ # {
245
+ # "repo_id": "exampleuser/example_lora_1",
246
+ # "high_tr": "example_lora_1_high.safetensors",
247
+ # "low_tr": "example_lora_1_low.safetensors",
248
+ # "high_scale": 0.5,
249
+ # "low_scale": 0.5
250
+ # },
251
+ # {
252
+ # "repo_id": "exampleuser/example_lora_2",
253
+ # "high_tr": "subfolder/example_lora_2_high.safetensors",
254
+ # "low_tr": "subfolder/example_lora_2_low.safetensors",
255
+ # "high_scale": 0.4,
256
+ # "low_scale": 0.4
257
+ # },
258
+ ]
259
+
260
  MAX_DIM = 832
261
  MIN_DIM = 480
262
  SQUARE_DIM = 640
 
281
  }
282
 
283
  pipe = WanImageToVideoPipeline.from_pretrained(
284
+ MODEL_ID,
285
  torch_dtype=torch.bfloat16,
286
  ).to('cuda')
287
  original_scheduler = copy.deepcopy(pipe.scheduler)
288
 
289
+ for i, lora in enumerate(LORA_MODELS):
290
+ name_high_tr = lora["high_tr"].split(".")[0].split("/")[-1] + "Hh"
291
+ name_low_tr = lora["low_tr"].split(".")[0].split("/")[-1] + "Ll"
292
+
293
+ try:
294
+ pipe.load_lora_weights(
295
+ lora["repo_id"],
296
+ weight_name=lora["high_tr"],
297
+ adapter_name=name_high_tr
298
+ )
299
+
300
+ kwargs_lora = {"load_into_transformer_2": True}
301
+ pipe.load_lora_weights(
302
+ lora["repo_id"],
303
+ weight_name=lora["low_tr"],
304
+ adapter_name=name_low_tr,
305
+ **kwargs_lora
306
+ )
307
+
308
+ pipe.set_adapters([name_high_tr, name_low_tr], adapter_weights=[1.0, 1.0])
309
+
310
+ pipe.fuse_lora(adapter_names=[name_high_tr], lora_scale=lora["high_scale"], components=["transformer"])
311
+ pipe.fuse_lora(adapter_names=[name_low_tr], lora_scale=lora["low_scale"], components=["transformer_2"])
312
+
313
+ pipe.unload_lora_weights()
314
+
315
+ print(f"Applied: {lora['high_tr']}, hs={lora['high_scale']}/ls={lora['low_scale']}, {i+1}/{len(LORA_MODELS)}")
316
+ except Exception as e:
317
+ print("Error:", str(e))
318
+ print("Failed LoRA:", name_high_tr)
319
+ pipe.unload_lora_weights()
320
+
321
  if os.path.exists(CACHE_DIR):
322
  shutil.rmtree(CACHE_DIR)
323
  print("Deleted Hugging Face cache.")
 
325
  print("No hub cache found.")
326
 
327
  quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
328
+ torch._dynamo.reset()
329
  quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
330
+ torch._dynamo.reset()
331
  quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
332
+ torch._dynamo.reset()
333
 
334
  aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
335
  aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
 
341
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
342
 
343
 
344
+ def model_title():
345
+ repo_name = MODEL_ID.split('/')[-1].replace("_", " ")
346
+ url = f"https://huggingface.co/{MODEL_ID}"
347
+ return f"## This space is currently running [{repo_name}]({url}) 🐢"
348
+
349
+
350
  def resize_image(image: Image.Image) -> Image.Image:
 
 
 
351
  width, height = image.size
352
  if width == height:
353
  return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
 
383
 
384
 
385
  def resize_and_crop_to_match(target_image, reference_image):
 
386
  ref_width, ref_height = reference_image.size
387
  target_width, target_height = target_image.size
388
  scale = max(ref_width / target_width, ref_height / target_height)
 
466
  clear_vram()
467
 
468
  task_name = str(uuid.uuid4())[:8]
469
+ print(f"Generating {num_frames} frames, task: {task_name}, {duration_seconds}, {resized_image.size}")
470
  start = time.time()
471
  result = pipe(
472
  image=resized_image,
 
482
  generator=torch.Generator(device="cuda").manual_seed(current_seed),
483
  output_type="np"
484
  )
485
+ print("gen time passed:", time.time() - start)
486
 
487
  raw_frames_np = result.frames[0] # Returns (T, H, W, C) float32
488
  pipe.scheduler = original_scheduler
 
490
  frame_factor = frame_multiplier // FIXED_FPS
491
  if frame_factor > 1:
492
  start = time.time()
493
+ print(f"Processing frames (RIFE Multiplier: {frame_factor}x)...")
494
  rife_model.device()
495
  rife_model.flownet = rife_model.flownet.half()
496
  final_frames = interpolate_bits(raw_frames_np, multiplier=int(frame_factor))
497
+ print("Interpolation time passed:", time.time() - start)
498
  else:
499
  final_frames = list(raw_frames_np)
500
 
 
508
  pbar.update(2)
509
  export_to_video(final_frames, video_path, fps=final_fps, quality=quality)
510
  pbar.update(1)
511
+ print(f"Export time passed, {final_fps} FPS:", time.time() - start)
512
 
513
  return video_path, task_name
514
 
 
623
  """
624
 
625
 
626
+ with gr.Blocks(delete_cache=(3600, 10800)) as demo:
627
+ gr.Markdown(model_title())
628
  gr.Markdown("#### ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
 
629
  gr.Markdown("Run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU")
630
 
631
  with gr.Row():
 
634
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
635
  duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
636
  frame_multi = gr.Dropdown(
637
+ choices=[FIXED_FPS, FIXED_FPS*2, FIXED_FPS*4, FIXED_FPS*8],
638
  value=FIXED_FPS,
639
  label="Video Fluidity (Frames per Second)",
640
  info="Extra frames will be generated using flow estimation, which estimates motion between frames to make the video smoother."
 
656
  )
657
  flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
658
  play_result_video = gr.Checkbox(label="Display result", value=True, interactive=True)
659
+ gr.Markdown(f"[ZeroGPU help, tips and troubleshooting](https://huggingface.co/datasets/{ORG_NAME}/help/blob/main/gpu_help.md)")
660
+ gr.Markdown( # TestOrganizationPleaseIgnore/wamu-tools
661
+ "To use a different model, **duplicate this Space** first, then change the `REPO_ID` environment variable. "
662
+ "[See compatible models here](https://huggingface.co/models?other=diffusers:WanImageToVideoPipeline&sort=trending&search=WAN2.2_I2V_LIGHTNING)."
663
+ )
664
 
665
  generate_button = gr.Button("Generate Video", variant="primary")
666
 
667
  with gr.Column():
668
  # ASSIGNED elem_id="generated-video" so JS can find it
669
+ video_output = gr.Video(label="Generated Video", autoplay=True, sources=["upload"], buttons=["download", "share"], interactive=True, elem_id="generated-video")
670
 
671
  # --- Frame Grabbing UI ---
672
  with gr.Row():
 
709
  if __name__ == "__main__":
710
  demo.queue().launch(
711
  mcp_server=True,
712
+ css=CSS,
713
  show_error=True,
714
  )