""" ZeroGPU AoTI (Ahead-of-Time Inductor) compilation module for Z-Image-Turbo. This module provides the compile_transformer_aoti() function that handles the Z-Image-Turbo transformer's specific forward signature: forward(x, t, cap_feats, return_dict=True) Where: - x: hidden states / latent sequence (dynamic shape) - t: timestep - cap_feats: caption/text features - return_dict: whether to return dict or tuple """ import logging import inspect import torch import spaces logger = logging.getLogger(__name__) def compile_transformer_aoti( pipe, example_prompt: str = "example prompt for compilation", height: int = 1024, width: int = 1024, num_inference_steps: int = 1, inductor_configs: dict = None, min_seq_len: int = 15360, max_seq_len: int = 65536, ): """ Compile transformer ahead-of-time for 1.3x-1.8x speedup. This function correctly handles the Z-Image-Turbo transformer's forward signature which uses positional args (x, t, cap_feats) rather than kwargs. Args: pipe: The DiffusionPipeline instance example_prompt: Prompt to use for capturing example inputs height: Example image height width: Example image width num_inference_steps: Steps for example inference inductor_configs: PyTorch Inductor configuration dict min_seq_len: Minimum sequence length for dynamic shapes max_seq_len: Maximum sequence length for dynamic shapes Returns: Compiled model path or None if compilation fails """ logger.info("Starting AoTI compilation for transformer...") if inductor_configs is None: inductor_configs = {} try: # Step 1: Capture example inputs logger.info("Step 1/4: Capturing example inputs...") with spaces.aoti_capture(pipe.transformer) as call: pipe( example_prompt, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=0.0, ) # Step 2: Map positional args to parameter names logger.info("Step 2/4: Configuring dynamic shapes...") # Get the transformer's forward signature to map positional args sig = inspect.signature(pipe.transformer.forward) param_names = [p.name for p in sig.parameters.values() if p.name != 'self'] logger.info(f"Forward signature params: {param_names}") logger.info(f"Captured positional args: {len(call.args)}") logger.info(f"Captured kwargs keys: {list(call.kwargs.keys())}") # Convert positional args to named kwargs args_as_kwargs = {} for i, arg in enumerate(call.args): if i < len(param_names): args_as_kwargs[param_names[i]] = arg # Combine with actual kwargs combined_kwargs = {**args_as_kwargs, **call.kwargs} logger.info(f"Combined kwargs keys: {list(combined_kwargs.keys())}") # Step 3: Define dynamic shapes for the sequence dimension from torch.export import Dim from torch.utils._pytree import tree_map # Create base dynamic shapes (all None) dynamic_shapes = tree_map(lambda v: None, combined_kwargs) # Define dynamic dimension for sequence length batch_dim = Dim("batch", min=1, max=4) seq_len_dim = Dim("seq_len", min=min_seq_len, max=max_seq_len) # Apply dynamic shapes to the latent input (x) # x shape is typically (batch, seq_len, hidden_dim) if 'x' in combined_kwargs: x_tensor = combined_kwargs['x'] if hasattr(x_tensor, 'shape') and len(x_tensor.shape) >= 2: dynamic_shapes['x'] = {0: batch_dim, 1: seq_len_dim} logger.info(f"Set dynamic shapes for 'x': batch={batch_dim}, seq_len={seq_len_dim}") # Step 4: Export the model logger.info("Step 3/4: Exporting model with torch.export...") # Export with all inputs as kwargs (no positional args) exported = torch.export.export( pipe.transformer, args=(), # Empty - all inputs via kwargs kwargs=combined_kwargs, dynamic_shapes=dynamic_shapes, ) # Step 5: Compile with inductor logger.info("Step 4/4: Compiling with PyTorch Inductor (this takes several minutes)...") compiled = spaces.aoti_compile(exported, inductor_configs) logger.info("AoTI compilation completed successfully!") return compiled except Exception as e: logger.error(f"AoTI compilation failed: {type(e).__name__}: {str(e)}") import traceback logger.error(traceback.format_exc()) logger.warning("Falling back to non-compiled transformer") return None