Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import re | |
| import json | |
| import random | |
| import logging | |
| import warnings | |
| from dataclasses import dataclass | |
| import gradio as gr | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| import spaces | |
| from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # ------------------------- 可选依赖:Prompt Enhancer 模板 ------------------------- | |
| # 你的原工程里如果有 pe.py,会自动使用;没有也不会报错(enhance 默认关闭) | |
| try: | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| from pe import prompt_template # type: ignore | |
| except Exception: | |
| prompt_template = ( | |
| "You are a helpful prompt engineer. Expand the user prompt into a richer, detailed prompt. " | |
| "Return JSON with key revised_prompt." | |
| ) | |
| # ------------------------- Z-Image 相关(依赖你环境中 diffusers 的实现) ------------------------- | |
| from diffusers import ZImagePipeline | |
| from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel | |
| # ==================== Environment Variables ================================== | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") | |
| ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true" | |
| ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true" | |
| ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3") | |
| DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # ============================================================================= | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| warnings.filterwarnings("ignore") | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 | |
| RES_CHOICES = { | |
| "1024": [ | |
| "1024x1024 ( 1:1 )", | |
| "1152x896 ( 9:7 )", | |
| "896x1152 ( 7:9 )", | |
| "1152x864 ( 4:3 )", | |
| "864x1152 ( 3:4 )", | |
| "1248x832 ( 3:2 )", | |
| "832x1248 ( 2:3 )", | |
| "1280x720 ( 16:9 )", | |
| "720x1280 ( 9:16 )", | |
| "1344x576 ( 21:9 )", | |
| "576x1344 ( 9:21 )", | |
| ], | |
| "1280": [ | |
| "1280x1280 ( 1:1 )", | |
| "1440x1120 ( 9:7 )", | |
| "1120x1440 ( 7:9 )", | |
| "1472x1104 ( 4:3 )", | |
| "1104x1472 ( 3:4 )", | |
| "1536x1024 ( 3:2 )", | |
| "1024x1536 ( 2:3 )", | |
| "1536x864 ( 16:9 )", | |
| "864x1536 ( 9:16 )", | |
| "1680x720 ( 21:9 )", | |
| "720x1680 ( 9:21 )", | |
| ], | |
| "1536": [ | |
| "1536x1536 ( 1:1 )", | |
| "1728x1344 ( 9:7 )", | |
| "1344x1728 ( 7:9 )", | |
| "1728x1296 ( 4:3 )", | |
| "1296x1728 ( 3:4 )", | |
| "1872x1248 ( 3:2 )", | |
| "1248x1872 ( 2:3 )", | |
| "2048x1152 ( 16:9 )", | |
| "1152x2048 ( 9:16 )", | |
| "2016x864 ( 21:9 )", | |
| "864x2016 ( 9:21 )", | |
| ], | |
| } | |
| RESOLUTION_SET = [] | |
| for _k, v in RES_CHOICES.items(): | |
| RESOLUTION_SET.extend(v) | |
| EXAMPLE_PROMPTS = [ | |
| ["一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。"], | |
| ["极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。一束强光通过遮光板,在她的脸上投射出一个清晰的闪电形状的光影,正好照亮一只眼睛。高对比度,明暗交界清晰,神秘感,莱卡相机色调。"], | |
| ] | |
| # ------------------------- HF token 兼容参数 ------------------------- | |
| def _hf_token_kwargs(token: str | None): | |
| """ | |
| transformers / diffusers 的 from_pretrained 近年来从 use_auth_token 迁移到 token。 | |
| 这里做一个兼容:优先传 token,不支持则回退 use_auth_token。 | |
| """ | |
| if not token: | |
| return {} | |
| return {"token": token, "use_auth_token": token} | |
| def get_resolution(resolution: str): | |
| match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution) | |
| if match: | |
| return int(match.group(1)), int(match.group(2)) | |
| return 1024, 1024 | |
| def _make_blocked_image(width=1024, height=1024, text="Blocked by Safety Checker"): | |
| img = Image.new("RGB", (width, height), (20, 20, 20)) | |
| draw = ImageDraw.Draw(img) | |
| try: | |
| font = ImageFont.load_default() | |
| except Exception: | |
| font = None | |
| draw.rectangle([0, 0, width, 90], fill=(160, 0, 0)) | |
| draw.text((20, 30), text, fill=(255, 255, 255), font=font) | |
| return img | |
| def _load_nsfw_placeholder(width=1024, height=1024): | |
| """ | |
| 命中 NSFW 时优先加载工作目录的 nsfw.png; | |
| 不存在就生成一张占位图,避免文件缺失导致再次报错。 | |
| """ | |
| if os.path.exists("nsfw.png"): | |
| try: | |
| return Image.open("nsfw.png").convert("RGB") | |
| except Exception: | |
| pass | |
| return _make_blocked_image(width, height, "NSFW blocked") | |
| def load_models(model_path: str, enable_compile=False, attention_backend="native"): | |
| print(f"[Init] Loading models from: {model_path}") | |
| print(f"[Init] DEVICE={DEVICE}, DTYPE={DTYPE}, ENABLE_COMPILE={enable_compile}, ATTENTION_BACKEND={attention_backend}") | |
| # 远端 repo-id(不存在的本地路径) vs 本地目录 | |
| is_local_dir = os.path.exists(model_path) | |
| token_kwargs = _hf_token_kwargs(HF_TOKEN) if not is_local_dir else {} | |
| # 1) VAE | |
| if not is_local_dir: | |
| vae = AutoencoderKL.from_pretrained( | |
| model_path, | |
| subfolder="vae", | |
| torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32, | |
| **token_kwargs, | |
| ) | |
| else: | |
| vae = AutoencoderKL.from_pretrained( | |
| os.path.join(model_path, "vae"), | |
| torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32, | |
| ) | |
| # 2) Text Encoder + Tokenizer | |
| if not is_local_dir: | |
| text_encoder = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| subfolder="text_encoder", | |
| torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32, | |
| **token_kwargs, | |
| ).eval() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| subfolder="tokenizer", | |
| **token_kwargs, | |
| ) | |
| else: | |
| text_encoder = AutoModelForCausalLM.from_pretrained( | |
| os.path.join(model_path, "text_encoder"), | |
| torch_dtype=DTYPE if DEVICE == "cuda" else torch.float32, | |
| ).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer")) | |
| tokenizer.padding_side = "left" | |
| # compile 优化(仅 CUDA 才建议打开) | |
| if enable_compile and DEVICE == "cuda": | |
| print("[Init] Enabling torch.compile optimizations...") | |
| torch._inductor.config.conv_1x1_as_mm = True | |
| torch._inductor.config.coordinate_descent_tuning = True | |
| torch._inductor.config.epilogue_fusion = False | |
| torch._inductor.config.coordinate_descent_check_all_directions = True | |
| torch._inductor.config.max_autotune_gemm = True | |
| torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN" | |
| torch._inductor.config.triton.cudagraphs = False | |
| pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None) | |
| # 3) Transformer | |
| if not is_local_dir: | |
| transformer = ZImageTransformer2DModel.from_pretrained( | |
| model_path, | |
| subfolder="transformer", | |
| **token_kwargs, | |
| ) | |
| else: | |
| transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer")) | |
| transformer = transformer.to(DEVICE, DTYPE) | |
| pipe.transformer = transformer | |
| # attention backend 可能在不同环境不支持,做容错 | |
| try: | |
| pipe.transformer.set_attention_backend(attention_backend) | |
| except Exception as e: | |
| print(f"[Init] set_attention_backend('{attention_backend}') failed, fallback to 'native'. Error: {e}") | |
| try: | |
| pipe.transformer.set_attention_backend("native") | |
| except Exception as e2: | |
| print(f"[Init] fallback set_attention_backend('native') failed: {e2}") | |
| if enable_compile and DEVICE == "cuda": | |
| try: | |
| print("[Init] Compiling transformer...") | |
| pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False) | |
| except Exception as e: | |
| print(f"[Init] torch.compile failed, continue without compile. Error: {e}") | |
| pipe = pipe.to(DEVICE, DTYPE) | |
| # 4) Safety Checker(用于生成后过滤) | |
| try: | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
| try: | |
| from transformers import CLIPImageProcessor as _CLIPProcessor | |
| except Exception: | |
| # 老版本兼容 | |
| from transformers import CLIPFeatureExtractor as _CLIPProcessor # type: ignore | |
| safety_model_id = "CompVis/stable-diffusion-safety-checker" | |
| safety_feature_extractor = _CLIPProcessor.from_pretrained(safety_model_id, **_hf_token_kwargs(HF_TOKEN)) | |
| safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
| safety_model_id, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| **_hf_token_kwargs(HF_TOKEN), | |
| ).to(DEVICE) | |
| pipe.safety_feature_extractor = safety_feature_extractor | |
| pipe.safety_checker = safety_checker | |
| print("[Init] Safety checker loaded.") | |
| except Exception as e: | |
| print(f"[Init] Safety checker init failed. NSFW filtering will be skipped. Error: {e}") | |
| pipe.safety_feature_extractor = None | |
| pipe.safety_checker = None | |
| return pipe | |
| def generate_image( | |
| pipe, | |
| prompt: str, | |
| resolution="1024x1024", | |
| seed=42, | |
| guidance_scale=5.0, | |
| num_inference_steps=50, | |
| shift=3.0, | |
| max_sequence_length=512, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| width, height = get_resolution(resolution) | |
| if DEVICE == "cuda": | |
| generator = torch.Generator(device="cuda").manual_seed(int(seed)) | |
| else: | |
| generator = torch.Generator().manual_seed(int(seed)) | |
| scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(shift)) | |
| pipe.scheduler = scheduler | |
| out = pipe( | |
| prompt=prompt, | |
| height=int(height), | |
| width=int(width), | |
| guidance_scale=float(guidance_scale), | |
| num_inference_steps=int(num_inference_steps), | |
| generator=generator, | |
| max_sequence_length=int(max_sequence_length), | |
| ) | |
| image = out.images[0] | |
| return image | |
| def warmup_model(pipe, resolutions): | |
| print("[Warmup] Starting warmup phase...") | |
| dummy_prompt = "warmup" | |
| for res_str in resolutions: | |
| print(f"[Warmup] Resolution: {res_str}") | |
| try: | |
| for i in range(2): | |
| generate_image( | |
| pipe, | |
| prompt=dummy_prompt, | |
| resolution=res_str.split(" ")[0], | |
| num_inference_steps=6, | |
| guidance_scale=0.0, | |
| seed=42 + i, | |
| ) | |
| except Exception as e: | |
| print(f"[Warmup] Failed for {res_str}: {e}") | |
| print("[Warmup] Completed.") | |
| # ==================== Prompt Expander(保留但默认不启用) ==================== | |
| class PromptOutput: | |
| status: bool | |
| prompt: str | |
| seed: int | |
| system_prompt: str | |
| message: str | |
| class PromptExpander: | |
| def __init__(self, backend="api", **kwargs): | |
| self.backend = backend | |
| def decide_system_prompt(self, template_name=None): | |
| return prompt_template | |
| class APIPromptExpander(PromptExpander): | |
| def __init__(self, api_config=None, **kwargs): | |
| super().__init__(backend="api", **kwargs) | |
| self.api_config = api_config or {} | |
| self.client = self._init_api_client() | |
| def _init_api_client(self): | |
| try: | |
| from openai import OpenAI | |
| api_key = self.api_config.get("api_key") or DASHSCOPE_API_KEY | |
| base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1") | |
| if not api_key: | |
| print("[PE] Warning: DASHSCOPE_API_KEY not found. Prompt enhance unavailable.") | |
| return None | |
| return OpenAI(api_key=api_key, base_url=base_url) | |
| except ImportError: | |
| print("[PE] Please install openai: pip install openai") | |
| return None | |
| except Exception as e: | |
| print(f"[PE] Failed to initialize API client: {e}") | |
| return None | |
| def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs): | |
| return self.extend(prompt, system_prompt, seed, **kwargs) | |
| def extend(self, prompt, system_prompt=None, seed=-1, **kwargs): | |
| if self.client is None: | |
| return PromptOutput(False, "", seed, system_prompt or "", "API client not initialized") | |
| if system_prompt is None: | |
| system_prompt = self.decide_system_prompt() | |
| if "{prompt}" in system_prompt: | |
| system_prompt = system_prompt.format(prompt=prompt) | |
| prompt = " " | |
| try: | |
| model = self.api_config.get("model", "qwen3-max-preview") | |
| response = self.client.chat.completions.create( | |
| model=model, | |
| messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}], | |
| temperature=0.7, | |
| top_p=0.8, | |
| ) | |
| content = response.choices[0].message.content or "" | |
| # 尝试从 ```json 块中解析 revised_prompt | |
| expanded_prompt = content | |
| json_start = content.find("```json") | |
| if json_start != -1: | |
| json_end = content.find("```", json_start + 7) | |
| if json_end != -1: | |
| json_str = content[json_start + 7 : json_end].strip() | |
| try: | |
| data = json.loads(json_str) | |
| expanded_prompt = data.get("revised_prompt", content) | |
| except Exception: | |
| expanded_prompt = content | |
| return PromptOutput(True, expanded_prompt, seed, system_prompt, content) | |
| except Exception as e: | |
| return PromptOutput(False, "", seed, system_prompt, str(e)) | |
| def create_prompt_expander(backend="api", **kwargs): | |
| if backend == "api": | |
| return APIPromptExpander(**kwargs) | |
| raise ValueError("Only 'api' backend is supported.") | |
| pipe = None | |
| prompt_expander = None | |
| def init_app(): | |
| global pipe, prompt_expander | |
| try: | |
| pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND) | |
| print("[Init] Model loaded.") | |
| if ENABLE_WARMUP and pipe is not None: | |
| all_res = [] | |
| for cat in RES_CHOICES.values(): | |
| all_res.extend(cat) | |
| warmup_model(pipe, all_res) | |
| except Exception as e: | |
| print(f"[Init] Error loading model: {e}") | |
| pipe = None | |
| try: | |
| prompt_expander = create_prompt_expander(backend="api", api_config={"model": "qwen3-max-preview"}) | |
| print("[Init] Prompt expander ready (disabled by default).") | |
| except Exception as e: | |
| print(f"[Init] Error initializing prompt expander: {e}") | |
| prompt_expander = None | |
| def prompt_enhance(prompt, enable_enhance: bool): | |
| if not enable_enhance or not prompt_expander: | |
| return prompt, "Enhancement disabled or unavailable." | |
| if not prompt.strip(): | |
| return "", "Please enter a prompt." | |
| try: | |
| result = prompt_expander(prompt) | |
| if result.status: | |
| return result.prompt, result.message | |
| return prompt, f"Enhancement failed: {result.message}" | |
| except Exception as e: | |
| return prompt, f"Error: {str(e)}" | |
| def try_enable_aoti(pipe): | |
| """ | |
| AoTI(ZeroGPU 加速)可用则启用;不可用则跳过,不影响主流程。 | |
| """ | |
| if pipe is None: | |
| return | |
| try: | |
| # 优先按你原代码的结构尝试:pipe.transformer.layers | |
| if hasattr(pipe, "transformer") and pipe.transformer is not None: | |
| target = None | |
| if hasattr(pipe.transformer, "layers"): | |
| target = pipe.transformer.layers | |
| if hasattr(target, "_repeated_blocks"): | |
| target._repeated_blocks = ["ZImageTransformerBlock"] | |
| else: | |
| # 兜底:直接对 transformer 设置 | |
| target = pipe.transformer | |
| if hasattr(target, "_repeated_blocks"): | |
| target._repeated_blocks = ["ZImageTransformerBlock"] | |
| if target is not None: | |
| spaces.aoti_blocks_load(target, "zerogpu-aoti/Z-Image", variant="fa3") | |
| print("[Init] AoTI blocks loaded.") | |
| except Exception as e: | |
| print(f"[Init] AoTI not enabled (safe to ignore). Error: {e}") | |
| def generate( | |
| prompt, | |
| resolution="1024x1024 ( 1:1 )", | |
| seed=42, | |
| steps=9, | |
| shift=3.0, | |
| random_seed=True, | |
| gallery_images=None, | |
| enhance=False, # 默认不启用 | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if random_seed: | |
| new_seed = random.randint(1, 1000000) | |
| else: | |
| new_seed = int(seed) if int(seed) != -1 else random.randint(1, 1000000) | |
| if pipe is None: | |
| raise gr.Error("Model not loaded. Please check logs.") | |
| final_prompt = prompt or "" | |
| if enhance: | |
| # 你原注释说 DISABLED,这里仍保留能力但默认关闭 | |
| final_prompt, _msg = prompt_enhance(final_prompt, True) | |
| print(f"[PE] Enhanced prompt: {final_prompt}") | |
| # 解析 "1024x1024 ( 1:1 )" -> "1024x1024" | |
| try: | |
| resolution_str = str(resolution).split(" ")[0] | |
| except Exception: | |
| resolution_str = "1024x1024" | |
| width, height = get_resolution(resolution_str) | |
| # 生成 | |
| image = generate_image( | |
| pipe=pipe, | |
| prompt=final_prompt, | |
| resolution=resolution_str, | |
| seed=new_seed, | |
| guidance_scale=0.0, | |
| num_inference_steps=int(steps) + 1, | |
| shift=float(shift), | |
| ) | |
| # 生成后 NSFW 安全检查(已去掉 prompt_check) | |
| try: | |
| if getattr(pipe, "safety_feature_extractor", None) is not None and getattr(pipe, "safety_checker", None) is not None: | |
| # CLIP 输入 | |
| clip_inputs = pipe.safety_feature_extractor([image], return_tensors="pt") | |
| clip_input = clip_inputs.pixel_values.to(DEVICE) | |
| # SafetyChecker 需要 numpy 格式图片(batch, H, W, C),float32 0-1 | |
| import numpy as np | |
| img_np = np.array(image).astype("float32") / 255.0 | |
| img_np = img_np[None, ...] | |
| checked_images, has_nsfw = pipe.safety_checker(images=img_np, clip_input=clip_input) | |
| # has_nsfw 一般是 list[bool] | |
| if isinstance(has_nsfw, (list, tuple)) and len(has_nsfw) > 0 and bool(has_nsfw[0]): | |
| image = _load_nsfw_placeholder(width, height) | |
| except Exception as e: | |
| # Safety checker 失败不应阻塞主流程 | |
| print(f"[Safety] Check failed (ignored): {e}") | |
| if gallery_images is None: | |
| gallery_images = [] | |
| gallery_images = [image] + list(gallery_images) | |
| return gallery_images, str(new_seed), int(new_seed) | |
| # ------------------------- 启动初始化 ------------------------- | |
| init_app() | |
| try_enable_aoti(pipe) | |
| # ==================== Gradio UI ==================== | |
| with gr.Blocks(title="Z-Image Demo") as demo: | |
| gr.Markdown( | |
| """<div align="center"> | |
| # Z-Image Generation Demo | |
| *An Efficient Image Generation Foundation Model with Single-Stream Diffusion Transformer* | |
| </div>""" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here...") | |
| with gr.Row(): | |
| choices = [int(k) for k in RES_CHOICES.keys()] | |
| res_cat = gr.Dropdown(value=1024, choices=choices, label="Resolution Category") | |
| initial_res_choices = RES_CHOICES["1024"] | |
| resolution = gr.Dropdown( | |
| value=initial_res_choices[0], | |
| choices=RESOLUTION_SET, | |
| label="Width x Height (Ratio)", | |
| ) | |
| with gr.Row(): | |
| seed = gr.Number(label="Seed", value=42, precision=0) | |
| random_seed = gr.Checkbox(label="Random Seed", value=True) | |
| with gr.Row(): | |
| steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=8, step=1, interactive=False) | |
| shift = gr.Slider(label="Time Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1) | |
| # 注意:enhance 默认不开启(你原本也标注 DISABLED) | |
| # enhance = gr.Checkbox(label="Enhance Prompt (DashScope)", value=False) | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| gr.Markdown("### 📝 Example Prompts") | |
| gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input, label=None) | |
| with gr.Column(scale=1): | |
| output_gallery = gr.Gallery( | |
| label="Generated Images", | |
| columns=2, | |
| rows=2, | |
| height=600, | |
| object_fit="contain", | |
| format="png", | |
| interactive=False, | |
| ) | |
| used_seed = gr.Textbox(label="Seed Used", interactive=False) | |
| def update_res_choices(_res_cat): | |
| if str(_res_cat) in RES_CHOICES: | |
| res_choices = RES_CHOICES[str(_res_cat)] | |
| else: | |
| res_choices = RES_CHOICES["1024"] | |
| return gr.update(value=res_choices[0], choices=res_choices) | |
| res_cat.change(update_res_choices, inputs=res_cat, outputs=resolution) | |
| generate_btn.click( | |
| generate, | |
| inputs=[prompt_input, resolution, seed, steps, shift, random_seed, output_gallery], | |
| outputs=[output_gallery, used_seed, seed], | |
| ) | |
| css = """ | |
| .fillable{max-width: 1230px !important} | |
| """ | |
| if __name__ == "__main__": | |
| # Gradio 新版本支持 mcp_server;若你环境版本较旧报错,把 mcp_server=True 去掉即可 | |
| demo.launch(css=css, mcp_server=True) | |