Spaces:
Running on Zero
Running on Zero
Update app.py
#5
by Aza72 - opened
app.py
CHANGED
|
@@ -15,6 +15,8 @@ from tqdm import tqdm
|
|
| 15 |
import cv2
|
| 16 |
import numpy as np
|
| 17 |
import torch
|
|
|
|
|
|
|
| 18 |
from torch.nn import functional as F
|
| 19 |
from PIL import Image
|
| 20 |
|
|
@@ -231,9 +233,30 @@ def interpolate_bits(frames_np, multiplier=2, scale=1.0):
|
|
| 231 |
|
| 232 |
# WAN
|
| 233 |
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
CACHE_DIR = os.path.expanduser("~/.cache/huggingface/")
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
MAX_DIM = 832
|
| 238 |
MIN_DIM = 480
|
| 239 |
SQUARE_DIM = 640
|
|
@@ -258,11 +281,43 @@ SCHEDULER_MAP = {
|
|
| 258 |
}
|
| 259 |
|
| 260 |
pipe = WanImageToVideoPipeline.from_pretrained(
|
| 261 |
-
|
| 262 |
torch_dtype=torch.bfloat16,
|
| 263 |
).to('cuda')
|
| 264 |
original_scheduler = copy.deepcopy(pipe.scheduler)
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
if os.path.exists(CACHE_DIR):
|
| 267 |
shutil.rmtree(CACHE_DIR)
|
| 268 |
print("Deleted Hugging Face cache.")
|
|
@@ -270,8 +325,11 @@ else:
|
|
| 270 |
print("No hub cache found.")
|
| 271 |
|
| 272 |
quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
|
|
|
|
| 273 |
quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
|
|
|
|
| 274 |
quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
|
|
|
|
| 275 |
|
| 276 |
aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
|
| 277 |
aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
|
|
@@ -283,10 +341,13 @@ default_prompt_i2v = "make this image come alive, cinematic motion, smooth anima
|
|
| 283 |
default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
|
| 284 |
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
def resize_image(image: Image.Image) -> Image.Image:
|
| 287 |
-
"""
|
| 288 |
-
Resizes an image to fit within the model's constraints, preserving aspect ratio as much as possible.
|
| 289 |
-
"""
|
| 290 |
width, height = image.size
|
| 291 |
if width == height:
|
| 292 |
return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
|
|
@@ -322,7 +383,6 @@ def resize_image(image: Image.Image) -> Image.Image:
|
|
| 322 |
|
| 323 |
|
| 324 |
def resize_and_crop_to_match(target_image, reference_image):
|
| 325 |
-
"""Resizes and center-crops the target image to match the reference image's dimensions."""
|
| 326 |
ref_width, ref_height = reference_image.size
|
| 327 |
target_width, target_height = target_image.size
|
| 328 |
scale = max(ref_width / target_width, ref_height / target_height)
|
|
@@ -406,7 +466,7 @@ def run_inference(
|
|
| 406 |
clear_vram()
|
| 407 |
|
| 408 |
task_name = str(uuid.uuid4())[:8]
|
| 409 |
-
print(f"
|
| 410 |
start = time.time()
|
| 411 |
result = pipe(
|
| 412 |
image=resized_image,
|
|
@@ -422,6 +482,7 @@ def run_inference(
|
|
| 422 |
generator=torch.Generator(device="cuda").manual_seed(current_seed),
|
| 423 |
output_type="np"
|
| 424 |
)
|
|
|
|
| 425 |
|
| 426 |
raw_frames_np = result.frames[0] # Returns (T, H, W, C) float32
|
| 427 |
pipe.scheduler = original_scheduler
|
|
@@ -429,9 +490,11 @@ def run_inference(
|
|
| 429 |
frame_factor = frame_multiplier // FIXED_FPS
|
| 430 |
if frame_factor > 1:
|
| 431 |
start = time.time()
|
|
|
|
| 432 |
rife_model.device()
|
| 433 |
rife_model.flownet = rife_model.flownet.half()
|
| 434 |
final_frames = interpolate_bits(raw_frames_np, multiplier=int(frame_factor))
|
|
|
|
| 435 |
else:
|
| 436 |
final_frames = list(raw_frames_np)
|
| 437 |
|
|
@@ -445,6 +508,7 @@ def run_inference(
|
|
| 445 |
pbar.update(2)
|
| 446 |
export_to_video(final_frames, video_path, fps=final_fps, quality=quality)
|
| 447 |
pbar.update(1)
|
|
|
|
| 448 |
|
| 449 |
return video_path, task_name
|
| 450 |
|
|
@@ -559,10 +623,9 @@ CSS = """
|
|
| 559 |
"""
|
| 560 |
|
| 561 |
|
| 562 |
-
with gr.Blocks(
|
| 563 |
-
gr.Markdown(
|
| 564 |
gr.Markdown("#### ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
|
| 565 |
-
gr.Markdown('Try the previous version: [WAMU v1](https://huggingface.co/spaces/r3gm/wan2-2-fp8da-aoti-preview2)')
|
| 566 |
gr.Markdown("Run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU")
|
| 567 |
|
| 568 |
with gr.Row():
|
|
@@ -571,7 +634,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CSS, delete_cache=(3600, 10800)) as d
|
|
| 571 |
prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
|
| 572 |
duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
|
| 573 |
frame_multi = gr.Dropdown(
|
| 574 |
-
choices=[FIXED_FPS, FIXED_FPS*2, FIXED_FPS*4],
|
| 575 |
value=FIXED_FPS,
|
| 576 |
label="Video Fluidity (Frames per Second)",
|
| 577 |
info="Extra frames will be generated using flow estimation, which estimates motion between frames to make the video smoother."
|
|
@@ -593,12 +656,17 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CSS, delete_cache=(3600, 10800)) as d
|
|
| 593 |
)
|
| 594 |
flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
|
| 595 |
play_result_video = gr.Checkbox(label="Display result", value=True, interactive=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
|
| 597 |
generate_button = gr.Button("Generate Video", variant="primary")
|
| 598 |
|
| 599 |
with gr.Column():
|
| 600 |
# ASSIGNED elem_id="generated-video" so JS can find it
|
| 601 |
-
video_output = gr.Video(label="Generated Video", autoplay=True, sources=["upload"],
|
| 602 |
|
| 603 |
# --- Frame Grabbing UI ---
|
| 604 |
with gr.Row():
|
|
@@ -641,6 +709,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CSS, delete_cache=(3600, 10800)) as d
|
|
| 641 |
if __name__ == "__main__":
|
| 642 |
demo.queue().launch(
|
| 643 |
mcp_server=True,
|
| 644 |
-
|
| 645 |
show_error=True,
|
| 646 |
)
|
|
|
|
| 15 |
import cv2
|
| 16 |
import numpy as np
|
| 17 |
import torch
|
| 18 |
+
import torch._dynamo
|
| 19 |
+
from huggingface_hub import list_models
|
| 20 |
from torch.nn import functional as F
|
| 21 |
from PIL import Image
|
| 22 |
|
|
|
|
| 233 |
|
| 234 |
# WAN
|
| 235 |
|
| 236 |
+
ORG_NAME = "TestOrganizationPleaseIgnore"
|
| 237 |
+
# MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
|
| 238 |
+
MODEL_ID = os.getenv("REPO_ID") or random.choice(
|
| 239 |
+
list(list_models(author=ORG_NAME, filter='diffusers:WanImageToVideoPipeline'))
|
| 240 |
+
).modelId
|
| 241 |
CACHE_DIR = os.path.expanduser("~/.cache/huggingface/")
|
| 242 |
|
| 243 |
+
LORA_MODELS = [
|
| 244 |
+
# {
|
| 245 |
+
# "repo_id": "exampleuser/example_lora_1",
|
| 246 |
+
# "high_tr": "example_lora_1_high.safetensors",
|
| 247 |
+
# "low_tr": "example_lora_1_low.safetensors",
|
| 248 |
+
# "high_scale": 0.5,
|
| 249 |
+
# "low_scale": 0.5
|
| 250 |
+
# },
|
| 251 |
+
# {
|
| 252 |
+
# "repo_id": "exampleuser/example_lora_2",
|
| 253 |
+
# "high_tr": "subfolder/example_lora_2_high.safetensors",
|
| 254 |
+
# "low_tr": "subfolder/example_lora_2_low.safetensors",
|
| 255 |
+
# "high_scale": 0.4,
|
| 256 |
+
# "low_scale": 0.4
|
| 257 |
+
# },
|
| 258 |
+
]
|
| 259 |
+
|
| 260 |
MAX_DIM = 832
|
| 261 |
MIN_DIM = 480
|
| 262 |
SQUARE_DIM = 640
|
|
|
|
| 281 |
}
|
| 282 |
|
| 283 |
pipe = WanImageToVideoPipeline.from_pretrained(
|
| 284 |
+
MODEL_ID,
|
| 285 |
torch_dtype=torch.bfloat16,
|
| 286 |
).to('cuda')
|
| 287 |
original_scheduler = copy.deepcopy(pipe.scheduler)
|
| 288 |
|
| 289 |
+
for i, lora in enumerate(LORA_MODELS):
|
| 290 |
+
name_high_tr = lora["high_tr"].split(".")[0].split("/")[-1] + "Hh"
|
| 291 |
+
name_low_tr = lora["low_tr"].split(".")[0].split("/")[-1] + "Ll"
|
| 292 |
+
|
| 293 |
+
try:
|
| 294 |
+
pipe.load_lora_weights(
|
| 295 |
+
lora["repo_id"],
|
| 296 |
+
weight_name=lora["high_tr"],
|
| 297 |
+
adapter_name=name_high_tr
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
kwargs_lora = {"load_into_transformer_2": True}
|
| 301 |
+
pipe.load_lora_weights(
|
| 302 |
+
lora["repo_id"],
|
| 303 |
+
weight_name=lora["low_tr"],
|
| 304 |
+
adapter_name=name_low_tr,
|
| 305 |
+
**kwargs_lora
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
pipe.set_adapters([name_high_tr, name_low_tr], adapter_weights=[1.0, 1.0])
|
| 309 |
+
|
| 310 |
+
pipe.fuse_lora(adapter_names=[name_high_tr], lora_scale=lora["high_scale"], components=["transformer"])
|
| 311 |
+
pipe.fuse_lora(adapter_names=[name_low_tr], lora_scale=lora["low_scale"], components=["transformer_2"])
|
| 312 |
+
|
| 313 |
+
pipe.unload_lora_weights()
|
| 314 |
+
|
| 315 |
+
print(f"Applied: {lora['high_tr']}, hs={lora['high_scale']}/ls={lora['low_scale']}, {i+1}/{len(LORA_MODELS)}")
|
| 316 |
+
except Exception as e:
|
| 317 |
+
print("Error:", str(e))
|
| 318 |
+
print("Failed LoRA:", name_high_tr)
|
| 319 |
+
pipe.unload_lora_weights()
|
| 320 |
+
|
| 321 |
if os.path.exists(CACHE_DIR):
|
| 322 |
shutil.rmtree(CACHE_DIR)
|
| 323 |
print("Deleted Hugging Face cache.")
|
|
|
|
| 325 |
print("No hub cache found.")
|
| 326 |
|
| 327 |
quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
|
| 328 |
+
torch._dynamo.reset()
|
| 329 |
quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
|
| 330 |
+
torch._dynamo.reset()
|
| 331 |
quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
|
| 332 |
+
torch._dynamo.reset()
|
| 333 |
|
| 334 |
aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
|
| 335 |
aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
|
|
|
|
| 341 |
default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
|
| 342 |
|
| 343 |
|
| 344 |
+
def model_title():
|
| 345 |
+
repo_name = MODEL_ID.split('/')[-1].replace("_", " ")
|
| 346 |
+
url = f"https://huggingface.co/{MODEL_ID}"
|
| 347 |
+
return f"## This space is currently running [{repo_name}]({url}) 🐢"
|
| 348 |
+
|
| 349 |
+
|
| 350 |
def resize_image(image: Image.Image) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
| 351 |
width, height = image.size
|
| 352 |
if width == height:
|
| 353 |
return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
|
|
|
|
| 383 |
|
| 384 |
|
| 385 |
def resize_and_crop_to_match(target_image, reference_image):
|
|
|
|
| 386 |
ref_width, ref_height = reference_image.size
|
| 387 |
target_width, target_height = target_image.size
|
| 388 |
scale = max(ref_width / target_width, ref_height / target_height)
|
|
|
|
| 466 |
clear_vram()
|
| 467 |
|
| 468 |
task_name = str(uuid.uuid4())[:8]
|
| 469 |
+
print(f"Generating {num_frames} frames, task: {task_name}, {duration_seconds}, {resized_image.size}")
|
| 470 |
start = time.time()
|
| 471 |
result = pipe(
|
| 472 |
image=resized_image,
|
|
|
|
| 482 |
generator=torch.Generator(device="cuda").manual_seed(current_seed),
|
| 483 |
output_type="np"
|
| 484 |
)
|
| 485 |
+
print("gen time passed:", time.time() - start)
|
| 486 |
|
| 487 |
raw_frames_np = result.frames[0] # Returns (T, H, W, C) float32
|
| 488 |
pipe.scheduler = original_scheduler
|
|
|
|
| 490 |
frame_factor = frame_multiplier // FIXED_FPS
|
| 491 |
if frame_factor > 1:
|
| 492 |
start = time.time()
|
| 493 |
+
print(f"Processing frames (RIFE Multiplier: {frame_factor}x)...")
|
| 494 |
rife_model.device()
|
| 495 |
rife_model.flownet = rife_model.flownet.half()
|
| 496 |
final_frames = interpolate_bits(raw_frames_np, multiplier=int(frame_factor))
|
| 497 |
+
print("Interpolation time passed:", time.time() - start)
|
| 498 |
else:
|
| 499 |
final_frames = list(raw_frames_np)
|
| 500 |
|
|
|
|
| 508 |
pbar.update(2)
|
| 509 |
export_to_video(final_frames, video_path, fps=final_fps, quality=quality)
|
| 510 |
pbar.update(1)
|
| 511 |
+
print(f"Export time passed, {final_fps} FPS:", time.time() - start)
|
| 512 |
|
| 513 |
return video_path, task_name
|
| 514 |
|
|
|
|
| 623 |
"""
|
| 624 |
|
| 625 |
|
| 626 |
+
with gr.Blocks(delete_cache=(3600, 10800)) as demo:
|
| 627 |
+
gr.Markdown(model_title())
|
| 628 |
gr.Markdown("#### ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
|
|
|
|
| 629 |
gr.Markdown("Run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU")
|
| 630 |
|
| 631 |
with gr.Row():
|
|
|
|
| 634 |
prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
|
| 635 |
duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
|
| 636 |
frame_multi = gr.Dropdown(
|
| 637 |
+
choices=[FIXED_FPS, FIXED_FPS*2, FIXED_FPS*4, FIXED_FPS*8],
|
| 638 |
value=FIXED_FPS,
|
| 639 |
label="Video Fluidity (Frames per Second)",
|
| 640 |
info="Extra frames will be generated using flow estimation, which estimates motion between frames to make the video smoother."
|
|
|
|
| 656 |
)
|
| 657 |
flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
|
| 658 |
play_result_video = gr.Checkbox(label="Display result", value=True, interactive=True)
|
| 659 |
+
gr.Markdown(f"[ZeroGPU help, tips and troubleshooting](https://huggingface.co/datasets/{ORG_NAME}/help/blob/main/gpu_help.md)")
|
| 660 |
+
gr.Markdown( # TestOrganizationPleaseIgnore/wamu-tools
|
| 661 |
+
"To use a different model, **duplicate this Space** first, then change the `REPO_ID` environment variable. "
|
| 662 |
+
"[See compatible models here](https://huggingface.co/models?other=diffusers:WanImageToVideoPipeline&sort=trending&search=WAN2.2_I2V_LIGHTNING)."
|
| 663 |
+
)
|
| 664 |
|
| 665 |
generate_button = gr.Button("Generate Video", variant="primary")
|
| 666 |
|
| 667 |
with gr.Column():
|
| 668 |
# ASSIGNED elem_id="generated-video" so JS can find it
|
| 669 |
+
video_output = gr.Video(label="Generated Video", autoplay=True, sources=["upload"], buttons=["download", "share"], interactive=True, elem_id="generated-video")
|
| 670 |
|
| 671 |
# --- Frame Grabbing UI ---
|
| 672 |
with gr.Row():
|
|
|
|
| 709 |
if __name__ == "__main__":
|
| 710 |
demo.queue().launch(
|
| 711 |
mcp_server=True,
|
| 712 |
+
css=CSS,
|
| 713 |
show_error=True,
|
| 714 |
)
|