| ''' |
| from diffusers import utils |
| from diffusers.utils import deprecation_utils |
| from diffusers.models import cross_attention |
| utils.deprecate = lambda *arg, **kwargs: None |
| deprecation_utils.deprecate = lambda *arg, **kwargs: None |
| cross_attention.deprecate = lambda *arg, **kwargs: None |
| ''' |
|
|
| import os |
| import sys |
| ''' |
| MAIN_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) |
| sys.path.insert(0, MAIN_DIR) |
| os.chdir(MAIN_DIR) |
| ''' |
|
|
| import gradio as gr |
| import numpy as np |
| import torch |
| import random |
|
|
| from annotator.util import resize_image, HWC3 |
| from annotator.canny import CannyDetector |
| from diffusers.models.unet_2d_condition import UNet2DConditionModel |
| from diffusers.pipelines import DiffusionPipeline |
| from diffusers.schedulers import DPMSolverMultistepScheduler |
| |
|
|
| apply_canny = CannyDetector() |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| ''' |
| pipeline = DiffusionPipeline.from_pretrained( |
| 'IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1', safety_checker=None |
| ) |
| pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
| pipeline = pipeline.to(device) |
| unet: UNet2DConditionModel = pipeline.unet |
| |
| #ckpt_path = "ckpts/sd-diffusiondb-canny-model-control-lora-zh" |
| ckpt_path = "svjack/canny-control-lora-zh" |
| control_lora = ControlLoRA.from_pretrained(ckpt_path) |
| control_lora = control_lora.to(device) |
| |
| # load control lora attention processors |
| lora_attn_procs = {} |
| lora_layers_list = list([list(layer_list) for layer_list in control_lora.lora_layers]) |
| n_ch = len(unet.config.block_out_channels) |
| control_ids = [i for i in range(n_ch)] |
| for name in pipeline.unet.attn_processors.keys(): |
| cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim |
| if name.startswith("mid_block"): |
| control_id = control_ids[-1] |
| elif name.startswith("up_blocks"): |
| block_id = int(name[len("up_blocks.")]) |
| control_id = list(reversed(control_ids))[block_id] |
| elif name.startswith("down_blocks"): |
| block_id = int(name[len("down_blocks.")]) |
| control_id = control_ids[block_id] |
| |
| lora_layers = lora_layers_list[control_id] |
| if len(lora_layers) != 0: |
| lora_layer: ControlLoRACrossAttnProcessor = lora_layers.pop(0) |
| lora_attn_procs[name] = lora_layer |
| |
| unet.set_attn_processor(lora_attn_procs) |
| ''' |
|
|
| from diffusers import ( |
| AutoencoderKL, |
| ControlNetModel, |
| DDPMScheduler, |
| StableDiffusionControlNetPipeline, |
| UNet2DConditionModel, |
| UniPCMultistepScheduler, |
| ) |
| import torch |
| from diffusers.utils import load_image |
|
|
| controlnet_model_name_or_path = "svjack/ControlNet-Canny-Zh" |
| controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path) |
|
|
| base_model_path = "IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1" |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( |
| base_model_path, controlnet=controlnet, |
| |
| ) |
|
|
| |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
| |
| if device == "cuda": |
| pipe = pipe.to("cuda") |
|
|
| pipe.safety_checker = None |
|
|
| def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold): |
| from PIL import Image |
| with torch.no_grad(): |
| img = resize_image(HWC3(input_image), image_resolution) |
| H, W, C = img.shape |
|
|
| detected_map = apply_canny(img, low_threshold, high_threshold) |
| detected_map = HWC3(detected_map) |
| ''' |
| print(type(detected_map)) |
| return [detected_map] |
| |
| control = torch.from_numpy(detected_map[...,::-1].copy().transpose([2,0,1])).float().to(device)[None] / 127.5 - 1 |
| _ = control_lora(control).control_states |
| |
| if seed == -1: |
| seed = random.randint(0, 65535) |
| ''' |
| if seed == -1: |
| seed = random.randint(0, 65535) |
| control_image = Image.fromarray(detected_map) |
|
|
| |
| generator = torch.Generator(device=device).manual_seed(seed) |
| images = [] |
| for i in range(num_samples): |
| ''' |
| _ = control_lora(control).control_states |
| image = pipeline( |
| prompt + ', ' + a_prompt, negative_prompt=n_prompt, |
| num_inference_steps=sample_steps, guidance_scale=scale, eta=eta, |
| generator=generator, height=H, width=W).images[0] |
| ''' |
| image = pipe( |
| prompt + ', ' + a_prompt, negative_prompt=n_prompt, |
| num_inference_steps=sample_steps, guidance_scale=scale, eta=eta, |
| image = control_image, |
| generator=generator, height=H, width=W).images[0] |
| images.append(np.asarray(image)) |
|
|
| results = images |
| return [255 - detected_map] + results |
|
|
|
|
| block = gr.Blocks().queue() |
| with block: |
| with gr.Row(): |
| gr.Markdown("## Control Stable Diffusion with Canny Edge Maps") |
| |
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image(source='upload', type="numpy", value = "house.png") |
| prompt = gr.Textbox(label="Prompt", value = "房屋铅笔画") |
| run_button = gr.Button(label="Run") |
| with gr.Accordion("Advanced options", open=False): |
| num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) |
| image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256) |
| low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1) |
| high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1) |
| sample_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) |
| scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) |
| seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) |
| eta = gr.Number(label="eta", value=0.0) |
| a_prompt = gr.Textbox(label="Added Prompt", value='') |
| n_prompt = gr.Textbox(label="Negative Prompt", |
| value='低质量,模糊,混乱') |
| with gr.Column(): |
| result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') |
| ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold] |
| run_button.click(fn=process, inputs=ips, outputs=[result_gallery], show_progress = True) |
|
|
|
|
|
|
| block.launch(server_name='0.0.0.0') |
|
|
| |
|
|