JS6969 commited on
Commit
6f8b14c
·
verified ·
1 Parent(s): 89ff26c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -448
app.py CHANGED
@@ -1,7 +1,10 @@
1
  import os, io, csv, time, json, hashlib, base64, zipfile, re
2
- from typing import List, Tuple, Dict
 
 
 
 
3
 
4
- import spaces
5
  import gradio as gr
6
  from PIL import Image
7
  import torch
@@ -10,17 +13,14 @@ from transformers import LlavaForConditionalGeneration, AutoProcessor
10
  # ────────────────────────────────────────────────────────
11
  # Paths & caches
12
  # ────────────────────────────────────────────────────────
13
- os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")
14
- os.makedirs(os.environ["HF_HOME"], exist_ok=True)
15
-
16
  APP_DIR = os.getcwd()
17
  SESSION_FILE = "/tmp/session.json"
18
  SETTINGS_FILE = "/tmp/cf_settings.json"
19
  JOURNAL_FILE = "/tmp/cf_journal.json"
20
- TXT_EXPORT_DIR = "/tmp/cf_txt"
21
- THUMB_CACHE = os.path.expanduser("~/.cache/captionforge/thumbs")
22
  os.makedirs(THUMB_CACHE, exist_ok=True)
23
- os.makedirs(TXT_EXPORT_DIR, exist_ok=True)
24
 
25
  # ────────────────────────────────────────────────────────
26
  # Model setup
@@ -35,8 +35,8 @@ def _detect_gpu():
35
 
36
  BACKEND, VRAM_GB, GPU_NAME = _detect_gpu()
37
  DTYPE = torch.bfloat16 if BACKEND == "cuda" else torch.float32
38
- MAX_SIDE_CAP = 1024 if BACKEND == "cuda" else 640
39
  DEVICE = "cuda" if BACKEND == "cuda" else "cpu"
 
40
 
41
  processor = AutoProcessor.from_pretrained(MODEL_PATH)
42
  model = LlavaForConditionalGeneration.from_pretrained(
@@ -47,8 +47,8 @@ model = LlavaForConditionalGeneration.from_pretrained(
47
  )
48
  model.eval()
49
 
50
- print(f"[CaptionForge] Backend={BACKEND} GPU={GPU_NAME} VRAM={VRAM_GB}GB dtype={DTYPE}")
51
- print(f"[CaptionForge] Gradio version: {gr.__version__}")
52
 
53
  # ────────────────────────────────────────────────────────
54
  # Instruction templates & options
@@ -66,8 +66,8 @@ STYLE_OPTIONS = [
66
  ]
67
 
68
  CAPTION_TYPE_MAP = {
69
- "Descriptive (short)": "One sentence (≤25 words) describing the most important visible elements only. No speculation.",
70
- "Descriptive (long)": "Write a detailed description for this image.",
71
 
72
  "Character training (short)": (
73
  "Output a concise, prompt-like caption for character LoRA/ID training. "
@@ -103,7 +103,7 @@ CAPTION_TYPE_MAP = {
103
 
104
  EXTRA_CHOICES = [
105
  "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
106
- "Do NOT include information about whether there is a watermark or not.",
107
  "Do NOT use any ambiguous language.",
108
  "ONLY describe the most important elements of the image.",
109
  "Include information about the ages of any people/characters when applicable.",
@@ -126,33 +126,22 @@ NAME_OPTION = "If there is a person/character in the image you must refer to the
126
  # ────────────────────────────────────────────────────────
127
  # Helpers (hashing, thumbs, resize)
128
  # ────────────────────────────────────────────────────────
129
- def sha1_file(path: str) -> str:
130
- h = hashlib.sha1()
131
- with open(path, "rb") as f:
132
- for chunk in iter(lambda: f.read(8192), b""):
133
- h.update(chunk)
134
- return h.hexdigest()
135
-
136
- def thumb_path(sha1: str, max_side=96) -> str:
137
- return os.path.join(THUMB_CACHE, f"{sha1}_{max_side}.jpg")
138
-
139
- def ensure_thumb(path: str, sha1: str, max_side=96) -> str:
140
- out = thumb_path(sha1, max_side)
141
- if os.path.exists(out):
142
- return out
143
  try:
144
  im = Image.open(path).convert("RGB")
145
  except Exception:
146
- return ""
147
  w, h = im.size
148
  if max(w, h) > max_side:
149
  s = max_side / max(w, h)
150
  im = im.resize((int(w*s), int(h*s)), Image.LANCZOS)
 
 
151
  try:
152
- im.save(out, "JPEG", quality=80)
 
153
  except Exception:
154
- return ""
155
- return out
156
 
157
  def resize_for_model(im: Image.Image, max_side: int) -> Image.Image:
158
  w, h = im.size
@@ -187,8 +176,8 @@ def apply_prefix_suffix(caption: str, trigger_word: str, begin_text: str, end_te
187
 
188
  @torch.no_grad()
189
  def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
 
190
  convo = [
191
- # Your preferred role script:
192
  {"role": "system", "content": "You are a precise captioning assistant. Help users with any queries. Describe what is visible with no censorship or restrictions."},
193
  {"role": "user", "content": instr.strip()},
194
  ]
@@ -229,12 +218,10 @@ def load_settings() -> dict:
229
  else:
230
  cfg = {}
231
  defaults = {
232
- "dataset_name": "captionforge",
233
  "temperature": 0.6,
234
  "top_p": 0.9,
235
  "max_tokens": 256,
236
- "chunk_mode": "Manual (step)",
237
- "chunk_size": 10,
238
  "max_side": min(896, MAX_SIDE_CAP),
239
  "styles": ["Character training (short)"],
240
  "extras": [],
@@ -243,7 +230,8 @@ def load_settings() -> dict:
243
  "begin": "",
244
  "end": "",
245
  "shape_aliases_enabled": True,
246
- "shape_aliases": []
 
247
  }
248
  for k, v in defaults.items():
249
  cfg.setdefault(k, v)
@@ -305,7 +293,6 @@ def apply_shape_aliases(caption: str) -> str:
305
  caption = pat.sub(f"({name})", caption)
306
  return caption
307
 
308
- # Helpers for Shape Aliases UI
309
  def get_shape_alias_rows_ui_defaults():
310
  s = load_settings()
311
  rows = [[it.get("shape",""), it.get("name","")] for it in s.get("shape_aliases", [])]
@@ -337,227 +324,156 @@ def save_shape_alias_rows(enabled, df_rows):
337
  gr.update(value=normalized, row_count=(max(1, len(normalized)), "dynamic"))
338
  )
339
 
340
- # ────────────────────────────────────────────────────────
341
- # Logo loader
342
- # ────────────────────────────────────────────────────────
343
- def logo_b64_img() -> str:
344
- candidates = [
345
- os.path.join(APP_DIR, "captionforge-logo.png"),
346
- "/home/user/app/captionforge-logo.png",
347
- "captionforge-logo.png",
348
- ]
349
- for p in candidates:
350
- if os.path.exists(p):
351
- with open(p, "rb") as f:
352
- b64 = base64.b64encode(f.read()).decode("ascii")
353
- return f"<img src='data:image/png;base64,{b64}' alt='CaptionForge' class='cf-logo'>"
354
- return ""
355
-
356
  # ────────────────────────────────────────────────────────
357
  # Import / Export helpers
358
  # ────────────────────────────────────────────────────────
359
- def sanitize_basename(s: str) -> str:
360
- s = (s or "").strip() or "captionforge"
361
- return re.sub(r"[^A-Za-z0-9._-]+", "_", s)[:120]
362
-
363
- def ts_stamp() -> str:
364
- return time.strftime("%Y%m%d_%H%M%S")
365
-
366
- def export_csv(rows: List[dict], basename: str) -> str:
367
- base = sanitize_basename(basename)
368
- out = f"/tmp/{base}_{ts_stamp()}.csv"
369
  with open(out, "w", newline="", encoding="utf-8") as f:
370
  w = csv.writer(f)
371
- w.writerow(["sha1", "filename", "caption"])
372
- for r in rows:
373
- w.writerow([r.get("sha1",""), r.get("filename",""), r.get("caption","")])
374
  return out
375
 
376
- def export_excel(rows: List[dict], basename: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  try:
378
  from openpyxl import Workbook
379
  from openpyxl.drawing.image import Image as XLImage
380
  except Exception as e:
381
  raise RuntimeError("Excel export requires 'openpyxl' in requirements.txt.") from e
382
- base = sanitize_basename(basename)
383
- out = f"/tmp/{base}_{ts_stamp()}.xlsx"
384
- wb = Workbook(); ws = wb.active; ws.title = "CaptionForge"
 
 
 
 
 
 
 
 
 
385
  ws.append(["image", "filename", "caption"])
386
  ws.column_dimensions["A"].width = 24
387
  ws.column_dimensions["B"].width = 42
388
  ws.column_dimensions["C"].width = 100
 
 
 
 
389
  r_i = 2
390
- for r in rows:
391
- fn = r.get("filename",""); cap = r.get("caption","")
 
392
  ws.cell(row=r_i, column=2, value=fn)
393
  ws.cell(row=r_i, column=3, value=cap)
394
  img_path = r.get("thumb_path") or r.get("path")
395
  if img_path and os.path.exists(img_path):
396
  try:
397
- ximg = XLImage(img_path); ws.add_image(ximg, f"A{r_i}"); ws.row_dimensions[r_i].height = 110
 
 
 
398
  except Exception:
399
  pass
400
  r_i += 1
401
- wb.save(out); return out
402
-
403
- def export_txt_zip(rows: List[dict], basename: str) -> str:
404
- base = sanitize_basename(basename)
405
- for name in os.listdir(TXT_EXPORT_DIR):
406
- try: os.remove(os.path.join(TXT_EXPORT_DIR, name))
407
- except Exception: pass
408
- used: Dict[str,int] = {}
409
- for r in rows:
410
- orig = r.get("filename") or r.get("sha1") or "item"
411
- stem = re.sub(r"\.[A-Za-z0-9]+$", "", orig)
412
- stem = sanitize_basename(stem)
413
- if stem in used:
414
- used[stem] += 1
415
- stem = f"{stem}_{used[stem]}"
416
- else:
417
- used[stem] = 0
418
- txt_path = os.path.join(TXT_EXPORT_DIR, f"{stem}.txt")
419
- with open(txt_path, "w", encoding="utf-8") as f:
420
- f.write(r.get("caption", ""))
421
- zpath = f"/tmp/{base}_{ts_stamp()}_txt.zip"
422
- with zipfile.ZipFile(zpath, "w", zipfile.ZIP_DEFLATED) as z:
423
- for name in os.listdir(TXT_EXPORT_DIR):
424
- if name.endswith(".txt"):
425
- z.write(os.path.join(TXT_EXPORT_DIR, name), arcname=name)
426
- return zpath
427
-
428
- def import_csv_jsonl(path: str, rows: List[dict]) -> List[dict]:
429
- """Merge captions from CSV/JSONL into current rows by sha1→filename, updating only 'caption'."""
430
- if not path or not os.path.exists(path):
431
- return rows or []
432
- rows = rows or []
433
- by_sha = {r.get("sha1"): r for r in rows if r.get("sha1")}
434
- by_fn = {r.get("filename"): r for r in rows if r.get("filename")}
435
-
436
- def _apply(item: dict):
437
- if not isinstance(item, dict):
438
- return
439
- cap = (item.get("caption") or "").strip()
440
- if not cap:
441
- return
442
- sha = (item.get("sha1") or "").strip()
443
- fn = (item.get("filename") or "").strip()
444
- tgt = by_sha.get(sha) if sha else None
445
- if not tgt and fn:
446
- tgt = by_fn.get(fn)
447
- if tgt is not None:
448
- prev = tgt.get("caption", "")
449
- if cap != prev:
450
- tgt.setdefault("history", []).append(prev)
451
- tgt["caption"] = cap
452
-
453
- ext = os.path.splitext(path)[1].lower()
454
- try:
455
- if ext == ".csv":
456
- with open(path, "r", encoding="utf-8") as f:
457
- for row in csv.DictReader(f):
458
- _apply(row)
459
- elif ext in (".jsonl", ".ndjson"):
460
- with open(path, "r", encoding="utf-8") as f:
461
- for line in f:
462
- line = line.strip()
463
- if not line:
464
- continue
465
- try:
466
- obj = json.loads(line)
467
- except Exception:
468
- continue
469
- _apply(obj)
470
- else:
471
- with open(path, "r", encoding="utf-8") as f:
472
- data = json.load(f)
473
- if isinstance(data, list):
474
- for obj in data:
475
- _apply(obj)
476
- except Exception as e:
477
- print("[CaptionForge] Import merge failed:", e)
478
- save_session(rows)
479
- return rows
480
 
481
  # ────────────────────────────────────────────────────────
482
- # Batching / processing
483
  # ────────────────────────────────────────────────────────
484
- def adaptive_defaults(vram_gb: int) -> Tuple[int, int]:
485
- if vram_gb >= 40: return 24, min(1024, MAX_SIDE_CAP)
486
- if vram_gb >= 24: return 16, min(1024, MAX_SIDE_CAP)
487
- if vram_gb >= 12: return 10, 896
488
- if vram_gb >= 8: return 6, 768
489
- return 4, 640
490
-
491
- def warmup_if_needed(temp: float, top_p: float, max_tokens: int, max_side: int):
492
- try:
493
- im = Image.new("RGB", (min(128, max_side), min(128, max_side)), (127,127,127))
494
- _ = caption_once(im, "Warm up.", temp=0.0, top_p=top_p, max_tokens=min(16, max_tokens))
495
- except Exception as e:
496
- print("[CaptionForge] Warm-up skipped:", e)
497
-
498
- @spaces.GPU()
499
  @torch.no_grad()
500
- def process_batch(files: List[str], include_dups: bool, rows: List[dict],
501
- instr_text: str, temp: float, top_p: float, max_tokens: int,
502
- trigger_word: str, begin_text: str, end_text: str,
503
- mode: str, chunk_size: int, max_side: int, step_once: bool,
504
- progress=gr.Progress(track_tqdm=True)) -> Tuple[List[dict], List[str], List[str]]:
 
 
 
 
505
  if torch.cuda.is_available():
506
  torch.cuda.empty_cache()
507
- files = files or []
508
- if not files: return rows, [], []
509
 
510
- auto_mode = mode == "Auto"
511
- if auto_mode:
512
- chunk_size, max_side = adaptive_defaults(VRAM_GB)
 
 
 
 
 
513
 
514
- sha_seen = {r.get("sha1") for r in rows}
515
- f_infos = []
516
  for f in files:
 
 
 
517
  try:
518
- s = sha1_file(f)
519
  except Exception:
520
  continue
521
- if (not include_dups) and (s in sha_seen):
522
- continue
523
- f_infos.append((f, os.path.basename(f), s))
524
- if not f_infos: return rows, [], []
525
-
526
- warmup_if_needed(temp, top_p, max_tokens, max_side)
527
-
528
- chunks = [f_infos[i:i+chunk_size] for i in range(0, len(f_infos), chunk_size)]
529
- if step_once:
530
- chunks = chunks[:1]
531
-
532
- error_log: List[str] = []
533
- remaining: List[str] = []
534
-
535
- for ci, chunk in enumerate(chunks):
536
- for (f, fname, sha) in progress.tqdm(chunk, desc=f"Chunk {ci+1}/{len(chunks)}"):
537
- thumb = ensure_thumb(f, sha, 96)
538
- attempt = 0
539
- cap = None
540
- while attempt < 2:
541
- attempt += 1
542
- try:
543
- im = Image.open(f).convert("RGB")
544
- im = resize_for_model(im, max_side)
545
- cap = caption_once(im, instr_text, temp, top_p, max_tokens)
546
- cap = apply_shape_aliases(cap) # Apply custom shape→name mappings
547
- cap = apply_prefix_suffix(cap, trigger_word, begin_text, end_text)
548
- break
549
- except Exception as e:
550
- error_log.append(f"{fname} ({sha[:8]}): {type(e).__name__}: {e} [attempt {attempt}]")
551
- rows.append({
552
- "sha1": sha, "filename": fname, "path": f, "thumb_path": thumb,
553
- "caption": cap or "", "history": []
554
- })
555
- save_session(rows)
556
-
557
- if step_once and len(f_infos) > len(chunks[0]):
558
- remaining = [fi[0] for fi in f_infos[len(chunks[0]):]]
559
- save_journal({"remaining_files": remaining})
560
- return rows, remaining, error_log
561
 
562
  # ────────────────────────────────────────────────────────
563
  # UI
@@ -580,12 +496,24 @@ BASE_CSS = """
580
  .cf-scroll{max-height:70vh; overflow-y:auto; border:1px solid #e6e6e6; border-radius:10px; padding:8px}
581
  /* Uniform sizes */
582
  #cfGal .grid > div { height: 96px; }
583
- /* Hide SHA1 in table (first column) */
584
- #cfTable table thead tr th:first-child { display:none; }
585
- #cfTable table tbody tr td:first-child { display:none; }
586
  """
587
 
588
- with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  settings = load_settings()
590
  settings["styles"] = [s for s in settings.get("styles", []) if s in STYLE_OPTIONS] or ["Character training (short)"]
591
 
@@ -593,10 +521,10 @@ with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
593
  <div class="cf-hero">
594
  {logo_b64_img()}
595
  <div>
596
- <h1 class="cf-title">CaptionForge</h1>
597
  <div class="cf-sub">Batch captioning</div>
598
  <div class="cf-sub">Scrollable editor & autosave</div>
599
- <div class="cf-sub">CSV / Excel / TXT export</div>
600
  </div>
601
  </div>
602
  <hr>
@@ -619,123 +547,116 @@ with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
619
  )
620
  with gr.Accordion("Name & Prefix/Suffix", open=True):
621
  name_input = gr.Textbox(label="Person / Character Name", value=settings.get("name", ""))
622
- trig = gr.Textbox(label="Trigger word")
623
- add_start = gr.Textbox(label="Add text to start")
624
- add_end = gr.Textbox(label="Add text to end")
625
 
626
  with gr.Column(scale=1):
627
- instruction_preview = gr.Textbox(label="Model Instructions", lines=14)
628
- dataset_name = gr.Textbox(label="Dataset name (used for export file titles)", value=settings.get("dataset_name", "captionforge"))
 
 
 
629
 
630
- with gr.Row():
631
- chunk_mode = gr.Radio(
632
- choices=["Auto", "Manual (all at once)", "Manual (step)"],
633
- value=settings.get("chunk_mode", "Manual (step)"),
634
- label="Batch mode"
635
- )
636
- chunk_size = gr.Slider(1, 50, settings.get("chunk_size", 10), step=1, label="Chunk size")
637
- max_side = gr.Slider(256, MAX_SIDE_CAP, settings.get("max_side", min(896, MAX_SIDE_CAP)), step=32, label="Max side (resize)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
 
639
- # === Shape Aliases UI (Accordion) ===
640
  with gr.Accordion("Shape Aliases", open=False):
641
  gr.Markdown(
642
  "### 🔷 Shape Aliases\n"
643
  "Replace literal **shape tokens** in captions with a preferred **name**.\n\n"
644
- "**How to use:**\n"
645
- "1) Left column: exact shape token (e.g., `diamond`, `heart`).\n"
646
- "2) Right column: replacement name (e.g., `starkey-emblem`).\n"
647
- "3) Click **Save**.\n\n"
648
- "_Rows with blanks are ignored; add as many rows as you like._"
649
  )
650
-
651
  init_rows, init_enabled = get_shape_alias_rows_ui_defaults()
652
-
653
- enable_aliases = gr.Checkbox(
654
- label="Enable shape alias replacements",
655
- value=init_enabled,
656
- )
657
-
658
  alias_table = gr.Dataframe(
659
  headers=["shape (literal token)", "name to insert"],
660
- value=init_rows, # set value at creation time
661
- col_count=(2, "fixed"), # exactly two columns
662
- row_count=(max(1, len(init_rows)), "dynamic"), # user can add/remove rows
663
  datatype=["str", "str"],
664
- type="array", # treat rows as simple lists
665
  interactive=True
666
  )
667
-
668
  with gr.Row():
669
  add_row_btn = gr.Button("+ Add row", variant="secondary")
670
  clear_btn = gr.Button("Clear", variant="secondary")
671
  save_btn = gr.Button("💾 Save", variant="primary")
672
-
673
  save_status = gr.Markdown("")
674
-
675
- # Button handlers
676
  def _add_row(cur):
677
  cur = (cur or []) + [["", ""]]
678
  return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic"))
679
-
680
  def _clear_rows():
681
  return gr.update(value=[["", ""]], row_count=(1, "dynamic"))
682
-
683
  add_row_btn.click(_add_row, inputs=[alias_table], outputs=[alias_table])
684
  clear_btn.click(_clear_rows, outputs=[alias_table])
685
- save_btn.click(save_shape_alias_rows, inputs=[enable_aliases, alias_table],
686
- outputs=[save_status, alias_table])
687
-
688
- # Auto-refresh instruction & persist key controls
689
- def _refresh_instruction(styles, extra, name_value, trigv, begv, endv):
690
- instr = final_instruction(styles or ["Character training (short)"], extra or [], name_value)
691
- cfg = load_settings()
692
- cfg.update({
693
- "styles": styles or ["Character training (short)"],
694
- "extras": extra or [],
695
- "name": name_value,
696
- "trigger": trigv, "begin": begv, "end": endv
697
- })
698
- save_settings(cfg)
699
- return instr
 
 
 
 
 
 
 
 
 
 
 
 
700
 
701
- for comp in [style_checks, extra_opts, name_input, trig, add_start, add_end]:
702
- comp.change(_refresh_instruction, inputs=[style_checks, extra_opts, name_input, trig, add_start, add_end], outputs=[instruction_preview])
 
 
703
 
704
- # ===== File inputs & actions
705
- with gr.Accordion("Uploaded images", open=True) as uploads_acc:
706
- input_files = gr.File(label="Drop images", file_types=["image"], file_count="multiple", type="filepath")
707
- import_file = gr.File(label="Import CSV or JSONL (merge captions)")
708
 
709
- with gr.Row():
710
- import_btn = gr.Button("Import → Merge")
711
- resume_btn = gr.Button("Resume last run", variant="secondary")
712
- run_button = gr.Button("Caption batch", variant="primary")
713
-
714
- # Step-chunk controls
715
- step_panel = gr.Group(visible=False)
716
- with step_panel:
717
- step_msg = gr.Markdown("")
718
- step_next = gr.Button("Process next chunk")
719
- step_finish = gr.Button("Finish")
720
-
721
- # States
722
- rows_state = gr.State(load_session())
723
- visible_count_state = gr.State(100)
724
- pending_files = gr.State([])
725
- remaining_state = gr.State([])
726
- autosave_md = gr.Markdown("Ready.")
727
-
728
- # Error log
729
- with gr.Accordion("Error log", open=False):
730
- log_md = gr.Markdown("")
731
-
732
- # Side-by-side gallery + table with shared scroll
733
  with gr.Row():
734
  with gr.Column(scale=1):
735
  gallery = gr.Gallery(
736
- label="Results (images)",
737
  show_label=True,
738
- columns=10,
739
  height=520,
740
  elem_id="cfGal",
741
  elem_classes=["cf-scroll"]
@@ -743,49 +664,30 @@ with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
743
  with gr.Column(scale=1):
744
  table = gr.Dataframe(
745
  label="Editable captions (whole session)",
746
- value=[],
747
- headers=["sha1", "filename", "caption"], # sha1 hidden via CSS
748
  interactive=True,
 
749
  elem_id="cfTable",
750
  elem_classes=["cf-scroll"]
751
  )
752
 
753
- # Exports aligned under buttons
754
  with gr.Row():
755
  with gr.Column():
756
  export_csv_btn = gr.Button("Export CSV")
757
  csv_file = gr.File(label="CSV file", visible=False)
758
  with gr.Column():
759
- export_xlsx_btn = gr.Button("Export Excel (.xlsx)")
760
  xlsx_file = gr.File(label="Excel file", visible=False)
761
- with gr.Column():
762
- export_txt_btn = gr.Button("Export captions as .txt (zip)")
763
- txt_zip = gr.File(label="TXT zip", visible=False)
764
 
765
- # Render helpers
766
- def _render(rows, vis_n):
767
  rows = rows or []
768
- n = max(0, int(vis_n) if vis_n else 0)
769
- win = rows[:n]
770
-
771
- # Gallery expects list of paths
772
- gal = []
773
- for r in win:
774
- p = r.get("thumb_path") or r.get("path")
775
- if isinstance(p, str) and os.path.exists(p):
776
- gal.append(p)
777
-
778
- # Dataframe rows (sha1 kept for mapping; hidden via CSS)
779
- df = [[str(r.get("sha1","")), str(r.get("filename","")), str(r.get("caption",""))] for r in win]
780
- info = f"Showing {len(win)} of {len(rows)}"
781
- return gal, df, info
782
-
783
- # Initial loads
784
- demo.load(lambda s,e,n: final_instruction(s or ["Character training (short)"], e or [], n),
785
- inputs=[style_checks, extra_opts, name_input], outputs=[instruction_preview])
786
- demo.load(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
787
 
788
- # Scroll sync (Gallery + Table)
789
  gr.HTML("""
790
  <script>
791
  (function () {
@@ -830,123 +732,36 @@ with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
830
  </script>
831
  """)
832
 
833
- # Table edits -> save + rerender
834
- def on_table_edit(df_rows, rows):
835
- by_sha = {r.get("sha1"): r for r in (rows or [])}
836
- for row in (df_rows or []):
837
- if not row:
838
- continue
839
- sha = str(row[0]) if len(row) > 0 else ""
840
- tgt = by_sha.get(sha)
841
- if not tgt:
842
- continue
843
- prev = tgt.get("caption", "")
844
- newc = row[2] if len(row) > 2 else prev
845
- if newc != prev:
846
- hist = tgt.setdefault("history", [])
847
- hist.append(prev)
848
- tgt["caption"] = newc
849
- save_session(rows or [])
850
- return rows, f"Saved • {time.strftime('%H:%M:%S')}"
851
- table.change(on_table_edit, inputs=[table, rows_state], outputs=[rows_state, autosave_md])\
852
- .then(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
853
-
854
- # Import merge
855
- def do_import(file, rows):
856
- if not file:
857
- return rows, gr.update(value="No file"), rows
858
- rows = import_csv_jsonl(file.name, rows or [])
859
- return rows, gr.update(value=f"Merged from {os.path.basename(file.name)} at {time.strftime('%H:%M:%S')}"), rows
860
- import_btn.click(do_import, inputs=[import_file, rows_state], outputs=[rows_state, autosave_md, rows_state])\
861
- .then(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
862
-
863
- # Prepare run (hash dupes)
864
- def prepare_run(files, rows, pending):
865
- files = files or []
866
- rows = rows or []
867
- pending = pending or []
868
- if not files:
869
- return [], gr.update(open=True), []
870
- existing = {r.get("sha1") for r in rows if r.get("sha1")}
871
- pending_sha = set()
872
- for f in pending:
873
- try:
874
- pending_sha.add(sha1_file(f))
875
- except Exception:
876
- pass
877
- keep = []
878
- for f in files:
879
- try:
880
- s = sha1_file(f)
881
- except Exception:
882
- continue
883
- if s not in existing and s not in pending_sha:
884
- keep.append(f)
885
- return keep, gr.update(open=False), keep
886
- input_files.change(prepare_run, inputs=[input_files, rows_state, pending_files], outputs=[pending_files, uploads_acc, pending_files])
887
-
888
- # Step visibility helper (dynamic message)
889
- def _step_visibility(remain, mode):
890
- if (mode == "Manual (step)") and remain:
891
- return gr.update(visible=True), gr.update(value=f"{len(remain)} files remain. Process next chunk?")
892
- return gr.update(visible=False), gr.update(value="")
893
-
894
- # Run helpers (use settings for gen params)
895
- def _run_once(pending, rows, instr, trigv, begv, endv, mode, csize, mside):
896
- cfg = load_settings()
897
- t = cfg.get("temperature", 0.6)
898
- p = cfg.get("top_p", 0.9)
899
- m = cfg.get("max_tokens", 256)
900
- step_once = (mode == "Manual (step)")
901
- new_rows, remaining, errors = process_batch(
902
- pending, False, rows or [], instr, t, p, m, trigv, begv, endv,
903
- mode, int(csize), int(mside), step_once
904
- )
905
- log = "\n".join(errors) if errors else "(no errors)"
906
- return new_rows, remaining, log, f"Saved • {time.strftime('%H:%M:%S')}"
907
 
908
  run_button.click(
909
- _run_once,
910
- inputs=[pending_files, rows_state, instruction_preview, trig, add_start, add_end, chunk_mode, chunk_size, max_side],
911
- outputs=[rows_state, remaining_state, log_md, autosave_md]
912
- ).then(
913
- _render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md]
914
- ).then(
915
- _step_visibility, inputs=[remaining_state, chunk_mode], outputs=[step_panel, step_msg]
916
  )
917
 
918
- def _resume(rows):
919
- j = load_journal()
920
- rem = j.get("remaining_files", [])
921
- return rem, ("Nothing to resume" if not rem else f"Loaded {len(rem)} remaining")
922
- resume_btn.click(_resume, inputs=[rows_state], outputs=[pending_files, autosave_md])
923
-
924
- def _step_next(remain, rows, instr, trigv, begv, endv, mode, csize, mside):
925
- return _run_once(remain, rows, instr, trigv, begv, endv, mode, csize, mside)
926
- step_next.click(
927
- _step_next,
928
- inputs=[remaining_state, rows_state, instruction_preview, trig, add_start, add_end, chunk_mode, chunk_size, max_side],
929
- outputs=[rows_state, remaining_state, log_md, autosave_md]
930
- ).then(
931
- _render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md]
932
- ).then(
933
- _step_visibility, inputs=[remaining_state, chunk_mode], outputs=[step_panel, step_msg]
934
  )
935
 
936
- step_finish.click(lambda: (gr.update(visible=False), gr.update(value=""), []), outputs=[step_panel, step_msg, remaining_state])
937
-
938
- # Exports (reveal file widgets when created)
939
  export_csv_btn.click(
940
- lambda rows, base: (export_csv(rows or [], base), gr.update(visible=True)),
941
- inputs=[rows_state, dataset_name], outputs=[csv_file, csv_file]
942
  )
943
  export_xlsx_btn.click(
944
- lambda rows, base: (export_excel(rows or [], base), gr.update(visible=True)),
945
- inputs=[rows_state, dataset_name], outputs=[xlsx_file, xlsx_file]
946
- )
947
- export_txt_btn.click(
948
- lambda rows, base: (export_txt_zip(rows or [], base), gr.update(visible=True)),
949
- inputs=[rows_state, dataset_name], outputs=[txt_zip, txt_zip]
950
  )
951
 
952
  # Launch
 
1
  import os, io, csv, time, json, hashlib, base64, zipfile, re
2
+ from typing import List, Tuple, Dict, Any
3
+
4
+ # Caches
5
+ os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")
6
+ os.makedirs(os.environ["HF_HOME"], exist_ok=True)
7
 
 
8
  import gradio as gr
9
  from PIL import Image
10
  import torch
 
13
  # ────────────────────────────────────────────────────────
14
  # Paths & caches
15
  # ────────────────────────────────────────────────────────
 
 
 
16
  APP_DIR = os.getcwd()
17
  SESSION_FILE = "/tmp/session.json"
18
  SETTINGS_FILE = "/tmp/cf_settings.json"
19
  JOURNAL_FILE = "/tmp/cf_journal.json"
20
+ THUMB_CACHE = os.path.expanduser("~/.cache/forgecaptions/thumbs")
21
+ EXCEL_THUMB_DIR = "/tmp/forge_excel_thumbs"
22
  os.makedirs(THUMB_CACHE, exist_ok=True)
23
+ os.makedirs(EXCEL_THUMB_DIR, exist_ok=True)
24
 
25
  # ────────────────────────────────────────────────────────
26
  # Model setup
 
35
 
36
  BACKEND, VRAM_GB, GPU_NAME = _detect_gpu()
37
  DTYPE = torch.bfloat16 if BACKEND == "cuda" else torch.float32
 
38
  DEVICE = "cuda" if BACKEND == "cuda" else "cpu"
39
+ MAX_SIDE_CAP = 1024 if BACKEND == "cuda" else 640
40
 
41
  processor = AutoProcessor.from_pretrained(MODEL_PATH)
42
  model = LlavaForConditionalGeneration.from_pretrained(
 
47
  )
48
  model.eval()
49
 
50
+ print(f"[ForgeCaptions] Backend={BACKEND} GPU={GPU_NAME} VRAM={VRAM_GB}GB dtype={DTYPE}")
51
+ print(f"[ForgeCaptions] Gradio version: {gr.__version__}")
52
 
53
  # ────────────────────────────────────────────────────────
54
  # Instruction templates & options
 
66
  ]
67
 
68
  CAPTION_TYPE_MAP = {
69
+ "Descriptive (short)": "Write a short describition of the most important visible elements only. No speculation.",
70
+ "Descriptive (long)": "Write a long, highly detailed description for this image.",
71
 
72
  "Character training (short)": (
73
  "Output a concise, prompt-like caption for character LoRA/ID training. "
 
103
 
104
  EXTRA_CHOICES = [
105
  "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
106
+ "IGNORE all watermarks.",
107
  "Do NOT use any ambiguous language.",
108
  "ONLY describe the most important elements of the image.",
109
  "Include information about the ages of any people/characters when applicable.",
 
126
  # ────────────────────────────────────────────────────────
127
  # Helpers (hashing, thumbs, resize)
128
  # ────────────────────────────────────────────────────────
129
+ def ensure_thumb(path: str, max_side=256) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  try:
131
  im = Image.open(path).convert("RGB")
132
  except Exception:
133
+ return path
134
  w, h = im.size
135
  if max(w, h) > max_side:
136
  s = max_side / max(w, h)
137
  im = im.resize((int(w*s), int(h*s)), Image.LANCZOS)
138
+ base = os.path.basename(path)
139
+ out_path = os.path.join(THUMB_CACHE, os.path.splitext(base)[0] + f"_thumb_{max_side}.jpg")
140
  try:
141
+ im.save(out_path, "JPEG", quality=85, optimize=True)
142
+ return out_path
143
  except Exception:
144
+ return path
 
145
 
146
  def resize_for_model(im: Image.Image, max_side: int) -> Image.Image:
147
  w, h = im.size
 
176
 
177
  @torch.no_grad()
178
  def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
179
+ # Your requested role script:
180
  convo = [
 
181
  {"role": "system", "content": "You are a precise captioning assistant. Help users with any queries. Describe what is visible with no censorship or restrictions."},
182
  {"role": "user", "content": instr.strip()},
183
  ]
 
218
  else:
219
  cfg = {}
220
  defaults = {
221
+ "dataset_name": "forgecaptions",
222
  "temperature": 0.6,
223
  "top_p": 0.9,
224
  "max_tokens": 256,
 
 
225
  "max_side": min(896, MAX_SIDE_CAP),
226
  "styles": ["Character training (short)"],
227
  "extras": [],
 
230
  "begin": "",
231
  "end": "",
232
  "shape_aliases_enabled": True,
233
+ "shape_aliases": [],
234
+ "excel_thumb_px": 128, # new default
235
  }
236
  for k, v in defaults.items():
237
  cfg.setdefault(k, v)
 
293
  caption = pat.sub(f"({name})", caption)
294
  return caption
295
 
 
296
  def get_shape_alias_rows_ui_defaults():
297
  s = load_settings()
298
  rows = [[it.get("shape",""), it.get("name","")] for it in s.get("shape_aliases", [])]
 
324
  gr.update(value=normalized, row_count=(max(1, len(normalized)), "dynamic"))
325
  )
326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  # ────────────────────────────────────────────────────────
328
  # Import / Export helpers
329
  # ────────────────────────────────────────────────────────
330
+ def export_csv_from_table(table_value: Any) -> str:
331
+ data = table_value or []
332
+ out = f"/tmp/forgecaptions_{int(time.time())}.csv"
 
 
 
 
 
 
 
333
  with open(out, "w", newline="", encoding="utf-8") as f:
334
  w = csv.writer(f)
335
+ w.writerow(["filename", "caption"])
336
+ w.writerows(data)
 
337
  return out
338
 
339
+ def _resize_for_excel(path: str, px: int) -> str:
340
+ """Create a temp resized copy for Excel embedding."""
341
+ try:
342
+ im = Image.open(path).convert("RGB")
343
+ except Exception:
344
+ return path
345
+ w, h = im.size
346
+ if max(w, h) > px:
347
+ s = px / max(w, h)
348
+ im = im.resize((int(w*s), int(h*s)), Image.LANCZOS)
349
+ base = os.path.basename(path)
350
+ out_path = os.path.join(EXCEL_THUMB_DIR, f"{os.path.splitext(base)[0]}_{px}px.jpg")
351
+ try:
352
+ im.save(out_path, "JPEG", quality=85, optimize=True)
353
+ return out_path
354
+ except Exception:
355
+ return path
356
+
357
+ def export_excel_with_thumbs(table_value: Any, session_rows: List[dict], thumb_px: int) -> str:
358
  try:
359
  from openpyxl import Workbook
360
  from openpyxl.drawing.image import Image as XLImage
361
  except Exception as e:
362
  raise RuntimeError("Excel export requires 'openpyxl' in requirements.txt.") from e
363
+
364
+ # Respect user edits (table wins)
365
+ caption_by_file = {}
366
+ for row in (table_value or []):
367
+ if not row:
368
+ continue
369
+ fn = str(row[0]) if len(row) > 0 else ""
370
+ cap = str(row[1]) if len(row) > 1 and row[1] is not None else ""
371
+ if fn:
372
+ caption_by_file[fn] = cap
373
+
374
+ wb = Workbook(); ws = wb.active; ws.title = "ForgeCaptions"
375
  ws.append(["image", "filename", "caption"])
376
  ws.column_dimensions["A"].width = 24
377
  ws.column_dimensions["B"].width = 42
378
  ws.column_dimensions["C"].width = 100
379
+
380
+ # Approx px→points (Excel row height is points; ~0.75 pt per px @ 96dpi)
381
+ row_h = int(thumb_px * 0.75)
382
+
383
  r_i = 2
384
+ for r in (session_rows or []):
385
+ fn = r.get("filename","")
386
+ cap = caption_by_file.get(fn, r.get("caption",""))
387
  ws.cell(row=r_i, column=2, value=fn)
388
  ws.cell(row=r_i, column=3, value=cap)
389
  img_path = r.get("thumb_path") or r.get("path")
390
  if img_path and os.path.exists(img_path):
391
  try:
392
+ resized = _resize_for_excel(img_path, int(thumb_px))
393
+ xlimg = XLImage(resized)
394
+ ws.add_image(xlimg, f"A{r_i}")
395
+ ws.row_dimensions[r_i].height = row_h
396
  except Exception:
397
  pass
398
  r_i += 1
399
+
400
+ out = f"/tmp/forgecaptions_{int(time.time())}.xlsx"
401
+ wb.save(out)
402
+ return out
403
+
404
+ # Rows<->Table helpers
405
+ def _rows_to_table(rows: List[dict]) -> list:
406
+ return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])]
407
+
408
+ def _table_to_rows(table_value: Any, rows: List[dict]) -> List[dict]:
409
+ tbl = table_value or []
410
+ new = []
411
+ for i, r in enumerate(rows or []):
412
+ r = dict(r)
413
+ if i < len(tbl) and len(tbl[i]) >= 2:
414
+ r["filename"] = str(tbl[i][0]) if tbl[i][0] is not None else r.get("filename","")
415
+ r["caption"] = str(tbl[i][1]) if tbl[i][1] is not None else r.get("caption","")
416
+ new.append(r)
417
+ return new
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
  # ────────────────────────────────────────────────────────
420
+ # Batch captioning (returns rows, gallery, table)
421
  # ────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  @torch.no_grad()
423
+ def run_batch(
424
+ files: List[Any],
425
+ session_rows: List[dict],
426
+ instr_text: str,
427
+ temp: float,
428
+ top_p: float,
429
+ max_tokens: int,
430
+ max_side: int,
431
+ ) -> Tuple[List[dict], list, list, str]:
432
  if torch.cuda.is_available():
433
  torch.cuda.empty_cache()
 
 
434
 
435
+ session_rows = session_rows or []
436
+ files = files or []
437
+ if not files:
438
+ gallery_pairs = [
439
+ ((r.get("thumb_path") or r.get("path")), r.get("caption",""))
440
+ for r in session_rows if (r.get("thumb_path") or r.get("path"))
441
+ ]
442
+ return session_rows, gallery_pairs, _rows_to_table(session_rows), f"Saved • {time.strftime('%H:%M:%S')}"
443
 
 
 
444
  for f in files:
445
+ path = f if isinstance(f, str) else getattr(f, "name", None) or getattr(f, "path", None)
446
+ if not path or not os.path.exists(path):
447
+ continue
448
  try:
449
+ im = Image.open(path).convert("RGB")
450
  except Exception:
451
  continue
452
+ im = resize_for_model(im, max_side)
453
+ cap = caption_once(im, instr_text, temp, top_p, max_tokens)
454
+ cap = apply_shape_aliases(cap)
455
+ s = load_settings()
456
+ cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end",""))
457
+
458
+ filename = os.path.basename(path)
459
+ thumb = ensure_thumb(path, 256)
460
+ session_rows.append({"filename": filename, "caption": cap, "path": path, "thumb_path": thumb})
461
+
462
+ save_session(session_rows)
463
+ gallery_pairs = [
464
+ ((r.get("thumb_path") or r.get("path")), r.get("caption",""))
465
+ for r in session_rows if (r.get("thumb_path") or r.get("path"))
466
+ ]
467
+ return session_rows, gallery_pairs, _rows_to_table(session_rows), f"Saved {time.strftime('%H:%M:%S')}"
468
+
469
+ def sync_table_to_session(table_value: Any, session_rows: List[dict]) -> Tuple[List[dict], list, str]:
470
+ session_rows = _table_to_rows(table_value, session_rows or [])
471
+ save_session(session_rows)
472
+ gallery_pairs = [
473
+ ((r.get("thumb_path") or r.get("path")), r.get("caption",""))
474
+ for r in session_rows if (r.get("thumb_path") or r.get("path"))
475
+ ]
476
+ return session_rows, gallery_pairs, f"Saved {time.strftime('%H:%M:%S')}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
 
478
  # ────────────────────────────────────────────────────────
479
  # UI
 
496
  .cf-scroll{max-height:70vh; overflow-y:auto; border:1px solid #e6e6e6; border-radius:10px; padding:8px}
497
  /* Uniform sizes */
498
  #cfGal .grid > div { height: 96px; }
 
 
 
499
  """
500
 
501
+ def logo_b64_img() -> str:
502
+ candidates = [
503
+ os.path.join(APP_DIR, "forgecaptions-logo.png"),
504
+ os.path.join(APP_DIR, "captionforge-logo.png"), # fallback if you kept the old name
505
+ "/home/user/app/forgecaptions-logo.png",
506
+ "forgecaptions-logo.png",
507
+ "captionforge-logo.png",
508
+ ]
509
+ for p in candidates:
510
+ if os.path.exists(p):
511
+ with open(p, "rb") as f:
512
+ b64 = base64.b64encode(f.read()).decode("ascii")
513
+ return f"<img src='data:image/png;base64,{b64}' alt='ForgeCaptions' class='cf-logo'>"
514
+ return ""
515
+
516
+ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
517
  settings = load_settings()
518
  settings["styles"] = [s for s in settings.get("styles", []) if s in STYLE_OPTIONS] or ["Character training (short)"]
519
 
 
521
  <div class="cf-hero">
522
  {logo_b64_img()}
523
  <div>
524
+ <h1 class="cf-title">ForgeCaptions</h1>
525
  <div class="cf-sub">Batch captioning</div>
526
  <div class="cf-sub">Scrollable editor & autosave</div>
527
+ <div class="cf-sub">CSV / Excel export</div>
528
  </div>
529
  </div>
530
  <hr>
 
547
  )
548
  with gr.Accordion("Name & Prefix/Suffix", open=True):
549
  name_input = gr.Textbox(label="Person / Character Name", value=settings.get("name", ""))
550
+ trig = gr.Textbox(label="Trigger word", value=settings.get("trigger",""))
551
+ add_start = gr.Textbox(label="Add text to start", value=settings.get("begin",""))
552
+ add_end = gr.Textbox(label="Add text to end", value=settings.get("end",""))
553
 
554
  with gr.Column(scale=1):
555
+ instruction_preview = gr.Textbox(label="Model Instructions", lines=12)
556
+ dataset_name = gr.Textbox(label="Dataset name (used for export file titles)", value=settings.get("dataset_name", "forgecaptions"))
557
+ max_side = gr.Slider(256, MAX_SIDE_CAP, settings.get("max_side", min(896, MAX_SIDE_CAP)), step=32, label="Max side (resize)")
558
+ excel_thumb_px = gr.Slider(64, 256, value=settings.get("excel_thumb_px", 128), step=8, label="Excel thumbnail size (px)")
559
+ gr.Markdown("Generation (saved in settings): temperature 0.6 • top-p 0.9 • max_tokens 256")
560
 
561
+ # Auto-refresh instruction & persist key controls
562
+ def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms):
563
+ instr = final_instruction(styles or ["Character training (short)"], extra or [], name_value)
564
+ cfg = load_settings()
565
+ cfg.update({
566
+ "styles": styles or ["Character training (short)"],
567
+ "extras": extra or [],
568
+ "name": name_value,
569
+ "trigger": trigv, "begin": begv, "end": endv,
570
+ "excel_thumb_px": int(excel_px),
571
+ "max_side": int(ms),
572
+ })
573
+ save_settings(cfg)
574
+ return instr
575
+
576
+ for comp in [style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side]:
577
+ comp.change(
578
+ _refresh_instruction,
579
+ inputs=[style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side],
580
+ outputs=[instruction_preview]
581
+ )
582
+
583
+ # Set initial instruction text on load
584
+ demo.load(lambda s,e,n: final_instruction(s or ["Character training (short)"], e or [], n),
585
+ inputs=[style_checks, extra_opts, name_input], outputs=[instruction_preview])
586
 
587
+ # ===== Shape Aliases
588
  with gr.Accordion("Shape Aliases", open=False):
589
  gr.Markdown(
590
  "### 🔷 Shape Aliases\n"
591
  "Replace literal **shape tokens** in captions with a preferred **name**.\n\n"
592
+ "**How to use:** left column = shape token (e.g., `diamond`), right column = replacement name (e.g., `starkey-emblem`)."
 
 
 
 
593
  )
 
594
  init_rows, init_enabled = get_shape_alias_rows_ui_defaults()
595
+ enable_aliases = gr.Checkbox(label="Enable shape alias replacements", value=init_enabled)
 
 
 
 
 
596
  alias_table = gr.Dataframe(
597
  headers=["shape (literal token)", "name to insert"],
598
+ value=init_rows,
599
+ col_count=(2, "fixed"),
600
+ row_count=(max(1, len(init_rows)), "dynamic"),
601
  datatype=["str", "str"],
602
+ type="array",
603
  interactive=True
604
  )
 
605
  with gr.Row():
606
  add_row_btn = gr.Button("+ Add row", variant="secondary")
607
  clear_btn = gr.Button("Clear", variant="secondary")
608
  save_btn = gr.Button("💾 Save", variant="primary")
 
609
  save_status = gr.Markdown("")
 
 
610
  def _add_row(cur):
611
  cur = (cur or []) + [["", ""]]
612
  return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic"))
 
613
  def _clear_rows():
614
  return gr.update(value=[["", ""]], row_count=(1, "dynamic"))
 
615
  add_row_btn.click(_add_row, inputs=[alias_table], outputs=[alias_table])
616
  clear_btn.click(_clear_rows, outputs=[alias_table])
617
+ save_btn.click(save_shape_alias_rows, inputs=[enable_aliases, alias_table], outputs=[save_status, alias_table])
618
+
619
+ # ===== Tabs: Single & Batch (keeps gallery/table position below)
620
+ with gr.Tabs():
621
+ with gr.Tab("Single"):
622
+ input_image_single = gr.Image(type="pil", label="Input Image", height=512, width=512)
623
+ single_caption_btn = gr.Button("Caption")
624
+ single_caption_out = gr.Textbox(label="Caption (single)")
625
+
626
+ def _caption_single(img, instr):
627
+ if img is None:
628
+ return "No image provided."
629
+ s = load_settings()
630
+ im = resize_for_model(img, int(s.get("max_side", MAX_SIDE_CAP)))
631
+ t = s.get("temperature", 0.6)
632
+ p = s.get("top_p", 0.9)
633
+ m = s.get("max_tokens", 256)
634
+ cap = caption_once(im, instr, t, p, m)
635
+ cap = apply_shape_aliases(cap)
636
+ cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end",""))
637
+ return cap
638
+
639
+ single_caption_btn.click(
640
+ _caption_single,
641
+ inputs=[input_image_single, instruction_preview],
642
+ outputs=[single_caption_out]
643
+ )
644
 
645
+ with gr.Tab("Batch"):
646
+ with gr.Accordion("Uploaded images", open=True):
647
+ input_files = gr.File(label="Drop images", file_types=["image"], file_count="multiple", type="filepath")
648
+ run_button = gr.Button("Caption batch", variant="primary")
649
 
650
+ # ===== Results/Table (position unchanged)
651
+ rows_state = gr.State(load_session())
652
+ autosave_md = gr.Markdown("Ready.")
 
653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  with gr.Row():
655
  with gr.Column(scale=1):
656
  gallery = gr.Gallery(
657
+ label="Results (image + caption)",
658
  show_label=True,
659
+ columns=3,
660
  height=520,
661
  elem_id="cfGal",
662
  elem_classes=["cf-scroll"]
 
664
  with gr.Column(scale=1):
665
  table = gr.Dataframe(
666
  label="Editable captions (whole session)",
667
+ value=_rows_to_table(load_session()),
668
+ headers=["filename", "caption"],
669
  interactive=True,
670
+ wrap=True,
671
  elem_id="cfTable",
672
  elem_classes=["cf-scroll"]
673
  )
674
 
675
+ # Exports
676
  with gr.Row():
677
  with gr.Column():
678
  export_csv_btn = gr.Button("Export CSV")
679
  csv_file = gr.File(label="CSV file", visible=False)
680
  with gr.Column():
681
+ export_xlsx_btn = gr.Button("Export Excel (.xlsx) with thumbnails")
682
  xlsx_file = gr.File(label="Excel file", visible=False)
 
 
 
683
 
684
+ # Initial gallery render
685
+ def _initial_gallery(rows):
686
  rows = rows or []
687
+ return [((r.get("thumb_path") or r.get("path")), r.get("caption","")) for r in rows if (r.get("thumb_path") or r.get("path"))]
688
+ demo.load(_initial_gallery, inputs=[rows_state], outputs=[gallery])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689
 
690
+ # Scroll sync
691
  gr.HTML("""
692
  <script>
693
  (function () {
 
732
  </script>
733
  """)
734
 
735
+ # Run batch
736
+ def _run_click(files, rows, instr, ms):
737
+ s = load_settings()
738
+ t = s.get("temperature", 0.6)
739
+ p = s.get("top_p", 0.9)
740
+ m = s.get("max_tokens", 256)
741
+ new_rows, gal, tbl, stamp = run_batch(files, rows or [], instr, t, p, m, int(ms))
742
+ return new_rows, gal, tbl, stamp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743
 
744
  run_button.click(
745
+ _run_click,
746
+ inputs=[input_files, rows_state, instruction_preview, max_side],
747
+ outputs=[rows_state, gallery, table, autosave_md]
 
 
 
 
748
  )
749
 
750
+ # Table edits sync → rows + gallery
751
+ table.change(
752
+ sync_table_to_session,
753
+ inputs=[table, rows_state],
754
+ outputs=[rows_state, gallery, autosave_md]
 
 
 
 
 
 
 
 
 
 
 
755
  )
756
 
757
+ # Exports (use slider value)
 
 
758
  export_csv_btn.click(
759
+ lambda tbl: (export_csv_from_table(tbl), gr.update(visible=True)),
760
+ inputs=[table], outputs=[csv_file, csv_file]
761
  )
762
  export_xlsx_btn.click(
763
+ lambda tbl, rows, px: (export_excel_with_thumbs(tbl, rows or [], int(px)), gr.update(visible=True)),
764
+ inputs=[table, rows_state, excel_thumb_px], outputs=[xlsx_file, xlsx_file]
 
 
 
 
765
  )
766
 
767
  # Launch