Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -61,7 +61,7 @@ except ImportError:
|
|
| 61 |
|
| 62 |
# Import Lyra loader module
|
| 63 |
try:
|
| 64 |
-
from geofractal.model.vae.
|
| 65 |
LYRA_LOADER_AVAILABLE = True
|
| 66 |
except ImportError:
|
| 67 |
print("⚠️ Lyra loader module not available, using fallback")
|
|
@@ -154,7 +154,10 @@ class LazyT5Encoder:
|
|
| 154 |
|
| 155 |
|
| 156 |
class LazyLyraModel:
|
| 157 |
-
"""Lazy loader for Lyra VAE - only downloads/loads when first accessed.
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
def __init__(
|
| 160 |
self,
|
|
@@ -167,7 +170,118 @@ class LazyLyraModel:
|
|
| 167 |
self.checkpoint = checkpoint
|
| 168 |
self._model = None
|
| 169 |
self._info = None
|
|
|
|
| 170 |
self._loaded = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
@property
|
| 173 |
def model(self):
|
|
@@ -182,10 +296,13 @@ class LazyLyraModel:
|
|
| 182 |
device=self.device,
|
| 183 |
return_info=True
|
| 184 |
)
|
|
|
|
|
|
|
|
|
|
| 185 |
else:
|
| 186 |
# Fallback to manual loading
|
| 187 |
self._model = self._load_fallback()
|
| 188 |
-
self._info = {"repo_id": self.repo_id, "version": "v2"}
|
| 189 |
|
| 190 |
self._model.eval()
|
| 191 |
self._loaded = True
|
|
@@ -194,8 +311,8 @@ class LazyLyraModel:
|
|
| 194 |
|
| 195 |
@property
|
| 196 |
def info(self) -> Optional[Dict]:
|
| 197 |
-
if self._info is None
|
| 198 |
-
return {"repo_id": self.repo_id}
|
| 199 |
return self._info
|
| 200 |
|
| 201 |
@property
|
|
@@ -207,18 +324,11 @@ class LazyLyraModel:
|
|
| 207 |
if not LYRA_V2_AVAILABLE:
|
| 208 |
raise ImportError("Lyra VAE v2 not available")
|
| 209 |
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
filename="config.json",
|
| 213 |
-
repo_type="model"
|
| 214 |
-
)
|
| 215 |
-
|
| 216 |
-
with open(config_path, 'r') as f:
|
| 217 |
-
config_dict = json.load(f)
|
| 218 |
|
| 219 |
# Find checkpoint
|
| 220 |
from huggingface_hub import list_repo_files
|
| 221 |
-
import re
|
| 222 |
|
| 223 |
repo_files = list_repo_files(self.repo_id, repo_type="model")
|
| 224 |
checkpoint_files = [f for f in repo_files if f.endswith('.safetensors') or f.endswith('.pt')]
|
|
@@ -245,7 +355,7 @@ class LazyLyraModel:
|
|
| 245 |
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 246 |
state_dict = checkpoint.get('model_state_dict', checkpoint)
|
| 247 |
|
| 248 |
-
# Build config
|
| 249 |
vae_config = LyraV2Config(
|
| 250 |
modality_dims=config_dict.get('modality_dims'),
|
| 251 |
modality_seq_lens=config_dict.get('modality_seq_lens'),
|
|
@@ -253,6 +363,13 @@ class LazyLyraModel:
|
|
| 253 |
latent_dim=config_dict.get('latent_dim', 2048),
|
| 254 |
hidden_dim=config_dict.get('hidden_dim', 2048),
|
| 255 |
fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
)
|
| 257 |
|
| 258 |
model = LyraV2(vae_config)
|
|
@@ -337,39 +454,6 @@ def get_scheduler(
|
|
| 337 |
# UTILITIES
|
| 338 |
# ============================================================================
|
| 339 |
|
| 340 |
-
def extract_comfyui_components(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 341 |
-
"""Extract UNet, CLIP-L, CLIP-G, and VAE from ComfyUI single-file checkpoint."""
|
| 342 |
-
|
| 343 |
-
components = {
|
| 344 |
-
"unet": {},
|
| 345 |
-
"clip_l": {},
|
| 346 |
-
"clip_g": {},
|
| 347 |
-
"vae": {}
|
| 348 |
-
}
|
| 349 |
-
|
| 350 |
-
for key, value in state_dict.items():
|
| 351 |
-
if key.startswith(COMFYUI_UNET_PREFIX):
|
| 352 |
-
new_key = key[len(COMFYUI_UNET_PREFIX):]
|
| 353 |
-
components["unet"][new_key] = value
|
| 354 |
-
elif key.startswith(COMFYUI_CLIP_L_PREFIX):
|
| 355 |
-
new_key = key[len(COMFYUI_CLIP_L_PREFIX):]
|
| 356 |
-
components["clip_l"][new_key] = value
|
| 357 |
-
elif key.startswith(COMFYUI_CLIP_G_PREFIX):
|
| 358 |
-
new_key = key[len(COMFYUI_CLIP_G_PREFIX):]
|
| 359 |
-
components["clip_g"][new_key] = value
|
| 360 |
-
elif key.startswith(COMFYUI_VAE_PREFIX):
|
| 361 |
-
new_key = key[len(COMFYUI_VAE_PREFIX):]
|
| 362 |
-
components["vae"][new_key] = value
|
| 363 |
-
|
| 364 |
-
print(f" Extracted components:")
|
| 365 |
-
print(f" UNet: {len(components['unet'])} keys")
|
| 366 |
-
print(f" CLIP-L: {len(components['clip_l'])} keys")
|
| 367 |
-
print(f" CLIP-G: {len(components['clip_g'])} keys")
|
| 368 |
-
print(f" VAE: {len(components['vae'])} keys")
|
| 369 |
-
|
| 370 |
-
return components
|
| 371 |
-
|
| 372 |
-
|
| 373 |
def get_clip_hidden_state(
|
| 374 |
model_output,
|
| 375 |
clip_skip: int = 1,
|
|
@@ -551,40 +635,77 @@ class SDXLFlowMatchingPipeline:
|
|
| 551 |
|
| 552 |
def encode_prompt_lyra(
|
| 553 |
self,
|
| 554 |
-
prompt: str,
|
| 555 |
negative_prompt: str = "",
|
| 556 |
clip_skip: int = 1,
|
| 557 |
t5_summary: str = "",
|
| 558 |
-
lyra_strength: float = 0.3
|
|
|
|
|
|
|
| 559 |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 560 |
"""Encode prompts using Lyra VAE v2 fusion (CLIP + T5).
|
| 561 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
This triggers lazy loading of T5 and Lyra if not already loaded.
|
|
|
|
| 563 |
"""
|
| 564 |
if not self.lyra_available:
|
| 565 |
raise ValueError("Lyra VAE components not configured")
|
| 566 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
# Access properties triggers lazy load
|
| 568 |
t5_encoder = self.t5_encoder
|
| 569 |
t5_tokenizer = self.t5_tokenizer
|
| 570 |
lyra_model = self.lyra_model
|
| 571 |
|
| 572 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
|
| 574 |
-
|
| 575 |
)
|
| 576 |
|
| 577 |
-
#
|
|
|
|
| 578 |
SUMMARY_SEPARATOR = "¶"
|
|
|
|
| 579 |
if t5_summary.strip():
|
| 580 |
-
|
|
|
|
|
|
|
|
|
|
| 581 |
else:
|
| 582 |
-
|
|
|
|
| 583 |
|
| 584 |
-
|
|
|
|
|
|
|
|
|
|
| 585 |
t5_inputs = t5_tokenizer(
|
| 586 |
t5_prompt,
|
| 587 |
-
max_length=
|
| 588 |
padding='max_length',
|
| 589 |
truncation=True,
|
| 590 |
return_tensors='pt'
|
|
@@ -593,6 +714,7 @@ class SDXLFlowMatchingPipeline:
|
|
| 593 |
with torch.no_grad():
|
| 594 |
t5_embeds = t5_encoder(**t5_inputs).last_hidden_state
|
| 595 |
|
|
|
|
| 596 |
clip_l_dim = 768
|
| 597 |
clip_g_dim = 1280
|
| 598 |
|
|
@@ -632,14 +754,17 @@ class SDXLFlowMatchingPipeline:
|
|
| 632 |
|
| 633 |
prompt_embeds_fused = torch.cat([fused_clip_l, fused_clip_g], dim=-1)
|
| 634 |
|
| 635 |
-
#
|
|
|
|
| 636 |
if negative_prompt:
|
| 637 |
neg_strength = lyra_strength * 0.5 # Less aggressive for negative
|
| 638 |
|
| 639 |
-
|
|
|
|
|
|
|
| 640 |
t5_inputs_neg = t5_tokenizer(
|
| 641 |
t5_neg_prompt,
|
| 642 |
-
max_length=
|
| 643 |
padding='max_length',
|
| 644 |
truncation=True,
|
| 645 |
return_tensors='pt'
|
|
@@ -710,9 +835,19 @@ class SDXLFlowMatchingPipeline:
|
|
| 710 |
clip_skip: int = 1,
|
| 711 |
t5_summary: str = "",
|
| 712 |
lyra_strength: float = 1.0,
|
|
|
|
|
|
|
| 713 |
progress_callback=None
|
| 714 |
):
|
| 715 |
-
"""Generate image using SDXL architecture.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 716 |
|
| 717 |
if seed is not None:
|
| 718 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
|
@@ -722,7 +857,9 @@ class SDXLFlowMatchingPipeline:
|
|
| 722 |
# Encode prompts (Lyra triggers lazy load only if use_lyra=True)
|
| 723 |
if use_lyra and self.lyra_available:
|
| 724 |
prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra(
|
| 725 |
-
prompt, negative_prompt, clip_skip, t5_summary, lyra_strength
|
|
|
|
|
|
|
| 726 |
)
|
| 727 |
else:
|
| 728 |
prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
|
|
@@ -912,10 +1049,19 @@ class SD15FlowMatchingPipeline:
|
|
| 912 |
return prompt_embeds, negative_prompt_embeds
|
| 913 |
|
| 914 |
def encode_prompt_lyra(self, prompt: str, negative_prompt: str = ""):
|
| 915 |
-
"""Encode using Lyra VAE (CLIP + T5 fusion).
|
|
|
|
|
|
|
|
|
|
| 916 |
if not self.lyra_available:
|
| 917 |
raise ValueError("Lyra VAE components not configured")
|
| 918 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 919 |
t5_encoder = self.t5_encoder
|
| 920 |
t5_tokenizer = self.t5_tokenizer
|
| 921 |
lyra_model = self.lyra_model
|
|
@@ -933,10 +1079,10 @@ class SD15FlowMatchingPipeline:
|
|
| 933 |
with torch.no_grad():
|
| 934 |
clip_embeds = self.text_encoder(text_input_ids)[0]
|
| 935 |
|
| 936 |
-
# T5
|
| 937 |
t5_inputs = t5_tokenizer(
|
| 938 |
prompt,
|
| 939 |
-
max_length=
|
| 940 |
padding='max_length',
|
| 941 |
truncation=True,
|
| 942 |
return_tensors='pt'
|
|
@@ -971,7 +1117,7 @@ class SD15FlowMatchingPipeline:
|
|
| 971 |
|
| 972 |
t5_inputs_uncond = t5_tokenizer(
|
| 973 |
negative_prompt,
|
| 974 |
-
max_length=
|
| 975 |
padding='max_length',
|
| 976 |
truncation=True,
|
| 977 |
return_tensors='pt'
|
|
@@ -1363,7 +1509,7 @@ def estimate_duration(num_steps: int, width: int, height: int, use_lyra: bool =
|
|
| 1363 |
|
| 1364 |
|
| 1365 |
@spaces.GPU(duration=lambda *args: estimate_duration(
|
| 1366 |
-
args[6], args[8], args[9], args[
|
| 1367 |
"SDXL" in args[3] or "Illustrious" in args[3]
|
| 1368 |
))
|
| 1369 |
def generate_image(
|
|
@@ -1381,11 +1527,20 @@ def generate_image(
|
|
| 1381 |
use_flow_matching: bool,
|
| 1382 |
use_lyra: bool,
|
| 1383 |
lyra_strength: float,
|
|
|
|
|
|
|
| 1384 |
seed: int,
|
| 1385 |
randomize_seed: bool,
|
| 1386 |
progress=gr.Progress()
|
| 1387 |
):
|
| 1388 |
-
"""Generate image with ZeroGPU support.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1389 |
|
| 1390 |
if randomize_seed:
|
| 1391 |
seed = np.random.randint(0, 2**32 - 1)
|
|
@@ -1464,6 +1619,8 @@ def generate_image(
|
|
| 1464 |
clip_skip=clip_skip,
|
| 1465 |
t5_summary=t5_summary,
|
| 1466 |
lyra_strength=lyra_strength,
|
|
|
|
|
|
|
| 1467 |
progress_callback=lambda s, t, d: progress(0.5 + (s/t) * 0.45, desc=d)
|
| 1468 |
)
|
| 1469 |
|
|
@@ -1507,14 +1664,15 @@ def create_demo():
|
|
| 1507 |
prompt = gr.TextArea(
|
| 1508 |
label="Prompt (Tags for CLIP)",
|
| 1509 |
value="masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background",
|
| 1510 |
-
lines=3
|
|
|
|
| 1511 |
)
|
| 1512 |
|
| 1513 |
t5_summary = gr.TextArea(
|
| 1514 |
-
label="T5 Summary (Natural Language
|
| 1515 |
value="A beautiful anime girl with flowing blue hair wearing a school uniform, surrounded by delicate pink cherry blossoms against a bright sky",
|
| 1516 |
lines=2,
|
| 1517 |
-
info="
|
| 1518 |
)
|
| 1519 |
|
| 1520 |
negative_prompt = gr.TextArea(
|
|
@@ -1565,6 +1723,19 @@ def create_demo():
|
|
| 1565 |
info="0.0 = pure CLIP, 1.0 = pure Lyra reconstruction"
|
| 1566 |
)
|
| 1567 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1568 |
with gr.Accordion("Generation Settings", open=True):
|
| 1569 |
num_steps = gr.Slider(
|
| 1570 |
label="Steps",
|
|
@@ -1725,7 +1896,8 @@ def create_demo():
|
|
| 1725 |
inputs=[
|
| 1726 |
prompt, t5_summary, negative_prompt, model_choice, scheduler_choice, clip_skip,
|
| 1727 |
num_steps, cfg_scale, width, height, shift,
|
| 1728 |
-
use_flow_matching, use_lyra, lyra_strength,
|
|
|
|
| 1729 |
],
|
| 1730 |
outputs=[output_image_standard, output_image_lyra, output_seed]
|
| 1731 |
)
|
|
|
|
| 61 |
|
| 62 |
# Import Lyra loader module
|
| 63 |
try:
|
| 64 |
+
from geofractal.model.vae.loader import load_vae_lyra, load_lyra_illustrious
|
| 65 |
LYRA_LOADER_AVAILABLE = True
|
| 66 |
except ImportError:
|
| 67 |
print("⚠️ Lyra loader module not available, using fallback")
|
|
|
|
| 154 |
|
| 155 |
|
| 156 |
class LazyLyraModel:
|
| 157 |
+
"""Lazy loader for Lyra VAE - only downloads/loads when first accessed.
|
| 158 |
+
|
| 159 |
+
Exposes config with modality_seq_lens for proper tokenization lengths.
|
| 160 |
+
"""
|
| 161 |
|
| 162 |
def __init__(
|
| 163 |
self,
|
|
|
|
| 170 |
self.checkpoint = checkpoint
|
| 171 |
self._model = None
|
| 172 |
self._info = None
|
| 173 |
+
self._config = None
|
| 174 |
self._loaded = False
|
| 175 |
+
|
| 176 |
+
# Pre-fetch config without loading model (lightweight)
|
| 177 |
+
self._prefetch_config()
|
| 178 |
+
|
| 179 |
+
def _prefetch_config(self):
|
| 180 |
+
"""Fetch config.json to get sequence lengths without loading the full model."""
|
| 181 |
+
try:
|
| 182 |
+
config_path = hf_hub_download(
|
| 183 |
+
repo_id=self.repo_id,
|
| 184 |
+
filename="config.json",
|
| 185 |
+
repo_type="model"
|
| 186 |
+
)
|
| 187 |
+
with open(config_path, 'r') as f:
|
| 188 |
+
self._config = json.load(f)
|
| 189 |
+
|
| 190 |
+
# Detect version from config
|
| 191 |
+
is_v2 = 'modality_seq_lens' in self._config or 'binding_config' in self._config
|
| 192 |
+
version = "v2" if is_v2 else "v1"
|
| 193 |
+
|
| 194 |
+
print(f"📋 Lyra config prefetched: {self.repo_id} ({version})")
|
| 195 |
+
|
| 196 |
+
if is_v2:
|
| 197 |
+
print(f" Sequence lengths: {self._config.get('modality_seq_lens', {})}")
|
| 198 |
+
else:
|
| 199 |
+
print(f" Sequence length: {self._config.get('seq_len', 77)}")
|
| 200 |
+
|
| 201 |
+
except Exception as e:
|
| 202 |
+
print(f"⚠️ Could not prefetch Lyra config: {e}")
|
| 203 |
+
# Detect version from repo name and use appropriate defaults
|
| 204 |
+
is_illustrious = 'illustrious' in self.repo_id.lower() or 'xl' in self.repo_id.lower()
|
| 205 |
+
|
| 206 |
+
if is_illustrious:
|
| 207 |
+
# v2 defaults for SDXL/Illustrious
|
| 208 |
+
self._config = {
|
| 209 |
+
"modality_dims": {
|
| 210 |
+
"clip_l": 768,
|
| 211 |
+
"clip_g": 1280,
|
| 212 |
+
"t5_xl_l": 2048,
|
| 213 |
+
"t5_xl_g": 2048
|
| 214 |
+
},
|
| 215 |
+
"modality_seq_lens": {
|
| 216 |
+
"clip_l": 77,
|
| 217 |
+
"clip_g": 77,
|
| 218 |
+
"t5_xl_l": 512,
|
| 219 |
+
"t5_xl_g": 512
|
| 220 |
+
},
|
| 221 |
+
"fusion_strategy": "adaptive_cantor",
|
| 222 |
+
"latent_dim": 2048
|
| 223 |
+
}
|
| 224 |
+
else:
|
| 225 |
+
# v1 defaults for SD1.5
|
| 226 |
+
self._config = {
|
| 227 |
+
"modality_dims": {
|
| 228 |
+
"clip": 768,
|
| 229 |
+
"t5": 768
|
| 230 |
+
},
|
| 231 |
+
"seq_len": 77,
|
| 232 |
+
"fusion_strategy": "cantor",
|
| 233 |
+
"latent_dim": 768
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
@property
|
| 237 |
+
def config(self) -> Dict:
|
| 238 |
+
"""Get model config (available before full model load)."""
|
| 239 |
+
return self._config or {}
|
| 240 |
+
|
| 241 |
+
@property
|
| 242 |
+
def modality_seq_lens(self) -> Dict[str, int]:
|
| 243 |
+
"""Get sequence lengths for each modality.
|
| 244 |
+
|
| 245 |
+
Handles both v1 (seq_len) and v2 (modality_seq_lens) config formats.
|
| 246 |
+
"""
|
| 247 |
+
# v2 format: modality_seq_lens dict
|
| 248 |
+
if 'modality_seq_lens' in self.config:
|
| 249 |
+
return self.config['modality_seq_lens']
|
| 250 |
+
|
| 251 |
+
# v1 format: derive from single seq_len
|
| 252 |
+
seq_len = self.config.get('seq_len', 77)
|
| 253 |
+
modality_dims = self.config.get('modality_dims', {})
|
| 254 |
+
|
| 255 |
+
# Return seq_len for all modalities in v1
|
| 256 |
+
return {name: seq_len for name in modality_dims.keys()}
|
| 257 |
+
|
| 258 |
+
@property
|
| 259 |
+
def t5_max_length(self) -> int:
|
| 260 |
+
"""Get T5 max sequence length from config.
|
| 261 |
+
|
| 262 |
+
Handles both v1 (seq_len) and v2 (modality_seq_lens) config formats.
|
| 263 |
+
"""
|
| 264 |
+
# v2 format: modality_seq_lens dict
|
| 265 |
+
if 'modality_seq_lens' in self.config:
|
| 266 |
+
seq_lens = self.config['modality_seq_lens']
|
| 267 |
+
return seq_lens.get('t5_xl_l', seq_lens.get('t5_xl_g', 512))
|
| 268 |
+
|
| 269 |
+
# v1 format: single seq_len
|
| 270 |
+
return self.config.get('seq_len', 77)
|
| 271 |
+
|
| 272 |
+
@property
|
| 273 |
+
def clip_max_length(self) -> int:
|
| 274 |
+
"""Get CLIP max sequence length from config.
|
| 275 |
+
|
| 276 |
+
Handles both v1 (seq_len) and v2 (modality_seq_lens) config formats.
|
| 277 |
+
"""
|
| 278 |
+
# v2 format: modality_seq_lens dict
|
| 279 |
+
if 'modality_seq_lens' in self.config:
|
| 280 |
+
seq_lens = self.config['modality_seq_lens']
|
| 281 |
+
return seq_lens.get('clip_l', 77)
|
| 282 |
+
|
| 283 |
+
# v1 format: single seq_len (same for all modalities)
|
| 284 |
+
return self.config.get('seq_len', 77)
|
| 285 |
|
| 286 |
@property
|
| 287 |
def model(self):
|
|
|
|
| 296 |
device=self.device,
|
| 297 |
return_info=True
|
| 298 |
)
|
| 299 |
+
# Update config from loaded info
|
| 300 |
+
if self._info and 'config' in self._info:
|
| 301 |
+
self._config = self._info['config']
|
| 302 |
else:
|
| 303 |
# Fallback to manual loading
|
| 304 |
self._model = self._load_fallback()
|
| 305 |
+
self._info = {"repo_id": self.repo_id, "version": "v2", "config": self._config}
|
| 306 |
|
| 307 |
self._model.eval()
|
| 308 |
self._loaded = True
|
|
|
|
| 311 |
|
| 312 |
@property
|
| 313 |
def info(self) -> Optional[Dict]:
|
| 314 |
+
if self._info is None:
|
| 315 |
+
return {"repo_id": self.repo_id, "config": self._config}
|
| 316 |
return self._info
|
| 317 |
|
| 318 |
@property
|
|
|
|
| 324 |
if not LYRA_V2_AVAILABLE:
|
| 325 |
raise ImportError("Lyra VAE v2 not available")
|
| 326 |
|
| 327 |
+
# Config already prefetched
|
| 328 |
+
config_dict = self._config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
# Find checkpoint
|
| 331 |
from huggingface_hub import list_repo_files
|
|
|
|
| 332 |
|
| 333 |
repo_files = list_repo_files(self.repo_id, repo_type="model")
|
| 334 |
checkpoint_files = [f for f in repo_files if f.endswith('.safetensors') or f.endswith('.pt')]
|
|
|
|
| 355 |
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 356 |
state_dict = checkpoint.get('model_state_dict', checkpoint)
|
| 357 |
|
| 358 |
+
# Build config with all fields from prefetched config
|
| 359 |
vae_config = LyraV2Config(
|
| 360 |
modality_dims=config_dict.get('modality_dims'),
|
| 361 |
modality_seq_lens=config_dict.get('modality_seq_lens'),
|
|
|
|
| 363 |
latent_dim=config_dict.get('latent_dim', 2048),
|
| 364 |
hidden_dim=config_dict.get('hidden_dim', 2048),
|
| 365 |
fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'),
|
| 366 |
+
encoder_layers=config_dict.get('encoder_layers', 3),
|
| 367 |
+
decoder_layers=config_dict.get('decoder_layers', 3),
|
| 368 |
+
fusion_heads=config_dict.get('fusion_heads', 8),
|
| 369 |
+
cantor_depth=config_dict.get('cantor_depth', 8),
|
| 370 |
+
cantor_local_window=config_dict.get('cantor_local_window', 3),
|
| 371 |
+
alpha_init=config_dict.get('alpha_init', 1.0),
|
| 372 |
+
beta_init=config_dict.get('beta_init', 0.3),
|
| 373 |
)
|
| 374 |
|
| 375 |
model = LyraV2(vae_config)
|
|
|
|
| 454 |
# UTILITIES
|
| 455 |
# ============================================================================
|
| 456 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
def get_clip_hidden_state(
|
| 458 |
model_output,
|
| 459 |
clip_skip: int = 1,
|
|
|
|
| 635 |
|
| 636 |
def encode_prompt_lyra(
|
| 637 |
self,
|
| 638 |
+
prompt: str,
|
| 639 |
negative_prompt: str = "",
|
| 640 |
clip_skip: int = 1,
|
| 641 |
t5_summary: str = "",
|
| 642 |
+
lyra_strength: float = 0.3,
|
| 643 |
+
use_separator: bool = True,
|
| 644 |
+
clip_include_summary: bool = False
|
| 645 |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 646 |
"""Encode prompts using Lyra VAE v2 fusion (CLIP + T5).
|
| 647 |
|
| 648 |
+
CLIP encoders receive tags only (prompt field).
|
| 649 |
+
T5 encoder receives tags + separator + summary.
|
| 650 |
+
|
| 651 |
+
Args:
|
| 652 |
+
prompt: Tags/keywords for CLIP encoding
|
| 653 |
+
negative_prompt: Negative tags
|
| 654 |
+
clip_skip: CLIP skip layers
|
| 655 |
+
t5_summary: Natural language summary for T5
|
| 656 |
+
lyra_strength: Blend factor (0=pure CLIP, 1=pure Lyra)
|
| 657 |
+
use_separator: If True, use ¶ separator between tags and summary
|
| 658 |
+
clip_include_summary: If True, append summary to CLIP input (default False)
|
| 659 |
+
|
| 660 |
This triggers lazy loading of T5 and Lyra if not already loaded.
|
| 661 |
+
Uses sequence lengths from Lyra config for proper tokenization.
|
| 662 |
"""
|
| 663 |
if not self.lyra_available:
|
| 664 |
raise ValueError("Lyra VAE components not configured")
|
| 665 |
|
| 666 |
+
# Get sequence lengths from Lyra config (available before full load)
|
| 667 |
+
t5_max_length = self.lyra_loader.t5_max_length # 512 for Illustrious
|
| 668 |
+
clip_max_length = self.lyra_loader.clip_max_length # 77 for Illustrious
|
| 669 |
+
|
| 670 |
+
print(f"[Lyra] Using sequence lengths: CLIP={clip_max_length}, T5={t5_max_length}")
|
| 671 |
+
|
| 672 |
# Access properties triggers lazy load
|
| 673 |
t5_encoder = self.t5_encoder
|
| 674 |
t5_tokenizer = self.t5_tokenizer
|
| 675 |
lyra_model = self.lyra_model
|
| 676 |
|
| 677 |
+
# === CLIP ENCODING ===
|
| 678 |
+
# CLIP sees tags only (unless clip_include_summary is True)
|
| 679 |
+
if clip_include_summary and t5_summary.strip():
|
| 680 |
+
clip_prompt = f"{prompt} {t5_summary}"
|
| 681 |
+
else:
|
| 682 |
+
clip_prompt = prompt
|
| 683 |
+
|
| 684 |
+
# Get CLIP embeddings with tags only
|
| 685 |
prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
|
| 686 |
+
clip_prompt, negative_prompt, clip_skip
|
| 687 |
)
|
| 688 |
|
| 689 |
+
# === T5 ENCODING ===
|
| 690 |
+
# T5 sees tags + separator + summary (or tags + summary if no separator)
|
| 691 |
SUMMARY_SEPARATOR = "¶"
|
| 692 |
+
|
| 693 |
if t5_summary.strip():
|
| 694 |
+
if use_separator:
|
| 695 |
+
t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {t5_summary}"
|
| 696 |
+
else:
|
| 697 |
+
t5_prompt = f"{prompt} {t5_summary}"
|
| 698 |
else:
|
| 699 |
+
# No summary provided - T5 just sees the tags
|
| 700 |
+
t5_prompt = prompt
|
| 701 |
|
| 702 |
+
print(f"[Lyra] CLIP input: {clip_prompt[:80]}...")
|
| 703 |
+
print(f"[Lyra] T5 input: {t5_prompt[:80]}...")
|
| 704 |
+
|
| 705 |
+
# Get T5 embeddings with config-specified max_length
|
| 706 |
t5_inputs = t5_tokenizer(
|
| 707 |
t5_prompt,
|
| 708 |
+
max_length=t5_max_length,
|
| 709 |
padding='max_length',
|
| 710 |
truncation=True,
|
| 711 |
return_tensors='pt'
|
|
|
|
| 714 |
with torch.no_grad():
|
| 715 |
t5_embeds = t5_encoder(**t5_inputs).last_hidden_state
|
| 716 |
|
| 717 |
+
# === LYRA FUSION ===
|
| 718 |
clip_l_dim = 768
|
| 719 |
clip_g_dim = 1280
|
| 720 |
|
|
|
|
| 754 |
|
| 755 |
prompt_embeds_fused = torch.cat([fused_clip_l, fused_clip_g], dim=-1)
|
| 756 |
|
| 757 |
+
# === NEGATIVE PROMPT ===
|
| 758 |
+
# Negative uses same logic: CLIP sees negative tags only
|
| 759 |
if negative_prompt:
|
| 760 |
neg_strength = lyra_strength * 0.5 # Less aggressive for negative
|
| 761 |
|
| 762 |
+
# T5 negative: tags only (no summary for negative)
|
| 763 |
+
t5_neg_prompt = negative_prompt
|
| 764 |
+
|
| 765 |
t5_inputs_neg = t5_tokenizer(
|
| 766 |
t5_neg_prompt,
|
| 767 |
+
max_length=t5_max_length,
|
| 768 |
padding='max_length',
|
| 769 |
truncation=True,
|
| 770 |
return_tensors='pt'
|
|
|
|
| 835 |
clip_skip: int = 1,
|
| 836 |
t5_summary: str = "",
|
| 837 |
lyra_strength: float = 1.0,
|
| 838 |
+
use_separator: bool = True,
|
| 839 |
+
clip_include_summary: bool = False,
|
| 840 |
progress_callback=None
|
| 841 |
):
|
| 842 |
+
"""Generate image using SDXL architecture.
|
| 843 |
+
|
| 844 |
+
Args:
|
| 845 |
+
prompt: Tags/keywords for image generation
|
| 846 |
+
negative_prompt: Negative tags
|
| 847 |
+
t5_summary: Natural language summary (T5 only, unless clip_include_summary=True)
|
| 848 |
+
use_separator: Use ¶ separator between tags and summary in T5 input
|
| 849 |
+
clip_include_summary: If True, append summary to CLIP input (default False)
|
| 850 |
+
"""
|
| 851 |
|
| 852 |
if seed is not None:
|
| 853 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
|
|
|
| 857 |
# Encode prompts (Lyra triggers lazy load only if use_lyra=True)
|
| 858 |
if use_lyra and self.lyra_available:
|
| 859 |
prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra(
|
| 860 |
+
prompt, negative_prompt, clip_skip, t5_summary, lyra_strength,
|
| 861 |
+
use_separator=use_separator,
|
| 862 |
+
clip_include_summary=clip_include_summary
|
| 863 |
)
|
| 864 |
else:
|
| 865 |
prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
|
|
|
|
| 1049 |
return prompt_embeds, negative_prompt_embeds
|
| 1050 |
|
| 1051 |
def encode_prompt_lyra(self, prompt: str, negative_prompt: str = ""):
|
| 1052 |
+
"""Encode using Lyra VAE v1 (CLIP + T5 fusion).
|
| 1053 |
+
|
| 1054 |
+
Uses sequence lengths from Lyra config for proper tokenization.
|
| 1055 |
+
"""
|
| 1056 |
if not self.lyra_available:
|
| 1057 |
raise ValueError("Lyra VAE components not configured")
|
| 1058 |
|
| 1059 |
+
# Get sequence length from config (v1 uses same length for clip and t5)
|
| 1060 |
+
# Default to 77 for SD1.5/v1
|
| 1061 |
+
t5_max_length = self.lyra_loader.config.get('seq_len', 77)
|
| 1062 |
+
|
| 1063 |
+
print(f"[Lyra v1] Using sequence length: {t5_max_length}")
|
| 1064 |
+
|
| 1065 |
t5_encoder = self.t5_encoder
|
| 1066 |
t5_tokenizer = self.t5_tokenizer
|
| 1067 |
lyra_model = self.lyra_model
|
|
|
|
| 1079 |
with torch.no_grad():
|
| 1080 |
clip_embeds = self.text_encoder(text_input_ids)[0]
|
| 1081 |
|
| 1082 |
+
# T5 with config-specified max_length
|
| 1083 |
t5_inputs = t5_tokenizer(
|
| 1084 |
prompt,
|
| 1085 |
+
max_length=t5_max_length,
|
| 1086 |
padding='max_length',
|
| 1087 |
truncation=True,
|
| 1088 |
return_tensors='pt'
|
|
|
|
| 1117 |
|
| 1118 |
t5_inputs_uncond = t5_tokenizer(
|
| 1119 |
negative_prompt,
|
| 1120 |
+
max_length=t5_max_length,
|
| 1121 |
padding='max_length',
|
| 1122 |
truncation=True,
|
| 1123 |
return_tensors='pt'
|
|
|
|
| 1509 |
|
| 1510 |
|
| 1511 |
@spaces.GPU(duration=lambda *args: estimate_duration(
|
| 1512 |
+
args[6], args[8], args[9], args[14],
|
| 1513 |
"SDXL" in args[3] or "Illustrious" in args[3]
|
| 1514 |
))
|
| 1515 |
def generate_image(
|
|
|
|
| 1527 |
use_flow_matching: bool,
|
| 1528 |
use_lyra: bool,
|
| 1529 |
lyra_strength: float,
|
| 1530 |
+
use_separator: bool,
|
| 1531 |
+
clip_include_summary: bool,
|
| 1532 |
seed: int,
|
| 1533 |
randomize_seed: bool,
|
| 1534 |
progress=gr.Progress()
|
| 1535 |
):
|
| 1536 |
+
"""Generate image with ZeroGPU support.
|
| 1537 |
+
|
| 1538 |
+
Args:
|
| 1539 |
+
prompt: Tags/keywords (CLIP input)
|
| 1540 |
+
t5_summary: Natural language summary (T5 input, unless clip_include_summary)
|
| 1541 |
+
use_separator: Use ¶ separator between tags and summary
|
| 1542 |
+
clip_include_summary: If True, CLIP also sees the summary
|
| 1543 |
+
"""
|
| 1544 |
|
| 1545 |
if randomize_seed:
|
| 1546 |
seed = np.random.randint(0, 2**32 - 1)
|
|
|
|
| 1619 |
clip_skip=clip_skip,
|
| 1620 |
t5_summary=t5_summary,
|
| 1621 |
lyra_strength=lyra_strength,
|
| 1622 |
+
use_separator=use_separator,
|
| 1623 |
+
clip_include_summary=clip_include_summary,
|
| 1624 |
progress_callback=lambda s, t, d: progress(0.5 + (s/t) * 0.45, desc=d)
|
| 1625 |
)
|
| 1626 |
|
|
|
|
| 1664 |
prompt = gr.TextArea(
|
| 1665 |
label="Prompt (Tags for CLIP)",
|
| 1666 |
value="masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background",
|
| 1667 |
+
lines=3,
|
| 1668 |
+
info="CLIP encoders see these tags. T5 also sees these + the summary below."
|
| 1669 |
)
|
| 1670 |
|
| 1671 |
t5_summary = gr.TextArea(
|
| 1672 |
+
label="T5 Summary (Natural Language - T5 Only)",
|
| 1673 |
value="A beautiful anime girl with flowing blue hair wearing a school uniform, surrounded by delicate pink cherry blossoms against a bright sky",
|
| 1674 |
lines=2,
|
| 1675 |
+
info="T5 sees: tags ¶ summary. CLIP sees: tags only (unless 'Include Summary in CLIP' is enabled)."
|
| 1676 |
)
|
| 1677 |
|
| 1678 |
negative_prompt = gr.TextArea(
|
|
|
|
| 1723 |
info="0.0 = pure CLIP, 1.0 = pure Lyra reconstruction"
|
| 1724 |
)
|
| 1725 |
|
| 1726 |
+
with gr.Accordion("Lyra Advanced Settings", open=False):
|
| 1727 |
+
use_separator = gr.Checkbox(
|
| 1728 |
+
label="Use ¶ Separator",
|
| 1729 |
+
value=True,
|
| 1730 |
+
info="Insert ¶ between tags and summary in T5 input"
|
| 1731 |
+
)
|
| 1732 |
+
|
| 1733 |
+
clip_include_summary = gr.Checkbox(
|
| 1734 |
+
label="Include Summary in CLIP",
|
| 1735 |
+
value=False,
|
| 1736 |
+
info="By default CLIP sees tags only. Enable to append summary to CLIP input."
|
| 1737 |
+
)
|
| 1738 |
+
|
| 1739 |
with gr.Accordion("Generation Settings", open=True):
|
| 1740 |
num_steps = gr.Slider(
|
| 1741 |
label="Steps",
|
|
|
|
| 1896 |
inputs=[
|
| 1897 |
prompt, t5_summary, negative_prompt, model_choice, scheduler_choice, clip_skip,
|
| 1898 |
num_steps, cfg_scale, width, height, shift,
|
| 1899 |
+
use_flow_matching, use_lyra, lyra_strength, use_separator, clip_include_summary,
|
| 1900 |
+
seed, randomize_seed
|
| 1901 |
],
|
| 1902 |
outputs=[output_image_standard, output_image_lyra, output_seed]
|
| 1903 |
)
|