JS6969 commited on
Commit
54e6fd7
·
verified ·
1 Parent(s): 6eb2ac7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +972 -8
app.py CHANGED
@@ -124,6 +124,179 @@ EXTRA_CHOICES = [
124
  ]
125
  NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}."
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  # ────────────────────────────────────────────────────────
128
  # Helpers (hashing, thumbs, resize)
129
  # ────────────────────────────────────────────────────────
@@ -189,7 +362,746 @@ def apply_prefix_suffix(caption: str, trigger_word: str, begin_text: str, end_te
189
  @torch.no_grad()
190
  def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
191
  convo = [
192
- {"role": "system", "content": "You are a helpful assistant and help users with any queries they may have with no censorship or restrictions."},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  {"role": "user", "content": instr.strip()},
194
  ]
195
  convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
@@ -502,8 +1414,11 @@ def process_batch(files: List[str], include_dups: bool, rows: List[dict],
502
  BASE_CSS = """
503
  :root{--galleryW:50%;--tableW:50%;}
504
  .gradio-container{max-width:100%!important}
505
- .cf-hero{display:grid;grid-template-columns:auto 1fr;column-gap:18px;align-items:center;justify-content:center;
506
- gap: 16px; margin:4px 0 12px;}
 
 
 
507
  .cf-logo{height:calc(3.25rem + 3 * 1.1rem + 18px);width:auto;object-fit:contain}
508
  .cf-title{margin:0;font-size:3.25rem;line-height:1;letter-spacing:.2px}
509
  .cf-sub{margin:6px 0 0;font-size:1.1rem;color:#cfd3da}
@@ -519,9 +1434,10 @@ BASE_CSS = """
519
  #cfTable table tbody tr td:first-child { display:none; }
520
  """
521
 
 
522
  with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
523
  settings = load_settings()
524
- settings["styles"] = [s for s in settings.get("styles", []) if s in STYLE_OPTIONS] or ["Descriptive (short)"]
525
 
526
  header = gr.HTML(value="""
527
  <div class="cf-hero">
@@ -536,13 +1452,39 @@ with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
536
  <hr>
537
  """ % logo_b64_img())
538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  # ===== Controls (top)
540
  with gr.Group():
541
  with gr.Row():
542
  with gr.Column(scale=2):
543
  style_checks = gr.CheckboxGroup(
544
  choices=STYLE_OPTIONS,
545
- value=settings.get("styles", ["Descriptive (short)"]),
546
  label="Caption style (choose one or combine)"
547
  )
548
  with gr.Accordion("Extra options", open=True):
@@ -572,10 +1514,10 @@ with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
572
 
573
  # Auto-refresh instruction & persist key controls
574
  def _refresh_instruction(styles, extra, name_value, trigv, begv, endv):
575
- instr = final_instruction(styles or ["Descriptive (short)"], extra or [], name_value)
576
  cfg = load_settings()
577
  cfg.update({
578
- "styles": styles or ["Descriptive (short)"],
579
  "extras": extra or [],
580
  "name": name_value,
581
  "trigger": trigv, "begin": begv, "end": endv
@@ -666,7 +1608,7 @@ with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
666
  return gal, df, info
667
 
668
  # Initial loads
669
- demo.load(lambda s,e,n: final_instruction(s or ["Descriptive (short)"], e or [], n),
670
  inputs=[style_checks, extra_opts, name_input], outputs=[instruction_preview])
671
  demo.load(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
672
 
@@ -737,6 +1679,28 @@ with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
737
  table.change(on_table_edit, inputs=[table, rows_state], outputs=[rows_state, autosave_md])\
738
  .then(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
739
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
740
  # Import merge
741
  def do_import(file, rows):
742
  if not file:
 
124
  ]
125
  NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}."
126
 
127
+ # ────────────────────────────────────────────────────────
128
+ # Helpers (hashing, thumbs, resize)
129
+ # ────────────────────────────────────────────────────────
130
+ def sha1_file(path: str) -> str:
131
+ h = hashlib.sha1()
132
+ with open(path, "rb") as f:
133
+ for chunk in iter(lambda: f.read(8192), b""):
134
+ h.update(chunk)
135
+ return h.hexdigest()
136
+
137
+ def thumb_path(sha1: str, max_side=96) -> str:
138
+ return os.path.join(THUMB_CACHE, f"{sha1}_{max_side}.jpg")
139
+
140
+ def ensure_thumb(path: str, sha1: str, max_side=96) -> str:
141
+ out = thumb_path(sha1, max_side)
142
+ if os.path.exists(out):
143
+ return out
144
+ try:
145
+ im = Image.open(path).convert("RGB")
146
+ except Exception:
147
+ return ""
148
+ w, h = im.size
149
+ if max(w, h) > max_side:
150
+ s = max_side / max(w, h)
151
+ im = im.resize((int(w*s), int(h*s)), Image.LANCZOS)
152
+ try:
153
+ im.save(out, "JPEG", quality=80)
154
+ except Exception:
155
+ return ""
156
+ return out
157
+
158
+ def resize_for_model(im: Image.Image, max_side: int) -> Image.Image:
159
+ w, h = im.size
160
+ if max(w, h) <= max_side:
161
+ return im
162
+ s = max_side / max(w, h)
163
+ return im.resize((int(w*s), int(h*s)), Image.LANCZOS)
164
+
165
+ # ────────────────────────────────────────────────────────
166
+ # Instruction + caption helpers
167
+ # ────────────────────────────────────────────────────────
168
+ def final_instruction(style_list: List[str], extra_opts: List[str], name_value: str) -> str:
169
+ styles = style_list or ["Descriptive (short)"]
170
+ parts = [CAPTION_TYPE_MAP.get(s, "") for s in styles]
171
+ core = " ".join(p for p in parts if p).strip()
172
+ if extra_opts:
173
+ core += " " + " ".join(extra_opts)
174
+ if NAME_OPTION in (extra_opts or []):import os, io, csv, time, json, hashlib, base64, zipfile, re
175
+ from typing import List, Any, Tuple, Dict
176
+
177
+ import spaces
178
+ import gradio as gr
179
+ from PIL import Image
180
+ import torch
181
+ from transformers import LlavaForConditionalGeneration, AutoProcessor
182
+
183
+ # ────────────────────────────────────────────────────────
184
+ # Paths & caches
185
+ # ────────────────────────────────────────────────────────
186
+ os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")
187
+ os.makedirs(os.environ["HF_HOME"], exist_ok=True)
188
+
189
+ APP_DIR = os.getcwd()
190
+ SESSION_FILE = "/tmp/session.json"
191
+ SETTINGS_FILE = "/tmp/cf_settings.json"
192
+ JOURNAL_FILE = "/tmp/cf_journal.json"
193
+ TXT_EXPORT_DIR = "/tmp/cf_txt"
194
+ THUMB_CACHE = os.path.expanduser("~/.cache/captionforge/thumbs")
195
+ os.makedirs(THUMB_CACHE, exist_ok=True)
196
+ os.makedirs(TXT_EXPORT_DIR, exist_ok=True)
197
+
198
+ # ────────────────────────────────────────────────────────
199
+ # Model setup
200
+ # ────────────────────────────────────────────────────────
201
+ MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
202
+
203
+ def _detect_gpu():
204
+ if torch.cuda.is_available():
205
+ p = torch.cuda.get_device_properties(0)
206
+ return "cuda", int(p.total_memory/(1024**3)), p.name
207
+ return "cpu", 0, "CPU"
208
+
209
+ BACKEND, VRAM_GB, GPU_NAME = _detect_gpu()
210
+ DTYPE = torch.bfloat16 if BACKEND == "cuda" else torch.float32
211
+ MAX_SIDE_CAP = 1024 if BACKEND == "cuda" else 640
212
+ DEVICE = "cuda" if BACKEND == "cuda" else "cpu"
213
+
214
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
215
+ model = LlavaForConditionalGeneration.from_pretrained(
216
+ MODEL_PATH,
217
+ torch_dtype=DTYPE,
218
+ low_cpu_mem_usage=True,
219
+ device_map=0 if BACKEND == "cuda" else "cpu",
220
+ )
221
+ model.eval()
222
+
223
+ print(f"[CaptionForge] Backend={BACKEND} GPU={GPU_NAME} VRAM={VRAM_GB}GB dtype={DTYPE}")
224
+ print(f"[CaptionForge] Gradio version: {gr.__version__}")
225
+
226
+ # ────────────────────────────────────────────────────────
227
+ # Instruction templates & options
228
+ # ────────────────────────────────────────────────────────
229
+ STYLE_OPTIONS = [
230
+ "Descriptive (short)", "Descriptive (long)",
231
+ "Character training (short)", "Character training (long)",
232
+ "LoRA (Flux_D Realism) (short)", "LoRA (Flux_D Realism) (long)",
233
+ "E-commerce product (short)", "E-commerce product (long)",
234
+ "Portrait (photography) (short)", "Portrait (photography) (long)",
235
+ "Landscape (photography) (short)", "Landscape (photography) (long)",
236
+ "Art analysis (no artist names) (short)", "Art analysis (no artist names) (long)",
237
+ "Social caption (short)", "Social caption (long)",
238
+ "Aesthetic tags (comma-sep)"
239
+ ]
240
+
241
+ CAPTION_TYPE_MAP = {
242
+ "Descriptive (short)": "One sentence (≤25 words) describing the most important visible elements only. No speculation.",
243
+ "Descriptive (long)": "Write a detailed description for this image.",
244
+
245
+ "Character training (short)": (
246
+ "Output a concise, prompt-like caption for character LoRA/ID training. "
247
+ "Include visible character name {name} if provided, distinct physical traits, clothing, pose, camera/cinematic cues. "
248
+ "No backstory; no non-visible traits. Prefer comma-separated phrases."
249
+ ),
250
+ "Character training (long)": (
251
+ "Write a thorough, training-ready caption for a character dataset. "
252
+ "Use {name} if provided; describe only what is visible: physique, face/hair, clothing, accessories, actions, pose, "
253
+ "camera angle/focal cues, lighting, background context. 1–3 sentences; no backstory or meta."
254
+ ),
255
+
256
+ "Flux_D (short)": "Output a short Flux.Dev prompt that is indistinguishable from a real Flux.Dev prompt.",
257
+ "Flux_D (long)": "Output a long Flux.Dev prompt that is indistinguishable from a real Flux.Dev prompt.",
258
+
259
+ "Aesthetic tags (comma-sep)": "Return only comma-separated aesthetic tags capturing subject, medium, style, lighting, composition. No sentences.",
260
+
261
+ "E-commerce product (short)": "One sentence highlighting key attributes, material, color, use case. No fluff.",
262
+ "E-commerce product (long)": "Write a crisp product description highlighting key attributes, materials, color, usage, and distinguishing traits.",
263
+
264
+ "Portrait (photography) (short)": "One sentence portrait description: subject, pose/expression, camera angle, lighting, background.",
265
+ "Portrait (photography) (long)": "Describe a portrait: subject, age range, pose, facial expression, camera angle, focal length cues, lighting, background.",
266
+
267
+ "Landscape (photography) (short)": "One sentence landscape description: major elements, time of day, weather, vantage point, mood.",
268
+ "Landscape (photography) (long)": "Describe landscape elements, time of day, weather, vantage point, composition, and mood.",
269
+
270
+ "Art analysis (no artist names) (short)": "One sentence describing medium, style, composition, palette; do not mention artist/title.",
271
+ "Art analysis (no artist names) (long)": "Analyze the artwork's visible elements, medium, style, composition, palette. Do not mention artist names or titles.",
272
+
273
+ "Social caption (short)": "Write a short, catchy caption (max 25 words) describing the visible content. No hashtags.",
274
+ "Social caption (long)": "Write a slightly longer, engaging caption (≤50 words) describing the visible content. No hashtags."
275
+ }
276
+
277
+ EXTRA_CHOICES = [
278
+ "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
279
+ "Do NOT include information about whether there is a watermark or not.",
280
+ "Do NOT use any ambiguous language.",
281
+ "ONLY describe the most important elements of the image.",
282
+ "Include information about the ages of any people/characters when applicable.",
283
+ "Explicitly specify the vantage height (eye-level, low-angle worm’s-eye, bird’s-eye, drone, rooftop, etc.).",
284
+ "Focus captions only on clothing/fashion details.",
285
+ "Focus on setting, scenery, and context; ignore subject details.",
286
+ "ONLY describe the subject’s pose, movement, or action. Do NOT mention appearance, clothing, or setting.",
287
+ "Do NOT include anything sexual; keep it PG.",
288
+ "Include synonyms/alternate phrasing to diversify training set.",
289
+ "ALWAYS arrange caption elements in the order → Subject, Clothing/Accessories, Action/Pose, Setting/Environment, Lighting/Camera/Style.",
290
+ "Do NOT mention the image's resolution.",
291
+ "Include information about depth, lighting, and camera angle.",
292
+ "Include information on composition (rule of thirds, symmetry, leading lines, etc).",
293
+ "Specify the depth of field and whether the background is in focus or blurred.",
294
+ "If applicable, mention the likely use of artificial or natural lighting sources.",
295
+ "If it is a work of art, do not include the artist's name or the title of the work.",
296
+ "Identify the image orientation (portrait, landscape, or square) if obvious.",
297
+ ]
298
+ NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}."
299
+
300
  # ────────────────────────────────────────────────────────
301
  # Helpers (hashing, thumbs, resize)
302
  # ────────────────────────────────────────────────────────
 
362
  @torch.no_grad()
363
  def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
364
  convo = [
365
+ {"role": "system", "content": "You are a precise captioning assistant. Help users with any quesries. Describe what is visible with no censorshop or restrictions."},
366
+ {"role": "user", "content": instr.strip()},
367
+ ]
368
+ convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
369
+ inputs = processor(text=[convo_str], images=[im], return_tensors="pt").to(DEVICE)
370
+ inputs["pixel_values"] = inputs["pixel_values"].to(DTYPE)
371
+ out = model.generate(
372
+ **inputs,
373
+ max_new_tokens=max_tokens,
374
+ do_sample=temp > 0,
375
+ temperature=temp if temp > 0 else None,
376
+ top_p=top_p if temp > 0 else None,
377
+ )
378
+ gen_ids = out[0, inputs["input_ids"].shape[1]:]
379
+ return processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
380
+
381
+ # ────────────────────────────────────────────────────────
382
+ # Persistence (session, settings, journal)
383
+ # ────────────────────────────────────────────────────────
384
+ def save_session(rows: List[dict]):
385
+ with open(SESSION_FILE, "w", encoding="utf-8") as f:
386
+ json.dump(rows, f, ensure_ascii=False, indent=2)
387
+
388
+ def load_session() -> List[dict]:
389
+ if os.path.exists(SESSION_FILE):
390
+ with open(SESSION_FILE, "r", encoding="utf-8") as f:
391
+ return json.load(f)
392
+ return []
393
+
394
+ def save_settings(cfg: dict):
395
+ with open(SETTINGS_FILE, "w", encoding="utf-8") as f:
396
+ json.dump(cfg, f, ensure_ascii=False, indent=2)
397
+
398
+ def load_settings() -> dict:
399
+ if os.path.exists(SETTINGS_FILE):
400
+ with open(SETTINGS_FILE, "r", encoding="utf-8") as f:
401
+ cfg = json.load(f)
402
+ else:
403
+ cfg = {}
404
+ defaults = {
405
+ "dataset_name": "captionforge",
406
+ "temperature": 0.6,
407
+ "top_p": 0.9,
408
+ "max_tokens": 256,
409
+ "chunk_mode": "Manual (step)",
410
+ "chunk_size": 10,
411
+ "max_side": min(896, MAX_SIDE_CAP),
412
+ "styles": ["Character training (short)"],
413
+ "extras": [],
414
+ "name": "",
415
+ "trigger": "",
416
+ "begin": "",
417
+ "end": "",
418
+ "shape_aliases_enabled": True,
419
+ "shape_aliases": []
420
+ }
421
+ for k, v in defaults.items():
422
+ cfg.setdefault(k, v)
423
+
424
+ # migrate legacy names
425
+ legacy_map = {
426
+ "Descriptive": "Descriptive (short)",
427
+ "LoRA (Flux_D Realism)": "LoRA (Flux_D Realism) (short)",
428
+ "Portrait (photography)": "Portrait (photography) (short)",
429
+ "Landscape (photography)": "Landscape (photography) (short)",
430
+ "Art analysis (no artist names)": "Art analysis (no artist names) (short)",
431
+ "E-commerce product": "E-commerce product (short)",
432
+ }
433
+ styles = cfg.get("styles") or []
434
+ migrated = []
435
+ for s in styles if isinstance(styles, list) else [styles]:
436
+ migrated.append(legacy_map.get(s, s))
437
+ migrated = [s for s in migrated if s in STYLE_OPTIONS]
438
+ if not migrated:
439
+ migrated = ["Descriptive (short)"]
440
+ cfg["styles"] = migrated
441
+ return cfg
442
+
443
+ def save_journal(data: dict):
444
+ with open(JOURNAL_FILE, "w", encoding="utf-8") as f:
445
+ json.dump(data, f, ensure_ascii=False, indent=2)
446
+
447
+ def load_journal() -> dict:
448
+ if os.path.exists(JOURNAL_FILE):
449
+ with open(JOURNAL_FILE, "r", encoding="utf-8") as f:
450
+ return json.load(f)
451
+ return {}
452
+
453
+ # ────────────────────────────────────────────────────────
454
+ # Shape Aliases (compile/cache/apply)
455
+ # ────────────────────────────────────────────────────────
456
+ def _compile_shape_aliases_from_file():
457
+ s = load_settings()
458
+ if not s.get("shape_aliases_enabled", True):
459
+ return []
460
+ compiled = []
461
+ for item in s.get("shape_aliases", []):
462
+ shape = (item.get("shape") or "").strip()
463
+ name = (item.get("name") or "").strip()
464
+ if not shape or not name:
465
+ continue
466
+ pat = rf"\b{re.escape(shape)}(?:-?shaped)?\b"
467
+ compiled.append((re.compile(pat, flags=re.I), name))
468
+ return compiled
469
+
470
+ _SHAPE_ALIASES = _compile_shape_aliases_from_file()
471
+
472
+ def _refresh_shape_aliases_cache():
473
+ global _SHAPE_ALIASES
474
+ _SHAPE_ALIASES = _compile_shape_aliases_from_file()
475
+
476
+ def apply_shape_aliases(caption: str) -> str:
477
+ for pat, name in _SHAPE_ALIASES:
478
+ caption = pat.sub(f"({name})", caption)
479
+ return caption
480
+
481
+ # ────────────────────────────────────────────────────────
482
+ # Logo loader
483
+ # ────────────────────────────────────────────────────────
484
+ def logo_b64_img() -> str:
485
+ candidates = [
486
+ os.path.join(APP_DIR, "captionforge-logo.png"),
487
+ "/home/user/app/captionforge-logo.png",
488
+ "captionforge-logo.png",
489
+ ]
490
+ for p in candidates:
491
+ if os.path.exists(p):
492
+ with open(p, "rb") as f:
493
+ b64 = base64.b64encode(f.read()).decode("ascii")
494
+ return f"<img src='data:image/png;base64,{b64}' alt='CaptionForge' class='cf-logo'>"
495
+ return ""
496
+
497
+ # ────────────────────────────────────────────────────────
498
+ # Import / Export helpers
499
+ # ────────────────────────────────────────────────────────
500
+ def sanitize_basename(s: str) -> str:
501
+ s = (s or "").strip() or "captionforge"
502
+ return re.sub(r"[^A-Za-z0-9._-]+", "_", s)[:120]
503
+
504
+ def ts_stamp() -> str:
505
+ return time.strftime("%Y%m%d_%H%M%S")
506
+
507
+ def export_csv(rows: List[dict], basename: str) -> str:
508
+ base = sanitize_basename(basename)
509
+ out = f"/tmp/{base}_{ts_stamp()}.csv"
510
+ with open(out, "w", newline="", encoding="utf-8") as f:
511
+ w = csv.writer(f)
512
+ w.writerow(["sha1", "filename", "caption"])
513
+ for r in rows:
514
+ w.writerow([r.get("sha1",""), r.get("filename",""), r.get("caption","")])
515
+ return out
516
+
517
+ def export_excel(rows: List[dict], basename: str) -> str:
518
+ try:
519
+ from openpyxl import Workbook
520
+ from openpyxl.drawing.image import Image as XLImage
521
+ except Exception as e:
522
+ raise RuntimeError("Excel export requires 'openpyxl' in requirements.txt.") from e
523
+ base = sanitize_basename(basename)
524
+ out = f"/tmp/{base}_{ts_stamp()}.xlsx"
525
+ wb = Workbook(); ws = wb.active; ws.title = "CaptionForge"
526
+ ws.append(["image", "filename", "caption"])
527
+ ws.column_dimensions["A"].width = 24
528
+ ws.column_dimensions["B"].width = 42
529
+ ws.column_dimensions["C"].width = 100
530
+ r_i = 2
531
+ for r in rows:
532
+ fn = r.get("filename",""); cap = r.get("caption","")
533
+ ws.cell(row=r_i, column=2, value=fn)
534
+ ws.cell(row=r_i, column=3, value=cap)
535
+ img_path = r.get("thumb_path") or r.get("path")
536
+ if img_path and os.path.exists(img_path):
537
+ try:
538
+ ximg = XLImage(img_path); ws.add_image(ximg, f"A{r_i}"); ws.row_dimensions[r_i].height = 110
539
+ except Exception:
540
+ pass
541
+ r_i += 1
542
+ wb.save(out); return out
543
+
544
+ def export_txt_zip(rows: List[dict], basename: str) -> str:
545
+ base = sanitize_basename(basename)
546
+ for name in os.listdir(TXT_EXPORT_DIR):
547
+ try: os.remove(os.path.join(TXT_EXPORT_DIR, name))
548
+ except Exception: pass
549
+ used: Dict[str,int] = {}
550
+ for r in rows:
551
+ orig = r.get("filename") or r.get("sha1") or "item"
552
+ stem = re.sub(r"\.[A-Za-z0-9]+$", "", orig)
553
+ stem = sanitize_basename(stem)
554
+ if stem in used:
555
+ used[stem] += 1
556
+ stem = f"{stem}_{used[stem]}"
557
+ else:
558
+ used[stem] = 0
559
+ txt_path = os.path.join(TXT_EXPORT_DIR, f"{stem}.txt")
560
+ with open(txt_path, "w", encoding="utf-8") as f:
561
+ f.write(r.get("caption", ""))
562
+ zpath = f"/tmp/{base}_{ts_stamp()}_txt.zip"
563
+ with zipfile.ZipFile(zpath, "w", zipfile.ZIP_DEFLATED) as z:
564
+ for name in os.listdir(TXT_EXPORT_DIR):
565
+ if name.endswith(".txt"):
566
+ z.write(os.path.join(TXT_EXPORT_DIR, name), arcname=name)
567
+ return zpath
568
+
569
+ def import_csv_jsonl(path: str, rows: List[dict]) -> List[dict]:
570
+ """Merge captions from CSV/JSONL into current rows by sha1→filename, updating only 'caption'."""
571
+ if not path or not os.path.exists(path):
572
+ return rows or []
573
+ rows = rows or []
574
+ by_sha = {r.get("sha1"): r for r in rows if r.get("sha1")}
575
+ by_fn = {r.get("filename"): r for r in rows if r.get("filename")}
576
+
577
+ def _apply(item: dict):
578
+ if not isinstance(item, dict):
579
+ return
580
+ cap = (item.get("caption") or "").strip()
581
+ if not cap:
582
+ return
583
+ sha = (item.get("sha1") or "").strip()
584
+ fn = (item.get("filename") or "").strip()
585
+ tgt = by_sha.get(sha) if sha else None
586
+ if not tgt and fn:
587
+ tgt = by_fn.get(fn)
588
+ if tgt is not None:
589
+ prev = tgt.get("caption", "")
590
+ if cap != prev:
591
+ tgt.setdefault("history", []).append(prev)
592
+ tgt["caption"] = cap
593
+
594
+ ext = os.path.splitext(path)[1].lower()
595
+ try:
596
+ if ext == ".csv":
597
+ with open(path, "r", encoding="utf-8") as f:
598
+ for row in csv.DictReader(f):
599
+ _apply(row)
600
+ elif ext in (".jsonl", ".ndjson"):
601
+ with open(path, "r", encoding="utf-8") as f:
602
+ for line in f:
603
+ line = line.strip()
604
+ if not line:
605
+ continue
606
+ try:
607
+ obj = json.loads(line)
608
+ except Exception:
609
+ continue
610
+ _apply(obj)
611
+ else:
612
+ with open(path, "r", encoding="utf-8") as f:
613
+ data = json.load(f)
614
+ if isinstance(data, list):
615
+ for obj in data:
616
+ _apply(obj)
617
+ except Exception as e:
618
+ print("[CaptionForge] Import merge failed:", e)
619
+ save_session(rows)
620
+ return rows
621
+
622
+ # ────────────────────────────────────────────────────────
623
+ # Batching / processing
624
+ # ────────────────────────────────────────────────────────
625
+ def adaptive_defaults(vram_gb: int) -> Tuple[int, int]:
626
+ if vram_gb >= 40: return 24, min(1024, MAX_SIDE_CAP)
627
+ if vram_gb >= 24: return 16, min(1024, MAX_SIDE_CAP)
628
+ if vram_gb >= 12: return 10, 896
629
+ if vram_gb >= 8: return 6, 768
630
+ return 4, 640
631
+
632
+ def warmup_if_needed(temp: float, top_p: float, max_tokens: int, max_side: int):
633
+ try:
634
+ im = Image.new("RGB", (min(128, max_side), min(128, max_side)), (127,127,127))
635
+ _ = caption_once(im, "Warm up.", temp=0.0, top_p=top_p, max_tokens=min(16, max_tokens))
636
+ except Exception as e:
637
+ print("[CaptionForge] Warm-up skipped:", e)
638
+
639
+ @spaces.GPU()
640
+ @torch.no_grad()
641
+ def process_batch(files: List[str], include_dups: bool, rows: List[dict],
642
+ instr_text: str, temp: float, top_p: float, max_tokens: int,
643
+ trigger_word: str, begin_text: str, end_text: str,
644
+ mode: str, chunk_size: int, max_side: int, step_once: bool,
645
+ progress=gr.Progress(track_tqdm=True)) -> Tuple[List[dict], List[str], List[str]]:
646
+ if torch.cuda.is_available():
647
+ torch.cuda.empty_cache()
648
+ files = files or []
649
+ if not files: return rows, [], []
650
+
651
+ auto_mode = mode == "Auto"
652
+ if auto_mode:
653
+ chunk_size, max_side = adaptive_defaults(VRAM_GB)
654
+
655
+ sha_seen = {r.get("sha1") for r in rows}
656
+ f_infos = []
657
+ for f in files:
658
+ try:
659
+ s = sha1_file(f)
660
+ except Exception:
661
+ continue
662
+ if (not include_dups) and (s in sha_seen):
663
+ continue
664
+ f_infos.append((f, os.path.basename(f), s))
665
+ if not f_infos: return rows, [], []
666
+
667
+ warmup_if_needed(temp, top_p, max_tokens, max_side)
668
+
669
+ chunks = [f_infos[i:i+chunk_size] for i in range(0, len(f_infos), chunk_size)]
670
+ if step_once:
671
+ chunks = chunks[:1]
672
+
673
+ error_log: List[str] = []
674
+ remaining: List[str] = []
675
+
676
+ for ci, chunk in enumerate(chunks):
677
+ for (f, fname, sha) in progress.tqdm(chunk, desc=f"Chunk {ci+1}/{len(chunks)}"):
678
+ thumb = ensure_thumb(f, sha, 96)
679
+ attempt = 0
680
+ cap = None
681
+ while attempt < 2:
682
+ attempt += 1
683
+ try:
684
+ im = Image.open(f).convert("RGB")
685
+ im = resize_for_model(im, max_side)
686
+ cap = caption_once(im, instr_text, temp, top_p, max_tokens)
687
+ cap = apply_shape_aliases(cap) # ← apply custom name mappings
688
+ cap = apply_prefix_suffix(cap, trigger_word, begin_text, end_text)
689
+ break
690
+ except Exception as e:
691
+ error_log.append(f"{fname} ({sha[:8]}): {type(e).__name__}: {e} [attempt {attempt}]")
692
+ rows.append({
693
+ "sha1": sha, "filename": fname, "path": f, "thumb_path": thumb,
694
+ "caption": cap or "", "history": []
695
+ })
696
+ save_session(rows)
697
+
698
+ if step_once and len(f_infos) > len(chunks[0]):
699
+ remaining = [fi[0] for fi in f_infos[len(chunks[0]):]]
700
+ save_journal({"remaining_files": remaining})
701
+ return rows, remaining, error_log
702
+
703
+ # ────────────────────────────────────────────────────────
704
+ # UI
705
+ # ────────────────────────────────────────────────────────
706
+ BASE_CSS = """
707
+ :root{--galleryW:50%;--tableW:50%;}
708
+ .gradio-container{max-width:100%!important}
709
+ .cf-hero{
710
+ display:flex; align-items:center; justify-content:center; gap:16px;
711
+ margin:4px 0 12px; text-align:center;
712
+ }
713
+ .cf-hero > div { text-align:center; }
714
+ .cf-logo{height:calc(3.25rem + 3 * 1.1rem + 18px);width:auto;object-fit:contain}
715
+ .cf-title{margin:0;font-size:3.25rem;line-height:1;letter-spacing:.2px}
716
+ .cf-sub{margin:6px 0 0;font-size:1.1rem;color:#cfd3da}
717
+ .cf-row{display:flex;gap:12px}
718
+ .cf-col-gallery{flex:0 0 var(--galleryW)}
719
+ .cf-col-table{flex:0 0 var(--tableW)}
720
+ /* Shared scroll look */
721
+ .cf-scroll{max-height:70vh; overflow-y:auto; border:1px solid #e6e6e6; border-radius:10px; padding:8px}
722
+ /* Uniform sizes */
723
+ #cfGal .grid > div { height: 96px; }
724
+ /* Hide SHA1 in table (first column) */
725
+ #cfTable table thead tr th:first-child { display:none; }
726
+ #cfTable table tbody tr td:first-child { display:none; }
727
+ """
728
+
729
+ with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
730
+ settings = load_settings()
731
+ settings["styles"] = [s for s in settings.get("styles", []) if s in STYLE_OPTIONS] or ["Character training (short)"]
732
+
733
+ gr.HTML(value=f"""
734
+ <div class="cf-hero">
735
+ {logo_b64_img()}
736
+ <div>
737
+ <h1 class="cf-title">CaptionForge</h1>
738
+ <div class="cf-sub">Batch captioning</div>
739
+ <div class="cf-sub">Scrollable editor & autosave</div>
740
+ <div class="cf-sub">CSV / Excel / TXT export</div>
741
+ </div>
742
+ </div>
743
+ <hr>
744
+ """)
745
+
746
+ # ===== Controls (top)
747
+ with gr.Group():
748
+ with gr.Row():
749
+ with gr.Column(scale=2):
750
+ style_checks = gr.CheckboxGroup(
751
+ choices=STYLE_OPTIONS,
752
+ value=settings.get("styles", ["Character training (short)"]),
753
+ label="Caption style (choose one or combine)"
754
+ )
755
+ with gr.Accordion("Extra options", open=True):
756
+ extra_opts = gr.CheckboxGroup(
757
+ choices=[NAME_OPTION] + EXTRA_CHOICES,
758
+ value=settings.get("extras", []),
759
+ label=None
760
+ )
761
+ with gr.Accordion("Name & Prefix/Suffix", open=True):
762
+ name_input = gr.Textbox(label="Person / Character Name", value=settings.get("name", ""))
763
+ trig = gr.Textbox(label="Trigger word")
764
+ add_start = gr.Textbox(label="Add text to start")
765
+ add_end = gr.Textbox(label="Add text to end")
766
+
767
+ with gr.Column(scale=1):
768
+ instruction_preview = gr.Textbox(label="Model Instructions", lines=14)
769
+ dataset_name = gr.Textbox(label="Dataset name (used for export file titles)", value=settings.get("dataset_name", "captionforge"))
770
+
771
+ with gr.Row():
772
+ chunk_mode = gr.Radio(
773
+ choices=["Auto", "Manual (all at once)", "Manual (step)"],
774
+ value=settings.get("chunk_mode", "Manual (step)"),
775
+ label="Batch mode"
776
+ )
777
+ chunk_size = gr.Slider(1, 50, settings.get("chunk_size", 10), step=1, label="Chunk size")
778
+ max_side = gr.Slider(256, MAX_SIDE_CAP, settings.get("max_side", min(896, MAX_SIDE_CAP)), step=32, label="Max side (resize)")
779
+
780
+ # Shape Aliases UI (Accordion)
781
+ with gr.Accordion("Shape Aliases", open=False):
782
+ enable_aliases = gr.Checkbox(label="Enable shape alias replacements", value=True)
783
+ alias_table = gr.Dataframe(
784
+ headers=["shape (literal token)", "name to insert"],
785
+ datatype=["str","str"],
786
+ row_count=(1, "dynamic"),
787
+ wrap=True,
788
+ interactive=True
789
+ )
790
+ save_aliases_btn = gr.Button("Save aliases")
791
+ save_status = gr.Markdown()
792
+
793
+ # init values
794
+ init_rows = [[it.get("shape",""), it.get("name","")] for it in settings.get("shape_aliases", [])]
795
+ if not init_rows:
796
+ init_rows = [["diamond", "starkey-emblem"]]
797
+ enable_aliases.value = bool(settings.get("shape_aliases_enabled", True))
798
+ alias_table.value = init_rows
799
+
800
+ def set_shape_alias_rows(enabled, rows):
801
+ cfg = load_settings()
802
+ cfg["shape_aliases_enabled"] = bool(enabled)
803
+ cfg["shape_aliases"] = [
804
+ {"shape": (r[0] or "").strip(), "name": (r[1] or "").strip()}
805
+ for r in (rows or []) if r and (r[0] or "").strip() and (r[1] or "").strip()
806
+ ]
807
+ save_settings(cfg)
808
+ _refresh_shape_aliases_cache()
809
+ return "Saved shape alias options."
810
+
811
+ save_aliases_btn.click(
812
+ set_shape_alias_rows,
813
+ inputs=[enable_aliases, alias_table],
814
+ outputs=[save_status]
815
+ )
816
+
817
+ # Auto-refresh instruction & persist key controls
818
+ def _refresh_instruction(styles, extra, name_value, trigv, begv, endv):
819
+ instr = final_instruction(styles or ["Character training (short)"], extra or [], name_value)
820
+ cfg = load_settings()
821
+ cfg.update({
822
+ "styles": styles or ["Character training (short)"],
823
+ "extras": extra or [],
824
+ "name": name_value,
825
+ "trigger": trigv, "begin": begv, "end": endv
826
+ })
827
+ save_settings(cfg)
828
+ return instr
829
+
830
+ for comp in [style_checks, extra_opts, name_input, trig, add_start, add_end]:
831
+ comp.change(_refresh_instruction, inputs=[style_checks, extra_opts, name_input, trig, add_start, add_end], outputs=[instruction_preview])
832
+
833
+ # ===== File inputs & actions
834
+ with gr.Accordion("Uploaded images", open=True) as uploads_acc:
835
+ input_files = gr.File(label="Drop images", file_types=["image"], file_count="multiple", type="filepath")
836
+ import_file = gr.File(label="Import CSV or JSONL (merge captions)")
837
+
838
+ with gr.Row():
839
+ import_btn = gr.Button("Import → Merge")
840
+ resume_btn = gr.Button("Resume last run", variant="secondary")
841
+ run_button = gr.Button("Caption batch", variant="primary")
842
+
843
+ # Step-chunk controls
844
+ step_panel = gr.Group(visible=False)
845
+ with step_panel:
846
+ step_msg = gr.Markdown("")
847
+ step_next = gr.Button("Process next chunk")
848
+ step_finish = gr.Button("Finish")
849
+
850
+ # States
851
+ rows_state = gr.State(load_session())
852
+ visible_count_state = gr.State(100)
853
+ pending_files = gr.State([])
854
+ remaining_state = gr.State([])
855
+ autosave_md = gr.Markdown("Ready.")
856
+
857
+ # Error log
858
+ with gr.Accordion("Error log", open=False):
859
+ log_md = gr.Markdown("")
860
+
861
+ # Side-by-side gallery + table with shared scroll
862
+ with gr.Row():
863
+ with gr.Column(scale=1):
864
+ gallery = gr.Gallery(
865
+ label="Results (images)",
866
+ show_label=True,
867
+ columns=10,
868
+ height=520,
869
+ elem_id="cfGal",
870
+ elem_classes=["cf-scroll"]
871
+ )
872
+ with gr.Column(scale=1):
873
+ table = gr.Dataframe(
874
+ label="Editable captions (whole session)",
875
+ value=[],
876
+ headers=["sha1", "filename", "caption"], # sha1 hidden via CSS
877
+ interactive=True,
878
+ elem_id="cfTable",
879
+ elem_classes=["cf-scroll"]
880
+ )
881
+
882
+ # Exports aligned under buttons
883
+ with gr.Row():
884
+ with gr.Column():
885
+ export_csv_btn = gr.Button("Export CSV")
886
+ csv_file = gr.File(label="CSV file", visible=False)
887
+ with gr.Column():
888
+ export_xlsx_btn = gr.Button("Export Excel (.xlsx)")
889
+ xlsx_file = gr.File(label="Excel file", visible=False)
890
+ with gr.Column():
891
+ export_txt_btn = gr.Button("Export captions as .txt (zip)")
892
+ txt_zip = gr.File(label="TXT zip", visible=False)
893
+
894
+ # Render helpers
895
+ def _render(rows, vis_n):
896
+ rows = rows or []
897
+ n = max(0, int(vis_n) if vis_n else 0)
898
+ win = rows[:n]
899
+
900
+ # Gallery expects list of paths
901
+ gal = []
902
+ for r in win:
903
+ p = r.get("thumb_path") or r.get("path")
904
+ if isinstance(p, str) and os.path.exists(p):
905
+ gal.append(p)
906
+
907
+ # Dataframe rows (sha1 kept for mapping; hidden via CSS)
908
+ df = [[str(r.get("sha1","")), str(r.get("filename","")), str(r.get("caption",""))] for r in win]
909
+ info = f"Showing {len(win)} of {len(rows)}"
910
+ return gal, df, info
911
+
912
+ # Initial loads
913
+ demo.load(lambda s,e,n: final_instruction(s or ["Character training (short)"], e or [], n),
914
+ inputs=[style_checks, extra_opts, name_input], outputs=[instruction_preview])
915
+ demo.load(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
916
+
917
+ # Scroll sync (Gallery + Table)
918
+ gr.HTML("""
919
+ <script>
920
+ (function () {
921
+ function findGalleryScrollRoot() {
922
+ const host = document.querySelector("#cfGal");
923
+ if (!host) return null;
924
+ return host.querySelector(".grid") || host.querySelector("[data-testid='gallery']") || host;
925
+ }
926
+ function findTableScrollRoot() {
927
+ const host = document.querySelector("#cfTable");
928
+ if (!host) return null;
929
+ return host.querySelector(".wrap") ||
930
+ host.querySelector(".dataframe-wrap") ||
931
+ (host.querySelector("table") ? host.querySelector("table").parentElement : null) ||
932
+ host;
933
+ }
934
+ function syncScroll(a, b) {
935
+ if (!a || !b) return;
936
+ let lock = false;
937
+ const onScrollA = () => { if (lock) return; lock = true; b.scrollTop = a.scrollTop; lock = false; };
938
+ const onScrollB = () => { if (lock) return; lock = true; a.scrollTop = b.scrollTop; lock = false; };
939
+ a.addEventListener("scroll", onScrollA, { passive: true });
940
+ b.addEventListener("scroll", onScrollB, { passive: true });
941
+ }
942
+ let tries = 0;
943
+ const timer = setInterval(() => {
944
+ tries++;
945
+ const gal = findGalleryScrollRoot();
946
+ const tab = findTableScrollRoot();
947
+ if (gal && tab) {
948
+ const H = Math.min(gal.clientHeight || 520, tab.clientHeight || 520);
949
+ gal.style.maxHeight = H + "px";
950
+ gal.style.overflowY = "auto";
951
+ tab.style.maxHeight = H + "px";
952
+ tab.style.overflowY = "auto";
953
+ syncScroll(gal, tab);
954
+ clearInterval(timer);
955
+ }
956
+ if (tries > 20) clearInterval(timer);
957
+ }, 100);
958
+ })();
959
+ </script>
960
+ """)
961
+
962
+ # Table edits -> save + rerender
963
+ def on_table_edit(df_rows, rows):
964
+ by_sha = {r.get("sha1"): r for r in (rows or [])}
965
+ for row in (df_rows or []):
966
+ if not row:
967
+ continue
968
+ sha = str(row[0]) if len(row) > 0 else ""
969
+ tgt = by_sha.get(sha)
970
+ if not tgt:
971
+ continue
972
+ prev = tgt.get("caption", "")
973
+ newc = row[2] if len(row) > 2 else prev
974
+ if newc != prev:
975
+ hist = tgt.setdefault("history", [])
976
+ hist.append(prev)
977
+ tgt["caption"] = newc
978
+ save_session(rows or [])
979
+ return rows, f"Saved • {time.strftime('%H:%M:%S')}"
980
+ table.change(on_table_edit, inputs=[table, rows_state], outputs=[rows_state, autosave_md])\
981
+ .then(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
982
+
983
+ # Import merge
984
+ def do_import(file, rows):
985
+ if not file:
986
+ return rows, gr.update(value="No file"), rows
987
+ rows = import_csv_jsonl(file.name, rows or [])
988
+ return rows, gr.update(value=f"Merged from {os.path.basename(file.name)} at {time.strftime('%H:%M:%S')}"), rows
989
+ import_btn.click(do_import, inputs=[import_file, rows_state], outputs=[rows_state, autosave_md, rows_state])\
990
+ .then(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
991
+
992
+ # Prepare run (hash dupes)
993
+ def prepare_run(files, rows, pending):
994
+ files = files or []
995
+ rows = rows or []
996
+ pending = pending or []
997
+ if not files:
998
+ return [], gr.update(open=True), []
999
+ existing = {r.get("sha1") for r in rows if r.get("sha1")}
1000
+ pending_sha = set()
1001
+ for f in pending:
1002
+ try:
1003
+ pending_sha.add(sha1_file(f))
1004
+ except Exception:
1005
+ pass
1006
+ keep = []
1007
+ for f in files:
1008
+ try:
1009
+ s = sha1_file(f)
1010
+ except Exception:
1011
+ continue
1012
+ if s not in existing and s not in pending_sha:
1013
+ keep.append(f)
1014
+ return keep, gr.update(open=False), keep
1015
+ input_files.change(prepare_run, inputs=[input_files, rows_state, pending_files], outputs=[pending_files, uploads_acc, pending_files])
1016
+
1017
+ # Step visibility helper (dynamic message)
1018
+ def _step_visibility(remain, mode):
1019
+ if (mode == "Manual (step)") and remain:
1020
+ return gr.update(visible=True), gr.update(value=f"{len(remain)} files remain. Process next chunk?")
1021
+ return gr.update(visible=False), gr.update(value="")
1022
+
1023
+ # Run helpers (use settings for gen params)
1024
+ def _run_once(pending, rows, instr, trigv, begv, endv, mode, csize, mside):
1025
+ cfg = load_settings()
1026
+ t = cfg.get("temperature", 0.6)
1027
+ p = cfg.get("top_p", 0.9)
1028
+ m = cfg.get("max_tokens", 256)
1029
+ step_once = (mode == "Manual (step)")
1030
+ new_rows, remaining, errors = process_batch(
1031
+ pending, False, rows or [], instr, t, p, m, trigv, begv, endv,
1032
+ mode, int(csize), int(mside), step_once
1033
+ )
1034
+ log = "\n".join(errors) if errors else "(no errors)"
1035
+ return new_rows, remaining, log, f"Saved • {time.strftime('%H:%M:%S')}"
1036
+
1037
+ run_button.click(
1038
+ _run_once,
1039
+ inputs=[pending_files, rows_state, instruction_preview, trig, add_start, add_end, chunk_mode, chunk_size, max_side],
1040
+ outputs=[rows_state, remaining_state, log_md, autosave_md]
1041
+ ).then(
1042
+ _render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md]
1043
+ ).then(
1044
+ _step_visibility, inputs=[remaining_state, chunk_mode], outputs=[step_panel, step_msg]
1045
+ )
1046
+
1047
+ def _resume(rows):
1048
+ j = load_journal()
1049
+ rem = j.get("remaining_files", [])
1050
+ return rem, ("Nothing to resume" if not rem else f"Loaded {len(rem)} remaining")
1051
+ resume_btn.click(_resume, inputs=[rows_state], outputs=[pending_files, autosave_md])
1052
+
1053
+ def _step_next(remain, rows, instr, trigv, begv, endv, mode, csize, mside):
1054
+ return _run_once(remain, rows, instr, trigv, begv, endv, mode, csize, mside)
1055
+ step_next.click(
1056
+ _step_next,
1057
+ inputs=[remaining_state, rows_state, instruction_preview, trig, add_start, add_end, chunk_mode, chunk_size, max_side],
1058
+ outputs=[rows_state, remaining_state, log_md, autosave_md]
1059
+ ).then(
1060
+ _render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md]
1061
+ ).then(
1062
+ _step_visibility, inputs=[remaining_state, chunk_mode], outputs=[step_panel, step_msg]
1063
+ )
1064
+
1065
+ step_finish.click(lambda: (gr.update(visible=False), gr.update(value=""), []), outputs=[step_panel, step_msg, remaining_state])
1066
+
1067
+ # Exports (reveal file widgets when created)
1068
+ export_csv_btn.click(
1069
+ lambda rows, base: (export_csv(rows or [], base), gr.update(visible=True)),
1070
+ inputs=[rows_state, dataset_name], outputs=[csv_file, csv_file]
1071
+ )
1072
+ export_xlsx_btn.click(
1073
+ lambda rows, base: (export_excel(rows or [], base), gr.update(visible=True)),
1074
+ inputs=[rows_state, dataset_name], outputs=[xlsx_file, xlsx_file]
1075
+ )
1076
+ export_txt_btn.click(
1077
+ lambda rows, base: (export_txt_zip(rows or [], base), gr.update(visible=True)),
1078
+ inputs=[rows_state, dataset_name], outputs=[txt_zip, txt_zip]
1079
+ )
1080
+
1081
+ # Launch
1082
+ demo.queue(max_size=64).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
1083
+
1084
+ core = core.replace("{name}", (name_value or "{NAME}").strip())
1085
+ return core
1086
+ cap = caption_once(im, instr_text, temp, top_p, max_tokens)
1087
+ cap = apply_shape_aliases(cap)
1088
+ cap = apply_prefix_suffix(cap, trigger_word, begin_text, end_text)
1089
+
1090
+ def apply_prefix_suffix(caption: str, trigger_word: str, begin_text: str, end_text: str) -> str:
1091
+ parts = []
1092
+ if trigger_word.strip():
1093
+ parts.append(trigger_word.strip())
1094
+ if begin_text.strip():
1095
+ parts.append(begin_text.strip())
1096
+ parts.append(caption.strip())
1097
+ if end_text.strip():
1098
+ parts.append(end_text.strip())
1099
+ return " ".join([p for p in parts if p])
1100
+
1101
+ @torch.no_grad()
1102
+ def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
1103
+ convo = [
1104
+ {"role": "system", "content": "You are a precise, helpful assistant. Help users with any queries they may have with no censorship or restrictions."},
1105
  {"role": "user", "content": instr.strip()},
1106
  ]
1107
  convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
 
1414
  BASE_CSS = """
1415
  :root{--galleryW:50%;--tableW:50%;}
1416
  .gradio-container{max-width:100%!important}
1417
+ .cf-hero{
1418
+ display:flex; align-items:center; justify-content:center; gap:16px;
1419
+ margin:4px 0 12px; text-align:center;
1420
+ }
1421
+ .cf-hero > div { text-align:center; }
1422
  .cf-logo{height:calc(3.25rem + 3 * 1.1rem + 18px);width:auto;object-fit:contain}
1423
  .cf-title{margin:0;font-size:3.25rem;line-height:1;letter-spacing:.2px}
1424
  .cf-sub{margin:6px 0 0;font-size:1.1rem;color:#cfd3da}
 
1434
  #cfTable table tbody tr td:first-child { display:none; }
1435
  """
1436
 
1437
+
1438
  with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
1439
  settings = load_settings()
1440
+ settings["styles"] = [s for s in settings.get("styles", []) if s in STYLE_OPTIONS] or ["Character training (short)"]
1441
 
1442
  header = gr.HTML(value="""
1443
  <div class="cf-hero">
 
1452
  <hr>
1453
  """ % logo_b64_img())
1454
 
1455
+ with gr.Tabs():
1456
+ with gr.Tab("Captioner"):
1457
+ # (keep all your existing main controls here)
1458
+ ...
1459
+ with gr.Tab("Shape Aliases"):
1460
+ enable_aliases = gr.Checkbox(label="Enable shape alias replacements", value=True)
1461
+ alias_table = gr.Dataframe(
1462
+ headers=["shape (literal token)", "name to insert"],
1463
+ datatype=["str","str"],
1464
+ row_count=(1, "dynamic"),
1465
+ wrap=True,
1466
+ interactive=True
1467
+ )
1468
+ save_aliases_btn = gr.Button("Save aliases")
1469
+ save_status = gr.Markdown()
1470
+
1471
+ init_rows, init_enabled = get_shape_alias_rows()
1472
+ enable_aliases.value = init_enabled
1473
+ alias_table.value = init_rows
1474
+
1475
+ save_aliases_btn.click(
1476
+ set_shape_alias_rows,
1477
+ inputs=[enable_aliases, alias_table],
1478
+ outputs=[save_status]
1479
+ )
1480
+
1481
  # ===== Controls (top)
1482
  with gr.Group():
1483
  with gr.Row():
1484
  with gr.Column(scale=2):
1485
  style_checks = gr.CheckboxGroup(
1486
  choices=STYLE_OPTIONS,
1487
+ value=settings.get("styles", ["Character training (short)"]),
1488
  label="Caption style (choose one or combine)"
1489
  )
1490
  with gr.Accordion("Extra options", open=True):
 
1514
 
1515
  # Auto-refresh instruction & persist key controls
1516
  def _refresh_instruction(styles, extra, name_value, trigv, begv, endv):
1517
+ instr = final_instruction(styles or ["Character training (short)"], extra or [], name_value)
1518
  cfg = load_settings()
1519
  cfg.update({
1520
+ "styles": styles or ["Character training (short)"],
1521
  "extras": extra or [],
1522
  "name": name_value,
1523
  "trigger": trigv, "begin": begv, "end": endv
 
1608
  return gal, df, info
1609
 
1610
  # Initial loads
1611
+ demo.load(lambda s,e,n: final_instruction(s or ["Character training (short)"], e or [], n),
1612
  inputs=[style_checks, extra_opts, name_input], outputs=[instruction_preview])
1613
  demo.load(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
1614
 
 
1679
  table.change(on_table_edit, inputs=[table, rows_state], outputs=[rows_state, autosave_md])\
1680
  .then(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
1681
 
1682
+ def _compile_shape_aliases_from_file():
1683
+ s = load_settings()
1684
+ if not s.get("shape_aliases_enabled", True):
1685
+ return []
1686
+ compiled = []
1687
+ for item in s.get("shape_aliases", []):
1688
+ shape = (item.get("shape") or "").strip()
1689
+ name = (item.get("name") or "").strip()
1690
+ if not shape or not name:
1691
+ continue
1692
+ pat = rf"\b{re.escape(shape)}(?:-?shaped)?\b"
1693
+ compiled.append((re.compile(pat, flags=re.I), name))
1694
+ return compiled
1695
+
1696
+ SHAPE_ALIASES = _compile_shape_aliases_from_file()
1697
+
1698
+ def apply_shape_aliases(caption: str) -> str:
1699
+ for pat, name in SHAPE_ALIASES:
1700
+ caption = pat.sub(f"({name})", caption)
1701
+ return caption
1702
+
1703
+
1704
  # Import merge
1705
  def do_import(file, rows):
1706
  if not file: