lulavc commited on
Commit
425254d
·
verified ·
1 Parent(s): 858dcd4

v22: Add prompt polishing feature (AI-enhanced prompts)

Browse files
Files changed (1) hide show
  1. app.py +60 -7
app.py CHANGED
@@ -1,5 +1,6 @@
1
- """Z-Image-Turbo v21 - Stable baseline (no AoTI)"""
2
 
 
3
  import torch
4
  import spaces
5
  import gradio as gr
@@ -7,12 +8,44 @@ import requests
7
  import io
8
  from PIL import Image
9
  from diffusers import DiffusionPipeline, ZImageImg2ImgPipeline
 
10
 
11
  # Enable optimized backends
12
  torch.backends.cuda.enable_flash_sdp(True)
13
  torch.backends.cuda.enable_mem_efficient_sdp(True)
14
  torch.backends.cudnn.benchmark = True
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  print("Loading Z-Image-Turbo pipeline...")
17
 
18
  # Load text-to-image pipeline
@@ -113,17 +146,35 @@ def upload_to_hf_cdn(image):
113
  except Exception as e:
114
  return f"Error: {str(e)}"
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  @spaces.GPU
117
- def generate(prompt, style, ratio, steps, seed, randomize, progress=gr.Progress(track_tqdm=True)):
118
  if randomize:
119
  seed = torch.randint(0, 2**32 - 1, (1,)).item()
120
  seed = int(seed)
121
 
122
  if not prompt or not prompt.strip():
123
- return None, seed
124
 
125
  w, h = RATIO_DIMS.get(ratio, (1024, 1024))
126
- full_prompt = prompt.strip() + STYLE_SUFFIXES.get(style, "")
127
 
128
  generator = torch.Generator("cuda").manual_seed(seed)
129
  image = pipe_t2i(
@@ -135,7 +186,7 @@ def generate(prompt, style, ratio, steps, seed, randomize, progress=gr.Progress(
135
  generator=generator,
136
  ).images[0]
137
 
138
- return image, seed
139
 
140
  @spaces.GPU
141
  def transform(input_image, prompt, style, strength, steps, seed, randomize, progress=gr.Progress(track_tqdm=True)):
@@ -265,6 +316,7 @@ with gr.Blocks(title="Z-Image Turbo", css=css) as demo:
265
  with gr.Row():
266
  with gr.Column():
267
  gen_prompt = gr.Textbox(label="Prompt", placeholder="Describe your image...", lines=3)
 
268
  gen_style = gr.Dropdown(choices=STYLES, value="None", label="Style")
269
  gen_ratio = gr.Dropdown(choices=RATIOS, value="1:1 Square (1024x1024)", label="Aspect Ratio")
270
  gen_steps = gr.Slider(minimum=4, maximum=16, value=8, step=1, label="Steps")
@@ -275,6 +327,7 @@ with gr.Blocks(title="Z-Image Turbo", css=css) as demo:
275
 
276
  with gr.Column():
277
  gen_output = gr.Image(label="Generated Image", type="pil", format="png", interactive=False)
 
278
  gen_seed_out = gr.Number(label="Seed Used", interactive=False)
279
  with gr.Row():
280
  gen_share_btn = gr.Button("📤 Share Image Link", variant="secondary")
@@ -282,8 +335,8 @@ with gr.Blocks(title="Z-Image Turbo", css=css) as demo:
282
 
283
  gr.Examples(examples=EXAMPLES_GENERATE, inputs=[gen_prompt, gen_style, gen_ratio, gen_steps, gen_seed, gen_randomize])
284
 
285
- gen_btn.click(fn=generate, inputs=[gen_prompt, gen_style, gen_ratio, gen_steps, gen_seed, gen_randomize], outputs=[gen_output, gen_seed_out])
286
- gen_prompt.submit(fn=generate, inputs=[gen_prompt, gen_style, gen_ratio, gen_steps, gen_seed, gen_randomize], outputs=[gen_output, gen_seed_out])
287
  gen_share_btn.click(fn=upload_to_hf_cdn, inputs=[gen_output], outputs=[gen_share_link])
288
 
289
  # TAB 2: Transform Image
 
1
+ """Z-Image-Turbo v22 - With prompt polishing feature"""
2
 
3
+ import os
4
  import torch
5
  import spaces
6
  import gradio as gr
 
8
  import io
9
  from PIL import Image
10
  from diffusers import DiffusionPipeline, ZImageImg2ImgPipeline
11
+ from huggingface_hub import InferenceClient
12
 
13
  # Enable optimized backends
14
  torch.backends.cuda.enable_flash_sdp(True)
15
  torch.backends.cuda.enable_mem_efficient_sdp(True)
16
  torch.backends.cudnn.benchmark = True
17
 
18
+ # Prompt polishing using HF Inference API
19
+ def polish_prompt(original_prompt):
20
+ """Expand short prompts into detailed, high-quality prompts using AI."""
21
+ if not original_prompt or not original_prompt.strip():
22
+ return "Ultra HD, 4K, cinematic composition, highly detailed"
23
+
24
+ api_key = os.environ.get("HF_TOKEN")
25
+ if not api_key:
26
+ return original_prompt # Return original if no token
27
+
28
+ system_prompt = """You are a prompt optimizer for AI image generation.
29
+ Rewrite the user's input into a detailed, expressive prompt that will produce stunning images.
30
+ Keep it under 150 words. Be descriptive about lighting, atmosphere, style, and details.
31
+ Do not explain - just output the improved prompt directly."""
32
+
33
+ try:
34
+ client = InferenceClient(api_key=api_key)
35
+ completion = client.chat.completions.create(
36
+ model="Qwen/Qwen2.5-72B-Instruct",
37
+ max_tokens=200,
38
+ messages=[
39
+ {"role": "system", "content": system_prompt},
40
+ {"role": "user", "content": original_prompt}
41
+ ],
42
+ )
43
+ polished = completion.choices[0].message.content
44
+ return polished.strip().replace("\n", " ")
45
+ except Exception as e:
46
+ print(f"Prompt polish error: {e}")
47
+ return original_prompt
48
+
49
  print("Loading Z-Image-Turbo pipeline...")
50
 
51
  # Load text-to-image pipeline
 
146
  except Exception as e:
147
  return f"Error: {str(e)}"
148
 
149
+ def prepare_prompt(prompt, style, do_polish):
150
+ """Prepare the final prompt with optional polishing and style."""
151
+ if not prompt or not prompt.strip():
152
+ return "", ""
153
+
154
+ base_prompt = prompt.strip()
155
+
156
+ # Polish if enabled
157
+ if do_polish:
158
+ polished = polish_prompt(base_prompt)
159
+ else:
160
+ polished = base_prompt
161
+
162
+ # Add style suffix
163
+ final_prompt = polished + STYLE_SUFFIXES.get(style, "")
164
+
165
+ return final_prompt, polished
166
+
167
  @spaces.GPU
168
+ def generate(prompt, style, ratio, steps, seed, randomize, do_polish, progress=gr.Progress(track_tqdm=True)):
169
  if randomize:
170
  seed = torch.randint(0, 2**32 - 1, (1,)).item()
171
  seed = int(seed)
172
 
173
  if not prompt or not prompt.strip():
174
+ return None, seed, ""
175
 
176
  w, h = RATIO_DIMS.get(ratio, (1024, 1024))
177
+ full_prompt, polished_prompt = prepare_prompt(prompt, style, do_polish)
178
 
179
  generator = torch.Generator("cuda").manual_seed(seed)
180
  image = pipe_t2i(
 
186
  generator=generator,
187
  ).images[0]
188
 
189
+ return image, seed, polished_prompt if do_polish else ""
190
 
191
  @spaces.GPU
192
  def transform(input_image, prompt, style, strength, steps, seed, randomize, progress=gr.Progress(track_tqdm=True)):
 
316
  with gr.Row():
317
  with gr.Column():
318
  gen_prompt = gr.Textbox(label="Prompt", placeholder="Describe your image...", lines=3)
319
+ gen_polish = gr.Checkbox(label="✨ Polish Prompt (AI-enhanced)", value=False)
320
  gen_style = gr.Dropdown(choices=STYLES, value="None", label="Style")
321
  gen_ratio = gr.Dropdown(choices=RATIOS, value="1:1 Square (1024x1024)", label="Aspect Ratio")
322
  gen_steps = gr.Slider(minimum=4, maximum=16, value=8, step=1, label="Steps")
 
327
 
328
  with gr.Column():
329
  gen_output = gr.Image(label="Generated Image", type="pil", format="png", interactive=False)
330
+ gen_polished_prompt = gr.Textbox(label="Polished Prompt", interactive=False, visible=True, lines=2)
331
  gen_seed_out = gr.Number(label="Seed Used", interactive=False)
332
  with gr.Row():
333
  gen_share_btn = gr.Button("📤 Share Image Link", variant="secondary")
 
335
 
336
  gr.Examples(examples=EXAMPLES_GENERATE, inputs=[gen_prompt, gen_style, gen_ratio, gen_steps, gen_seed, gen_randomize])
337
 
338
+ gen_btn.click(fn=generate, inputs=[gen_prompt, gen_style, gen_ratio, gen_steps, gen_seed, gen_randomize, gen_polish], outputs=[gen_output, gen_seed_out, gen_polished_prompt])
339
+ gen_prompt.submit(fn=generate, inputs=[gen_prompt, gen_style, gen_ratio, gen_steps, gen_seed, gen_randomize, gen_polish], outputs=[gen_output, gen_seed_out, gen_polished_prompt])
340
  gen_share_btn.click(fn=upload_to_hf_cdn, inputs=[gen_output], outputs=[gen_share_link])
341
 
342
  # TAB 2: Transform Image