Z-Image-Turbo / aoti.py
lulavc's picture
Upload aoti.py with huggingface_hub
5fef830 verified
raw
history blame
4.94 kB
"""
ZeroGPU AoTI (Ahead-of-Time Inductor) compilation module for Z-Image-Turbo.
This module provides the compile_transformer_aoti() function that handles
the Z-Image-Turbo transformer's specific forward signature:
forward(x, t, cap_feats, return_dict=True)
Where:
- x: hidden states / latent sequence (dynamic shape)
- t: timestep
- cap_feats: caption/text features
- return_dict: whether to return dict or tuple
"""
import logging
import inspect
import torch
import spaces
logger = logging.getLogger(__name__)
def compile_transformer_aoti(
pipe,
example_prompt: str = "example prompt for compilation",
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 1,
inductor_configs: dict = None,
min_seq_len: int = 15360,
max_seq_len: int = 65536,
):
"""
Compile transformer ahead-of-time for 1.3x-1.8x speedup.
This function correctly handles the Z-Image-Turbo transformer's forward
signature which uses positional args (x, t, cap_feats) rather than kwargs.
Args:
pipe: The DiffusionPipeline instance
example_prompt: Prompt to use for capturing example inputs
height: Example image height
width: Example image width
num_inference_steps: Steps for example inference
inductor_configs: PyTorch Inductor configuration dict
min_seq_len: Minimum sequence length for dynamic shapes
max_seq_len: Maximum sequence length for dynamic shapes
Returns:
Compiled model path or None if compilation fails
"""
logger.info("Starting AoTI compilation for transformer...")
if inductor_configs is None:
inductor_configs = {}
try:
# Step 1: Capture example inputs
logger.info("Step 1/4: Capturing example inputs...")
with spaces.aoti_capture(pipe.transformer) as call:
pipe(
example_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=0.0,
)
# Step 2: Map positional args to parameter names
logger.info("Step 2/4: Configuring dynamic shapes...")
# Get the transformer's forward signature to map positional args
sig = inspect.signature(pipe.transformer.forward)
param_names = [p.name for p in sig.parameters.values() if p.name != 'self']
logger.info(f"Forward signature params: {param_names}")
logger.info(f"Captured positional args: {len(call.args)}")
logger.info(f"Captured kwargs keys: {list(call.kwargs.keys())}")
# Convert positional args to named kwargs
args_as_kwargs = {}
for i, arg in enumerate(call.args):
if i < len(param_names):
args_as_kwargs[param_names[i]] = arg
# Combine with actual kwargs
combined_kwargs = {**args_as_kwargs, **call.kwargs}
logger.info(f"Combined kwargs keys: {list(combined_kwargs.keys())}")
# Step 3: Define dynamic shapes for the sequence dimension
from torch.export import Dim
from torch.utils._pytree import tree_map
# Create base dynamic shapes (all None)
dynamic_shapes = tree_map(lambda v: None, combined_kwargs)
# Define dynamic dimension for sequence length
batch_dim = Dim("batch", min=1, max=4)
seq_len_dim = Dim("seq_len", min=min_seq_len, max=max_seq_len)
# Apply dynamic shapes to the latent input (x)
# x shape is typically (batch, seq_len, hidden_dim)
if 'x' in combined_kwargs:
x_tensor = combined_kwargs['x']
if hasattr(x_tensor, 'shape') and len(x_tensor.shape) >= 2:
dynamic_shapes['x'] = {0: batch_dim, 1: seq_len_dim}
logger.info(f"Set dynamic shapes for 'x': batch={batch_dim}, seq_len={seq_len_dim}")
# Step 4: Export the model
logger.info("Step 3/4: Exporting model with torch.export...")
# Export with all inputs as kwargs (no positional args)
exported = torch.export.export(
pipe.transformer,
args=(), # Empty - all inputs via kwargs
kwargs=combined_kwargs,
dynamic_shapes=dynamic_shapes,
)
# Step 5: Compile with inductor
logger.info("Step 4/4: Compiling with PyTorch Inductor (this takes several minutes)...")
compiled = spaces.aoti_compile(exported, inductor_configs)
logger.info("AoTI compilation completed successfully!")
return compiled
except Exception as e:
logger.error(f"AoTI compilation failed: {type(e).__name__}: {str(e)}")
import traceback
logger.error(traceback.format_exc())
logger.warning("Falling back to non-compiled transformer")
return None