Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| import os | |
| from diffusers import ( | |
| DiffusionPipeline, | |
| AutoencoderKL, | |
| ControlNetModel, | |
| StableDiffusionControlNetPipeline, | |
| StableDiffusionControlNetImg2ImgPipeline, | |
| DPMSolverMultistepScheduler, | |
| EulerDiscreteScheduler | |
| ) | |
| # -------- CONFIG -------- # | |
| BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE" | |
| CONTROLNET_MODEL = "monster-labs/control_v1p_sd15_qrcode_monster" | |
| # -------- VAE & CONTROLNET -------- # | |
| vae = AutoencoderKL.from_pretrained( | |
| "stabilityai/sd-vae-ft-mse", | |
| torch_dtype=torch.float16 | |
| ) | |
| controlnet = ControlNetModel.from_pretrained( | |
| CONTROLNET_MODEL, | |
| torch_dtype=torch.float16 | |
| ) | |
| # -------- MAIN PIPELINE -------- # | |
| main_pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
| BASE_MODEL, | |
| controlnet=controlnet, | |
| vae=vae, | |
| torch_dtype=torch.float16, | |
| safety_checker=None, # safety checker disabled | |
| feature_extractor=None | |
| ).to("cuda") | |
| # -------- IMG2IMG PIPELINE -------- # | |
| image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components) | |
| # -------- SCHEDULERS -------- # | |
| SAMPLER_MAP = { | |
| "DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config( | |
| config, use_karras=True, algorithm_type="sde-dpmsolver++" | |
| ), | |
| "Euler": lambda config: EulerDiscreteScheduler.from_config(config), | |
| } | |
| # -------- GRADIO DEMO -------- # | |
| def generate(prompt, control_image, strength=0.8, guidance=7.0, steps=30, sampler="Euler"): | |
| scheduler = SAMPLER_MAP[sampler](image_pipe.scheduler.config) | |
| image_pipe.scheduler = scheduler | |
| image = image_pipe( | |
| prompt=prompt, | |
| image=control_image, | |
| strength=strength, | |
| guidance_scale=guidance, | |
| num_inference_steps=steps | |
| ).images[0] | |
| return image | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# ✨ Illusion Diffusion (Fixed)") | |
| with gr.Row(): | |
| prompt = gr.Textbox(label="Prompt") | |
| control_image = gr.Image(type="pil", label="Control Image") | |
| with gr.Row(): | |
| strength = gr.Slider(0.0, 1.0, value=0.8, label="Strength") | |
| guidance = gr.Slider(1.0, 15.0, value=7.0, label="Guidance Scale") | |
| steps = gr.Slider(5, 50, value=30, step=1, label="Steps") | |
| sampler = gr.Dropdown(list(SAMPLER_MAP.keys()), value="Euler", label="Sampler") | |
| btn = gr.Button("Generate") | |
| output = gr.Image() | |
| btn.click( | |
| fn=generate, | |
| inputs=[prompt, control_image, strength, guidance, steps, sampler], | |
| outputs=output | |
| ) | |
| demo.queue().launch() |