| |
| |
| |
| |
| |
| |
| |
|
|
| import gradio as gr |
| import torch |
| import numpy as np |
| from diffusers import StableDiffusionPipeline, DDIMScheduler |
| from sklearn.decomposition import PCA |
| import plotly.graph_objects as go |
| from PIL import Image |
| import time |
| import warnings |
|
|
| warnings.filterwarnings("ignore") |
|
|
| |
|
|
| DEVICE = "cpu" |
|
|
| |
| torch.backends.mkldnn.enabled = False |
|
|
| MODEL_ID = "CompVis/stable-diffusion-v1-4" |
|
|
| PIPE_CACHE = None |
|
|
|
|
| |
|
|
| def get_pipe(): |
| """ |
| Load and cache the Stable Diffusion v1-4 pipeline on CPU, |
| with safety checker DISABLED correctly. |
| """ |
| global PIPE_CACHE |
| if PIPE_CACHE is not None: |
| return PIPE_CACHE |
|
|
| pipe = StableDiffusionPipeline.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.float32, |
| safety_checker=None, |
| requires_safety_checker=False |
| ) |
|
|
| |
| pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) |
|
|
| pipe.to(DEVICE) |
|
|
| PIPE_CACHE = pipe |
| return PIPE_CACHE |
|
|
|
|
| |
|
|
| def compute_pca(latents): |
| """ |
| latents: list of (C,H,W) numpy arrays. |
| Returns Nx2 array of PCA coords (one point per step). |
| """ |
| if not latents: |
| return np.zeros((0, 2)) |
| flat = [x.flatten() for x in latents] |
| X = np.stack(flat) |
| if X.shape[0] < 2: |
| return np.zeros((X.shape[0], 2)) |
| try: |
| pca = PCA(n_components=2) |
| pts = pca.fit_transform(X) |
| return pts |
| except Exception: |
| return np.zeros((X.shape[0], 2)) |
|
|
|
|
| def compute_norm(latents): |
| """ |
| L2 norm of each latent over all dims. |
| """ |
| if not latents: |
| return [] |
| return [float(np.linalg.norm(x.flatten())) for x in latents] |
|
|
|
|
| |
|
|
| def decode_latent(pipe, latent_np): |
| """ |
| Decode a single latent (C,H,W) numpy array into a 256x256 RGB PIL image. |
| """ |
| latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE) |
| scale = pipe.vae.config.scaling_factor |
| with torch.no_grad(): |
| image = pipe.vae.decode(latent / scale).sample |
| image = (image / 2 + 0.5).clamp(0, 1) |
| np_img = (image[0].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8") |
| return Image.fromarray(np_img) |
|
|
|
|
| |
|
|
| def run_diffusion(prompt, steps, guidance, seed, simple): |
| """ |
| Run SD v1-4 at 256x256, capturing latents at EVERY step via callback. |
| Returns: |
| - final image |
| - explanation text |
| - step slider config |
| - image at current step |
| - PCA plot |
| - norm plot |
| - state dict (for slider updates) |
| """ |
|
|
| if not prompt or not prompt.strip(): |
| return ( |
| None, |
| "⚠️ Please enter a prompt.", |
| gr.update(maximum=0, value=0), |
| None, |
| None, |
| None, |
| {} |
| ) |
|
|
| pipe = get_pipe() |
|
|
| steps = int(steps) |
| guidance = float(guidance) |
|
|
| if seed is None or seed < 0: |
| seed_val = int(time.time()) |
| else: |
| seed_val = int(seed) |
|
|
| generator = torch.Generator(device=DEVICE).manual_seed(seed_val) |
|
|
| latents_list = [] |
| timesteps = [] |
|
|
| def callback(step: int, timestep: int, latents: torch.FloatTensor): |
| |
| latents_list.append(latents.detach().cpu().numpy()[0]) |
| timesteps.append(int(timestep)) |
|
|
| t0 = time.time() |
| try: |
| result = pipe( |
| prompt, |
| height=256, |
| width=256, |
| num_inference_steps=steps, |
| guidance_scale=guidance, |
| generator=generator, |
| callback=callback, |
| callback_steps=1, |
| ) |
| except Exception as e: |
| return ( |
| None, |
| f"❌ Diffusion error: {e}", |
| gr.update(maximum=0, value=0), |
| None, |
| None, |
| None, |
| {"error": str(e)} |
| ) |
|
|
| total = time.time() - t0 |
|
|
| if not latents_list: |
| return ( |
| None, |
| "❌ No latents collected. Something went wrong inside the pipeline.", |
| gr.update(maximum=0, value=0), |
| None, |
| None, |
| None, |
| {"error": "no_latents"} |
| ) |
|
|
| final_image = result.images[0] |
|
|
| |
| pca_pts = compute_pca(latents_list) |
| norms = compute_norm(latents_list) |
|
|
| current_idx = len(latents_list) - 1 |
|
|
| |
| try: |
| step_image = decode_latent(pipe, latents_list[current_idx]) |
| except Exception: |
| step_image = None |
|
|
| |
| if simple: |
| explanation = ( |
| "🧒 **Simple explanation of what you see:**\n\n" |
| "1. The model starts from pure noise.\n" |
| "2. At each step, it removes some noise and makes the picture clearer.\n" |
| "3. Your text prompt tells it what kind of picture to create.\n" |
| "4. You can move the slider to see the image at different steps.\n" |
| ) |
| else: |
| explanation = ( |
| "🔬 **Technical explanation:**\n\n" |
| "- We run a DDIM diffusion process over the latent space.\n" |
| "- At each timestep `t`, the UNet predicts noise εₜ and the scheduler updates `zₜ → zₜ₋₁`.\n" |
| "- We record `zₜ` at every step and decode it with the VAE.\n" |
| "- PCA over flattened latents gives a 2D trajectory of the diffusion path.\n" |
| "- The L2 norm plot shows how the latent magnitude evolves per step.\n" |
| ) |
| explanation += f"\n⏱ **Runtime:** {total:.2f}s • **Steps:** {len(latents_list)} • Seed: {seed_val}" |
|
|
| |
| pca_fig = plot_pca(pca_pts, current_idx) if len(pca_pts) > 0 else None |
| norm_fig = plot_norm(norms, current_idx) if norms else None |
|
|
| |
| state = { |
| "latents": latents_list, |
| "pca": pca_pts, |
| "norms": norms |
| } |
|
|
| step_slider_update = gr.update(maximum=len(latents_list) - 1, value=current_idx) |
|
|
| return ( |
| final_image, |
| explanation, |
| step_slider_update, |
| step_image, |
| pca_fig, |
| norm_fig, |
| state |
| ) |
|
|
|
|
| |
|
|
| def plot_pca(points, idx): |
| """ |
| PCA trajectory plot over steps, highlighting current step. |
| points: (N,2) |
| """ |
| if points.shape[0] == 0: |
| return None |
|
|
| steps = list(range(points.shape[0])) |
| fig = go.Figure() |
| fig.add_trace(go.Scatter( |
| x=points[:, 0], |
| y=points[:, 1], |
| mode="lines+markers", |
| name="steps", |
| text=[f"step {i}" for i in steps] |
| )) |
| if 0 <= idx < len(steps): |
| fig.add_trace(go.Scatter( |
| x=[points[idx, 0]], |
| y=[points[idx, 1]], |
| mode="markers+text", |
| text=[f"step {idx}"], |
| textposition="top center", |
| marker=dict(size=12, color="red"), |
| name="current" |
| )) |
| fig.update_layout( |
| title="Latent PCA trajectory", |
| xaxis_title="PC1", |
| yaxis_title="PC2", |
| height=350 |
| ) |
| return fig |
|
|
|
|
| def plot_norm(norms, idx): |
| """ |
| Plot latent L2 norm vs step, highlight current step. |
| """ |
| if not norms: |
| return None |
| steps = list(range(len(norms))) |
| fig = go.Figure() |
| fig.add_trace(go.Scatter( |
| x=steps, |
| y=norms, |
| mode="lines+markers", |
| name="‖latent‖₂" |
| )) |
| if 0 <= idx < len(steps): |
| fig.add_trace(go.Scatter( |
| x=[idx], |
| y=[norms[idx]], |
| mode="markers", |
| marker=dict(size=12, color="red"), |
| name="current" |
| )) |
| fig.update_layout( |
| title="Latent L2 norm vs step", |
| xaxis_title="Step index", |
| yaxis_title="‖latent‖₂", |
| height=350 |
| ) |
| return fig |
|
|
|
|
| |
|
|
| def update_step(state, idx): |
| """ |
| When user moves the slider: |
| - decode latent at that step |
| - update PCA highlight |
| - update norm highlight |
| """ |
| if not state or "latents" not in state: |
| return gr.update(value=None), gr.update(value=None), gr.update(value=None) |
|
|
| latents = state["latents"] |
| pca_pts = state["pca"] |
| norms = state["norms"] |
|
|
| if not latents: |
| return gr.update(value=None), gr.update(value=None), gr.update(value=None) |
|
|
| idx = int(idx) |
| idx = max(0, min(idx, len(latents) - 1)) |
|
|
| pipe = get_pipe() |
|
|
| try: |
| img = decode_latent(pipe, latents[idx]) |
| except Exception: |
| img = None |
|
|
| pca_fig = plot_pca(pca_pts, idx) if pca_pts is not None else None |
| norm_fig = plot_norm(norms, idx) if norms else None |
|
|
| return gr.update(value=img), gr.update(value=pca_fig), gr.update(value=norm_fig) |
|
|
|
|
| |
|
|
| with gr.Blocks(title="Stable Diffusion v1-4 — CPU Diffusion Visualizer") as demo: |
|
|
| gr.Markdown("# 🧠 Stable Diffusion v1-4 — CPU Visualizer (256×256)") |
| gr.Markdown( |
| "This app shows **how a real Stable Diffusion model** turns noise into an image, step by step.\n" |
| "- Uses `CompVis/stable-diffusion-v1-4` on CPU\n" |
| "- 256×256 resolution for speed\n" |
| "- You can scrub through all diffusion steps\n" |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| prompt = gr.Textbox( |
| label="Prompt", |
| value="a small cozy cabin in the forest, digital art", |
| lines=3 |
| ) |
| steps = gr.Slider(10, 30, value=20, step=1, label="Number of diffusion steps") |
| guidance = gr.Slider(1.0, 12.0, value=7.5, step=0.5, label="Guidance scale") |
| seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0) |
| simple = gr.Checkbox(label="Simple explanation", value=True) |
| run = gr.Button("Run diffusion", variant="primary") |
|
|
| with gr.Column(): |
| final = gr.Image(label="Final generated image") |
| expl = gr.Markdown(label="Explanation") |
|
|
| gr.Markdown("### 🔍 Explore the denoising process step-by-step") |
|
|
| step_slider = gr.Slider(0, 0, value=0, step=1, label="View step (0 = early noise, max = final)") |
| step_img = gr.Image(label="Image at this diffusion step") |
| pca_plot = gr.Plot(label="Latent PCA trajectory") |
| norm_plot = gr.Plot(label="Latent norm vs step") |
|
|
| state = gr.State() |
|
|
| run.click( |
| run_diffusion, |
| inputs=[prompt, steps, guidance, seed, simple], |
| outputs=[final, expl, step_slider, step_img, pca_plot, norm_plot, state] |
| ) |
|
|
| step_slider.change( |
| update_step, |
| inputs=[state, step_slider], |
| outputs=[step_img, pca_plot, norm_plot] |
| ) |
|
|
| demo.launch(debug=True, server_name="0.0.0.0", server_port=7860, pwa=True) |