Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| from diffusers import StableDiffusionPipeline | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput | |
| def void(*args, **kwargs): | |
| pass | |
| st.title("AI 元火娘") | |
| with st.sidebar: | |
| model = st.selectbox("Model Name", [ | |
| "wybxc/yanhuo-v1-dreambooth", | |
| "wybxc/yanyuan-v1-dreambooth", | |
| "wybxc/yuanhuo-v1-dreambooth", | |
| "<Custom>" | |
| ]) | |
| if model == "<Custom>": | |
| model = st.text_input("Model Path", "").strip() | |
| # Caching model | |
| if 'model' not in st.session_state: | |
| st.session_state.model = model | |
| if 'pipeline' not in st.session_state: | |
| st.session_state.pipeline = None | |
| if model != st.session_state.model or st.session_state.pipeline is None: | |
| if model: | |
| with st.spinner("Loading Model..."): | |
| pipeline = StableDiffusionPipeline.from_pretrained(model) | |
| assert type(pipeline) is StableDiffusionPipeline | |
| if torch.cuda.is_available(): | |
| pipeline = pipeline.to("cuda") | |
| st.session_state.model = model | |
| st.session_state.pipeline = pipeline | |
| else: | |
| pipeline = None | |
| else: | |
| pipeline = st.session_state.pipeline | |
| assert type(pipeline) is StableDiffusionPipeline | |
| prompt = st.text_area("Prompt", "(yanhuo), 1girl, masterpiece, best quality, " | |
| "white hair, ahoge, snowy street, [smile], dynamic angle, full body, " | |
| "[blue eyes], flat chest, cinematic light") | |
| negative_prompt = st.text_area("Negative Prompt", "lowres, bad anatomy, bad hands, " | |
| "text, error, missing fingers, extra digit, fewer digits, cropped, " | |
| "worst quality, low quality, normal quality, jpeg artifacts, signature, " | |
| "watermark, username, blurry") | |
| with st.sidebar: | |
| height = st.slider("Height", 256, 1024, 512, 64) | |
| width = st.slider("Width", 256, 1024, 512, 64) | |
| steps = st.slider("Steps", 1, 100, 20, 1) | |
| if pipeline and st.button("Generate"): | |
| progress = st.progress(0) | |
| result = pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=steps, | |
| callback=lambda s, *_: void(progress.progress(s / steps)) | |
| ) | |
| assert type(result) is StableDiffusionPipelineOutput | |
| image = result.images[0] | |
| progress.progress(1.0) | |
| st.image(image) | |