Spaces:
Running
on
Zero
Running
on
Zero
v22: Add prompt polishing feature (AI-enhanced prompts)
Browse files
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
-
"""Z-Image-Turbo
|
| 2 |
|
|
|
|
| 3 |
import torch
|
| 4 |
import spaces
|
| 5 |
import gradio as gr
|
|
@@ -7,12 +8,44 @@ import requests
|
|
| 7 |
import io
|
| 8 |
from PIL import Image
|
| 9 |
from diffusers import DiffusionPipeline, ZImageImg2ImgPipeline
|
|
|
|
| 10 |
|
| 11 |
# Enable optimized backends
|
| 12 |
torch.backends.cuda.enable_flash_sdp(True)
|
| 13 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 14 |
torch.backends.cudnn.benchmark = True
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
print("Loading Z-Image-Turbo pipeline...")
|
| 17 |
|
| 18 |
# Load text-to-image pipeline
|
|
@@ -113,17 +146,35 @@ def upload_to_hf_cdn(image):
|
|
| 113 |
except Exception as e:
|
| 114 |
return f"Error: {str(e)}"
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
@spaces.GPU
|
| 117 |
-
def generate(prompt, style, ratio, steps, seed, randomize, progress=gr.Progress(track_tqdm=True)):
|
| 118 |
if randomize:
|
| 119 |
seed = torch.randint(0, 2**32 - 1, (1,)).item()
|
| 120 |
seed = int(seed)
|
| 121 |
|
| 122 |
if not prompt or not prompt.strip():
|
| 123 |
-
return None, seed
|
| 124 |
|
| 125 |
w, h = RATIO_DIMS.get(ratio, (1024, 1024))
|
| 126 |
-
full_prompt = prompt
|
| 127 |
|
| 128 |
generator = torch.Generator("cuda").manual_seed(seed)
|
| 129 |
image = pipe_t2i(
|
|
@@ -135,7 +186,7 @@ def generate(prompt, style, ratio, steps, seed, randomize, progress=gr.Progress(
|
|
| 135 |
generator=generator,
|
| 136 |
).images[0]
|
| 137 |
|
| 138 |
-
return image, seed
|
| 139 |
|
| 140 |
@spaces.GPU
|
| 141 |
def transform(input_image, prompt, style, strength, steps, seed, randomize, progress=gr.Progress(track_tqdm=True)):
|
|
@@ -265,6 +316,7 @@ with gr.Blocks(title="Z-Image Turbo", css=css) as demo:
|
|
| 265 |
with gr.Row():
|
| 266 |
with gr.Column():
|
| 267 |
gen_prompt = gr.Textbox(label="Prompt", placeholder="Describe your image...", lines=3)
|
|
|
|
| 268 |
gen_style = gr.Dropdown(choices=STYLES, value="None", label="Style")
|
| 269 |
gen_ratio = gr.Dropdown(choices=RATIOS, value="1:1 Square (1024x1024)", label="Aspect Ratio")
|
| 270 |
gen_steps = gr.Slider(minimum=4, maximum=16, value=8, step=1, label="Steps")
|
|
@@ -275,6 +327,7 @@ with gr.Blocks(title="Z-Image Turbo", css=css) as demo:
|
|
| 275 |
|
| 276 |
with gr.Column():
|
| 277 |
gen_output = gr.Image(label="Generated Image", type="pil", format="png", interactive=False)
|
|
|
|
| 278 |
gen_seed_out = gr.Number(label="Seed Used", interactive=False)
|
| 279 |
with gr.Row():
|
| 280 |
gen_share_btn = gr.Button("📤 Share Image Link", variant="secondary")
|
|
@@ -282,8 +335,8 @@ with gr.Blocks(title="Z-Image Turbo", css=css) as demo:
|
|
| 282 |
|
| 283 |
gr.Examples(examples=EXAMPLES_GENERATE, inputs=[gen_prompt, gen_style, gen_ratio, gen_steps, gen_seed, gen_randomize])
|
| 284 |
|
| 285 |
-
gen_btn.click(fn=generate, inputs=[gen_prompt, gen_style, gen_ratio, gen_steps, gen_seed, gen_randomize], outputs=[gen_output, gen_seed_out])
|
| 286 |
-
gen_prompt.submit(fn=generate, inputs=[gen_prompt, gen_style, gen_ratio, gen_steps, gen_seed, gen_randomize], outputs=[gen_output, gen_seed_out])
|
| 287 |
gen_share_btn.click(fn=upload_to_hf_cdn, inputs=[gen_output], outputs=[gen_share_link])
|
| 288 |
|
| 289 |
# TAB 2: Transform Image
|
|
|
|
| 1 |
+
"""Z-Image-Turbo v22 - With prompt polishing feature"""
|
| 2 |
|
| 3 |
+
import os
|
| 4 |
import torch
|
| 5 |
import spaces
|
| 6 |
import gradio as gr
|
|
|
|
| 8 |
import io
|
| 9 |
from PIL import Image
|
| 10 |
from diffusers import DiffusionPipeline, ZImageImg2ImgPipeline
|
| 11 |
+
from huggingface_hub import InferenceClient
|
| 12 |
|
| 13 |
# Enable optimized backends
|
| 14 |
torch.backends.cuda.enable_flash_sdp(True)
|
| 15 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 16 |
torch.backends.cudnn.benchmark = True
|
| 17 |
|
| 18 |
+
# Prompt polishing using HF Inference API
|
| 19 |
+
def polish_prompt(original_prompt):
|
| 20 |
+
"""Expand short prompts into detailed, high-quality prompts using AI."""
|
| 21 |
+
if not original_prompt or not original_prompt.strip():
|
| 22 |
+
return "Ultra HD, 4K, cinematic composition, highly detailed"
|
| 23 |
+
|
| 24 |
+
api_key = os.environ.get("HF_TOKEN")
|
| 25 |
+
if not api_key:
|
| 26 |
+
return original_prompt # Return original if no token
|
| 27 |
+
|
| 28 |
+
system_prompt = """You are a prompt optimizer for AI image generation.
|
| 29 |
+
Rewrite the user's input into a detailed, expressive prompt that will produce stunning images.
|
| 30 |
+
Keep it under 150 words. Be descriptive about lighting, atmosphere, style, and details.
|
| 31 |
+
Do not explain - just output the improved prompt directly."""
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
client = InferenceClient(api_key=api_key)
|
| 35 |
+
completion = client.chat.completions.create(
|
| 36 |
+
model="Qwen/Qwen2.5-72B-Instruct",
|
| 37 |
+
max_tokens=200,
|
| 38 |
+
messages=[
|
| 39 |
+
{"role": "system", "content": system_prompt},
|
| 40 |
+
{"role": "user", "content": original_prompt}
|
| 41 |
+
],
|
| 42 |
+
)
|
| 43 |
+
polished = completion.choices[0].message.content
|
| 44 |
+
return polished.strip().replace("\n", " ")
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"Prompt polish error: {e}")
|
| 47 |
+
return original_prompt
|
| 48 |
+
|
| 49 |
print("Loading Z-Image-Turbo pipeline...")
|
| 50 |
|
| 51 |
# Load text-to-image pipeline
|
|
|
|
| 146 |
except Exception as e:
|
| 147 |
return f"Error: {str(e)}"
|
| 148 |
|
| 149 |
+
def prepare_prompt(prompt, style, do_polish):
|
| 150 |
+
"""Prepare the final prompt with optional polishing and style."""
|
| 151 |
+
if not prompt or not prompt.strip():
|
| 152 |
+
return "", ""
|
| 153 |
+
|
| 154 |
+
base_prompt = prompt.strip()
|
| 155 |
+
|
| 156 |
+
# Polish if enabled
|
| 157 |
+
if do_polish:
|
| 158 |
+
polished = polish_prompt(base_prompt)
|
| 159 |
+
else:
|
| 160 |
+
polished = base_prompt
|
| 161 |
+
|
| 162 |
+
# Add style suffix
|
| 163 |
+
final_prompt = polished + STYLE_SUFFIXES.get(style, "")
|
| 164 |
+
|
| 165 |
+
return final_prompt, polished
|
| 166 |
+
|
| 167 |
@spaces.GPU
|
| 168 |
+
def generate(prompt, style, ratio, steps, seed, randomize, do_polish, progress=gr.Progress(track_tqdm=True)):
|
| 169 |
if randomize:
|
| 170 |
seed = torch.randint(0, 2**32 - 1, (1,)).item()
|
| 171 |
seed = int(seed)
|
| 172 |
|
| 173 |
if not prompt or not prompt.strip():
|
| 174 |
+
return None, seed, ""
|
| 175 |
|
| 176 |
w, h = RATIO_DIMS.get(ratio, (1024, 1024))
|
| 177 |
+
full_prompt, polished_prompt = prepare_prompt(prompt, style, do_polish)
|
| 178 |
|
| 179 |
generator = torch.Generator("cuda").manual_seed(seed)
|
| 180 |
image = pipe_t2i(
|
|
|
|
| 186 |
generator=generator,
|
| 187 |
).images[0]
|
| 188 |
|
| 189 |
+
return image, seed, polished_prompt if do_polish else ""
|
| 190 |
|
| 191 |
@spaces.GPU
|
| 192 |
def transform(input_image, prompt, style, strength, steps, seed, randomize, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
| 316 |
with gr.Row():
|
| 317 |
with gr.Column():
|
| 318 |
gen_prompt = gr.Textbox(label="Prompt", placeholder="Describe your image...", lines=3)
|
| 319 |
+
gen_polish = gr.Checkbox(label="✨ Polish Prompt (AI-enhanced)", value=False)
|
| 320 |
gen_style = gr.Dropdown(choices=STYLES, value="None", label="Style")
|
| 321 |
gen_ratio = gr.Dropdown(choices=RATIOS, value="1:1 Square (1024x1024)", label="Aspect Ratio")
|
| 322 |
gen_steps = gr.Slider(minimum=4, maximum=16, value=8, step=1, label="Steps")
|
|
|
|
| 327 |
|
| 328 |
with gr.Column():
|
| 329 |
gen_output = gr.Image(label="Generated Image", type="pil", format="png", interactive=False)
|
| 330 |
+
gen_polished_prompt = gr.Textbox(label="Polished Prompt", interactive=False, visible=True, lines=2)
|
| 331 |
gen_seed_out = gr.Number(label="Seed Used", interactive=False)
|
| 332 |
with gr.Row():
|
| 333 |
gen_share_btn = gr.Button("📤 Share Image Link", variant="secondary")
|
|
|
|
| 335 |
|
| 336 |
gr.Examples(examples=EXAMPLES_GENERATE, inputs=[gen_prompt, gen_style, gen_ratio, gen_steps, gen_seed, gen_randomize])
|
| 337 |
|
| 338 |
+
gen_btn.click(fn=generate, inputs=[gen_prompt, gen_style, gen_ratio, gen_steps, gen_seed, gen_randomize, gen_polish], outputs=[gen_output, gen_seed_out, gen_polished_prompt])
|
| 339 |
+
gen_prompt.submit(fn=generate, inputs=[gen_prompt, gen_style, gen_ratio, gen_steps, gen_seed, gen_randomize, gen_polish], outputs=[gen_output, gen_seed_out, gen_polished_prompt])
|
| 340 |
gen_share_btn.click(fn=upload_to_hf_cdn, inputs=[gen_output], outputs=[gen_share_link])
|
| 341 |
|
| 342 |
# TAB 2: Transform Image
|