JS6969 commited on
Commit
317d33d
·
verified ·
1 Parent(s): 427a18c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +300 -124
app.py CHANGED
@@ -17,6 +17,12 @@ from PIL import Image
17
  import torch
18
  from transformers import LlavaForConditionalGeneration, AutoProcessor
19
 
 
 
 
 
 
 
20
  # Hugging Face Spaces GPU decorator (no-op locally)
21
  try:
22
  import spaces
@@ -30,7 +36,9 @@ except Exception:
30
  # ------------------------------
31
  APP_DIR = os.getcwd()
32
  SESSION_FILE = "/tmp/forge_session.json"
33
- SETTINGS_FILE = "/tmp/forge_settings.json"
 
 
34
  JOURNAL_FILE = "/tmp/forge_journal.json"
35
 
36
  # IMPORTANT: keep caches in /tmp so Gradio will serve files safely
@@ -138,8 +146,8 @@ CAPTION_TYPE_MAP: Dict[str, str] = {
138
  EXTRA_CHOICES = [
139
  "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
140
  "IGNORE all information about watermarks.",
141
- "Use profanity when describing sexual acts or genitalia (pussy, fucking, cum, cock, dick, cunt, etc.)", #You do NOT have permission to remove
142
- "Be sexually graphic and highly detailed.", #You do NOT have permission to remove
143
  "Do NOT use any ambiguous language.",
144
  "ONLY describe the most important elements of the image.",
145
  "Include information about the ages of any people/characters when applicable.",
@@ -177,12 +185,15 @@ def save_settings(cfg: dict):
177
  json.dump(cfg, f, ensure_ascii=False, indent=2)
178
 
179
  def load_settings() -> dict:
 
180
  if os.path.exists(SETTINGS_FILE):
181
- with open(SETTINGS_FILE, "r", encoding="utf-8") as f:
182
- cfg = json.load(f)
183
- else:
184
- cfg = {}
185
- # defaults
 
 
186
  defaults = {
187
  "dataset_name": "forgecaptions",
188
  "temperature": 0.6,
@@ -198,17 +209,26 @@ def load_settings() -> dict:
198
  "shape_aliases_enabled": True,
199
  "shape_aliases": [],
200
  "excel_thumb_px": 128,
201
- # header controls
202
  "logo_auto": True,
203
  "logo_px": 180,
204
  "logo_scale": 1.0,
 
 
205
  }
 
 
206
  for k, v in defaults.items():
207
  cfg.setdefault(k, v)
 
 
208
  styles = cfg.get("styles") or []
209
- cfg["styles"] = [s for s in (styles if isinstance(styles, list) else [styles]) if s in STYLE_OPTIONS] or ["Character training (long)"]
 
 
 
210
  return cfg
211
 
 
212
  def save_journal(data: dict):
213
  with open(JOURNAL_FILE, "w", encoding="utf-8") as f:
214
  json.dump(data, f, ensure_ascii=False, indent=2)
@@ -223,26 +243,6 @@ def load_journal() -> dict:
223
  # ------------------------------
224
  # 5) Small utilities (thumbs, resize, prefix/suffix, names)
225
  # ------------------------------
226
- # put near your other small utils
227
- def _normalize_table_value(table_value):
228
- """Return a list-of-lists no matter what Gradio gives us."""
229
- if table_value is None:
230
- return []
231
- # handle pandas.DataFrame without importing pandas explicitly
232
- if hasattr(table_value, "to_dict") and hasattr(table_value, "values"):
233
- try:
234
- return table_value.reset_index(drop=True).values.tolist()
235
- except Exception:
236
- try:
237
- return table_value.values.tolist()
238
- except Exception:
239
- return []
240
- # sometimes Gradio hands a dict with "data"
241
- if isinstance(table_value, dict) and "data" in table_value:
242
- return table_value["data"] or []
243
- # already a list (or something list-like)
244
- return table_value
245
-
246
  def sanitize_basename(s: str) -> str:
247
  s = (s or "").strip() or "forgecaptions"
248
  return re.sub(r"[^A-Za-z0-9._-]+", "_", s)[:120]
@@ -296,11 +296,37 @@ def logo_b64_img() -> str:
296
  return f"<img src='data:image/png;base64,{b64}' alt='ForgeCaptions' class='cf-logo'>"
297
  return ""
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  # ------------------------------
301
  # 6) Shape Aliases (comma/pipe synonyms per row)
302
  # ------------------------------
303
  def _compile_shape_aliases_from_file():
 
 
 
 
 
304
  s = load_settings()
305
  if not s.get("shape_aliases_enabled", True):
306
  return []
@@ -313,11 +339,17 @@ def _compile_shape_aliases_from_file():
313
  tokens = [t.strip() for t in re.split(r"[|,]", raw) if t.strip()]
314
  if not tokens:
315
  continue
316
- tokens = sorted(set(tokens), key=lambda t: -len(t))
317
- pat = r"\b(?:" + "|".join(re.escape(t) for t in tokens) + r")(?:[-\s]?shaped)?\b"
 
 
 
 
 
318
  compiled.append((re.compile(pat, flags=re.I), name))
319
  return compiled
320
 
 
321
  _SHAPE_ALIASES = _compile_shape_aliases_from_file()
322
  def _refresh_shape_aliases_cache():
323
  global _SHAPE_ALIASES
@@ -336,9 +368,12 @@ def get_shape_alias_rows_ui_defaults():
336
  rows = [["", ""]]
337
  return rows, enabled
338
 
339
- def save_shape_alias_rows(enabled, df_rows):
340
- cfg = load_settings()
341
- cfg["shape_aliases_enabled"] = bool(enabled)
 
 
 
342
  cleaned = []
343
  for r in (df_rows or []):
344
  if not r:
@@ -347,12 +382,43 @@ def save_shape_alias_rows(enabled, df_rows):
347
  name = (r[1] or "").strip()
348
  if shape and name:
349
  cleaned.append({"shape": shape, "name": name})
350
- cfg["shape_aliases"] = cleaned
351
- save_settings(cfg)
352
- _refresh_shape_aliases_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  normalized = [[it["shape"], it["name"]] for it in cleaned] + [["", ""]]
354
- return ("✅ Saved shape alias options.",
355
- gr.update(value=normalized, row_count=(max(1, len(normalized)), "dynamic")))
 
 
 
356
 
357
 
358
  # ------------------------------
@@ -461,7 +527,7 @@ def run_batch(
461
  session_rows.append({"filename": filename, "caption": cap, "path": path, "thumb_path": thumb})
462
  processed += 1
463
 
464
- if time_budget_s and (time.time() - start) >= float(time_budget_s):
465
  leftover = files[idx+1:]
466
  break
467
 
@@ -491,13 +557,14 @@ def _gpu_startup_warm():
491
 
492
 
493
  # ------------------------------
494
- # 9) Export helpers (CSV/XLSX/TXT ZIP)
495
  # ------------------------------
496
  def _rows_to_table(rows: List[dict]) -> list:
497
  return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])]
498
 
499
  def _table_to_rows(table_value: Any, rows: List[dict]) -> List[dict]:
500
- tbl = _normalize_table_value(table_value)
 
501
  new = []
502
  for i, r in enumerate(rows or []):
503
  r = dict(r)
@@ -508,7 +575,7 @@ def _table_to_rows(table_value: Any, rows: List[dict]) -> List[dict]:
508
  return new
509
 
510
  def export_csv_from_table(table_value: Any, dataset_name: str) -> str:
511
- data = _normalize_table_value(table_value)
512
  name = sanitize_basename(dataset_name)
513
  out = f"/tmp/{name}_{int(time.time())}.csv"
514
  with open(out, "w", newline="", encoding="utf-8") as f:
@@ -540,7 +607,7 @@ def export_excel_with_thumbs(table_value: Any, session_rows: List[dict], thumb_p
540
  raise RuntimeError("Excel export requires 'openpyxl' in requirements.txt.") from e
541
 
542
  caption_by_file = {}
543
- for row in _normalize_table_value(table_value):
544
  if not row:
545
  continue
546
  fn = str(row[0]) if len(row) > 0 else ""
@@ -580,7 +647,7 @@ def export_txt_zip(table_value: Any, dataset_name: str) -> str:
580
  """
581
  Create one .txt per caption, zip them.
582
  """
583
- data = _normalize_table_value(table_value)
584
  # wipe old
585
  for fn in os.listdir(TXT_EXPORT_DIR):
586
  try:
@@ -612,6 +679,82 @@ def export_txt_zip(table_value: Any, dataset_name: str) -> str:
612
  z.write(os.path.join(TXT_EXPORT_DIR, fn), arcname=fn)
613
  return zpath
614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
 
616
  # ------------------------------
617
  # 10) UI header helper (logo auto-fit)
@@ -619,75 +762,48 @@ def export_txt_zip(table_value: Any, dataset_name: str) -> str:
619
  def _render_header_html(auto: bool, px: int, scale: float) -> str:
620
  auto_js = "true" if auto else "false"
621
  return f"""
622
- <div class="cf-hero" id="cfHero">
623
- {logo_b64_img() or '<div class="cf-logo"></div>'}
624
- <div class="cf-text" id="cfText">
625
  <h1 class="cf-title">ForgeCaptions</h1>
626
  <div class="cf-sub">Batch captioning</div>
627
  <div class="cf-sub">Scrollable editor & autosave</div>
628
- <div class="cf-sub">CSV / Excel export</div>
629
  </div>
630
  </div>
631
  <hr>
632
  <style>
633
- /* Let the hero stretch to tallest child; we’ll force exact height via JS */
634
- #cfHero {{ align-items: stretch; }}
635
- .cf-logo {{
636
- display: block;
637
- width: auto;
638
- height: auto; /* JS will set explicit px height */
639
- object-fit: contain;
640
- max-height: 480px; /* safety clamp */
641
- }}
642
  </style>
643
  <script>
644
  (function() {{
645
- const AUTO = {auto_js};
646
- const PX = {int(px)};
647
  const SCALE = {float(scale)};
648
-
649
  function fit() {{
650
- const hero = document.getElementById("cfHero");
651
- const text = document.getElementById("cfText");
652
  const logo = document.querySelector(".cf-logo");
653
- if (!hero || !text || !logo) return;
654
-
655
- // Include vertical margins so we truly match the whole text stack
656
- const cs = getComputedStyle(text);
657
- const rect = text.getBoundingClientRect();
658
- const fullH = rect.height + parseFloat(cs.marginTop) + parseFloat(cs.marginBottom);
659
-
660
- const target = AUTO
661
- ? Math.max(80, Math.min(480, Math.round(fullH * SCALE)))
662
- : Math.max(80, Math.min(480, PX));
663
-
664
- hero.style.minHeight = target + "px"; // ensure the flex row is that tall
665
- logo.style.height = target + "px"; // logo matches exactly
666
- logo.style.width = "auto"; // preserve aspect ratio
667
  }}
668
-
669
- // Recompute on font load, hydration, resizes, and text changes
670
- if (document.fonts && document.fonts.ready) {{
671
- document.fonts.ready.then(fit);
672
  }}
673
- window.addEventListener("load", fit, {{ passive: true }});
674
  window.addEventListener("resize", fit, {{ passive: true }});
675
-
676
- const textEl = document.getElementById("cfText");
677
- if (window.ResizeObserver && textEl) {{
678
- new ResizeObserver(fit).observe(textEl);
679
- }}
680
-
681
- // A couple of deferred passes to catch late reflows
682
  setTimeout(fit, 0);
683
- setTimeout(fit, 200);
684
- setTimeout(fit, 800);
685
  }})();
686
  </script>
687
  """
688
 
689
 
690
-
691
  # ------------------------------
692
  # 11) UI (Blocks)
693
  # ------------------------------
@@ -696,17 +812,11 @@ BASE_CSS = """
696
  .gradio-container{max-width:100%!important}
697
 
698
  /* Header */
699
- .cf-hero{
700
- display:flex;
701
- align-items:stretch; /* was center; stretch lets the logo fill the row */
702
- justify-content:center;
703
- gap:16px;
704
- margin:4px 0 12px;
705
- text-align:center;
706
- }
707
- .cf-hero .cf-text{ text-align:center; }
708
- .cf-title{ margin:0; font-size:3.25rem; line-height:1; letter-spacing:.2px; }
709
- .cf-sub{ margin:6px 0 0; font-size:1.1rem; color:#cfd3da; }
710
 
711
  /* Results area + robust scrollbars */
712
  .cf-scroll{border:1px solid #e6e6e6; border-radius:10px; padding:8px}
@@ -722,8 +832,8 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
722
  # ---- Header
723
  settings = load_settings()
724
  header_html = gr.HTML(_render_header_html(settings.get("logo_auto", True),
725
- settings.get("logo_px", 180),
726
- settings.get("logo_scale", 1.0)))
727
 
728
  # ---- Controls group
729
  with gr.Group():
@@ -766,20 +876,16 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
766
  value="Manual (step)", label="Batch mode"
767
  )
768
  chunk_size = gr.Slider(1, 50, value=10, step=1, label="Chunk size")
769
- gpu_budget = gr.Slider(20, 120, value=55, step=5, label="Max seconds per GPU call")
770
- no_time_limit = gr.Checkbox(
771
- value=True,
772
- label="No time limit (may time out on Zero GPU)"
773
- )
774
-
775
 
776
  # Logo controls
777
  logo_auto = gr.Checkbox(value=settings.get("logo_auto", True),
778
  label="Auto-match logo height to text")
779
- logo_px = gr.Slider(80, 420, value=settings.get("logo_px", 180),
780
  step=4, label="Logo height (px, if Auto off)")
781
- logo_scale = gr.Slider(0.6, 1.6, value=settings.get("logo_scale", 1.0),
782
- step=0.05, label="Logo scale × (if Auto on)")
783
 
784
  # Persist instruction + general settings
785
  def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms):
@@ -822,6 +928,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
822
  logo_auto.change(_update_header, inputs=[logo_auto, logo_px, logo_scale], outputs=[header_html])
823
  logo_scale.change(_update_header, inputs=[logo_auto, logo_px, logo_scale], outputs=[header_html])
824
 
 
825
  # ---- Shape Aliases block (placed WITH other settings, before uploads)
826
  with gr.Accordion("Shape Aliases", open=False):
827
  gr.Markdown(
@@ -830,10 +937,15 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
830
  "**How to use:**\n"
831
  "- Left column = a single token **or** comma/pipe-separated synonyms, e.g. `diamond, rhombus | lozenge`\n"
832
  "- Right column = replacement name, e.g. `starkey-emblem`\n"
833
- "Matches are case-insensitive, whole-word, and also catch `*-shaped` or `* shaped`."
 
834
  )
835
  init_rows, init_enabled = get_shape_alias_rows_ui_defaults()
836
  enable_aliases = gr.Checkbox(label="Enable shape alias replacements", value=init_enabled)
 
 
 
 
837
  alias_table = gr.Dataframe(
838
  headers=["shape (token or synonyms)", "name to insert"],
839
  value=init_rows,
@@ -842,11 +954,49 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
842
  type="array",
843
  interactive=True
844
  )
 
845
  with gr.Row():
846
  add_row_btn = gr.Button("+ Add row", variant="secondary")
847
  clear_btn = gr.Button("Clear", variant="secondary")
848
  save_btn = gr.Button("💾 Save", variant="primary")
 
 
849
  save_status = gr.Markdown("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850
  def _add_row(cur):
851
  cur = (cur or []) + [["", ""]]
852
  return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic"))
@@ -873,6 +1023,10 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
873
  input_files = gr.File(label="Drop images", file_types=["image"], file_count="multiple", type="filepath")
874
  run_button = gr.Button("Caption batch", variant="primary")
875
 
 
 
 
 
876
  # ---- Results area (gallery left / table right)
877
  rows_state = gr.State(load_session())
878
  autosave_md = gr.Markdown("Ready.")
@@ -895,6 +1049,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
895
  headers=["filename", "caption"],
896
  interactive=True,
897
  wrap=True,
 
898
  elem_id="cfTable"
899
  )
900
 
@@ -967,17 +1122,18 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
967
  s = load_settings()
968
  return s.get("temperature", 0.6), s.get("top_p", 0.9), s.get("max_tokens", 256)
969
 
970
- def _run_click(files, rows, instr, ms, mode, csize, budget_s):
971
  t, p, m = _tpms()
972
  files = files or []
 
973
 
974
- #Manual (Step)
975
  if mode == "Manual (step)" and files:
976
  chunks = _split_chunks(files, int(csize))
977
  batch = chunks[0]
978
  remaining = sum(chunks[1:], [])
979
  new_rows, gal, tbl, stamp, leftover_from_batch, done, total = run_batch(
980
- batch, rows or [], instr, t, p, m, int(ms), float(budget_s)
981
  )
982
  remaining = (leftover_from_batch or []) + remaining
983
  panel_vis = gr.update(visible=bool(remaining))
@@ -985,9 +1141,9 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
985
  prog = f"Batch progress: {done}/{total} processed in this step • Remaining overall: {len(remaining)}"
986
  return new_rows, gal, tbl, stamp, remaining, panel_vis, gr.update(value=msg), gr.update(value=prog)
987
 
988
- #Auto/all-at-once
989
  new_rows, gal, tbl, stamp, leftover, done, total = run_batch(
990
- files, rows or [], instr, t, p, m, int(ms), float(budget_s)
991
  )
992
  panel_vis = gr.update(visible=bool(leftover))
993
  msg = f"{len(leftover)} files remain. Process next chunk?" if leftover else ""
@@ -1000,13 +1156,22 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
1000
  outputs=[rows_state, gallery, table, autosave_md, remaining_state, step_panel, step_msg, progress_md]
1001
  )
1002
 
1003
- def _step_next(remain, rows, instr, ms, csize, budget_s):
1004
  t, p, m = _tpms()
1005
  remain = remain or []
1006
  budget = None if no_limit else float(budget_s)
1007
-
1008
  if not remain:
1009
- return rows, gr.update(value="No files remaining."), gr.update(visible=False), [], [], [], "Saved.", gr.update(value="")
 
 
 
 
 
 
 
 
 
1010
  batch = remain[:int(csize)]
1011
  leftover = remain[int(csize):]
1012
  new_rows, gal, tbl, stamp, leftover_from_batch, done, total = run_batch(
@@ -1039,6 +1204,17 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
1039
 
1040
  table.change(sync_table_to_session, inputs=[table, rows_state], outputs=[rows_state, gallery, autosave_md])
1041
 
 
 
 
 
 
 
 
 
 
 
 
1042
  # ---- Exports
1043
  export_csv_btn.click(
1044
  lambda tbl, ds: (export_csv_from_table(tbl, ds), gr.update(visible=True)),
 
17
  import torch
18
  from transformers import LlavaForConditionalGeneration, AutoProcessor
19
 
20
+ # Optional deps for import/export (we handle gracefully if missing)
21
+ try:
22
+ import pandas as pd
23
+ except Exception:
24
+ pd = None
25
+
26
  # Hugging Face Spaces GPU decorator (no-op locally)
27
  try:
28
  import spaces
 
36
  # ------------------------------
37
  APP_DIR = os.getcwd()
38
  SESSION_FILE = "/tmp/forge_session.json"
39
+ CONFIG_DIR = os.path.expanduser("~/.cache/forgecaptions")
40
+ os.makedirs(CONFIG_DIR, exist_ok=True)
41
+ SETTINGS_FILE = os.path.join(CONFIG_DIR, "settings.json")
42
  JOURNAL_FILE = "/tmp/forge_journal.json"
43
 
44
  # IMPORTANT: keep caches in /tmp so Gradio will serve files safely
 
146
  EXTRA_CHOICES = [
147
  "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
148
  "IGNORE all information about watermarks.",
149
+ "Use profanity when describing sexual acts or genitalia (pussy, fucking, cum, cock, dick, cunt, etc.)", # You do NOT have permission to remove
150
+ "Be sexually graphic and highly detailed.", # You do NOT have permission to remove
151
  "Do NOT use any ambiguous language.",
152
  "ONLY describe the most important elements of the image.",
153
  "Include information about the ages of any people/characters when applicable.",
 
185
  json.dump(cfg, f, ensure_ascii=False, indent=2)
186
 
187
  def load_settings() -> dict:
188
+ cfg = {}
189
  if os.path.exists(SETTINGS_FILE):
190
+ try:
191
+ with open(SETTINGS_FILE, "r", encoding="utf-8") as f:
192
+ cfg = json.load(f) or {}
193
+ except Exception:
194
+ cfg = {}
195
+
196
+ # Defaults (add whatever else your app uses)
197
  defaults = {
198
  "dataset_name": "forgecaptions",
199
  "temperature": 0.6,
 
209
  "shape_aliases_enabled": True,
210
  "shape_aliases": [],
211
  "excel_thumb_px": 128,
 
212
  "logo_auto": True,
213
  "logo_px": 180,
214
  "logo_scale": 1.0,
215
+ # NEW: persist flag for aliases
216
+ "shape_aliases_persist": True,
217
  }
218
+
219
+ # Merge defaults without overwriting existing values
220
  for k, v in defaults.items():
221
  cfg.setdefault(k, v)
222
+
223
+ # Normalize styles to a valid list
224
  styles = cfg.get("styles") or []
225
+ if not isinstance(styles, list):
226
+ styles = [styles]
227
+ cfg["styles"] = [s for s in styles if s in STYLE_OPTIONS] or ["Character training (long)"]
228
+
229
  return cfg
230
 
231
+
232
  def save_journal(data: dict):
233
  with open(JOURNAL_FILE, "w", encoding="utf-8") as f:
234
  json.dump(data, f, ensure_ascii=False, indent=2)
 
243
  # ------------------------------
244
  # 5) Small utilities (thumbs, resize, prefix/suffix, names)
245
  # ------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  def sanitize_basename(s: str) -> str:
247
  s = (s or "").strip() or "forgecaptions"
248
  return re.sub(r"[^A-Za-z0-9._-]+", "_", s)[:120]
 
296
  return f"<img src='data:image/png;base64,{b64}' alt='ForgeCaptions' class='cf-logo'>"
297
  return ""
298
 
299
+ def _plural_token_regex(tok: str) -> str:
300
+ """
301
+ Build a regex for a token that also matches simple English plurals.
302
+ Rules:
303
+ - box / class / dish / buzz / rhombus → + (?:es)?
304
+ - lady (consonant+y) → (?:y|ies)
305
+ - default → + s?
306
+ This is case-insensitive at compile time.
307
+ """
308
+ t = (tok or "").strip()
309
+ if not t:
310
+ return ""
311
+ t_low = t.lower()
312
+ # consonant + y → y|ies
313
+ if re.search(r"[^aeiou]y$", t_low):
314
+ return re.escape(t[:-1]) + r"(?:y|ies)"
315
+ # s, x, z, ch, sh → +es
316
+ if re.search(r"(?:s|x|z|ch|sh)$", t_low):
317
+ return re.escape(t) + r"(?:es)?"
318
+ # default → +s
319
+ return re.escape(t) + r"s?"
320
 
321
  # ------------------------------
322
  # 6) Shape Aliases (comma/pipe synonyms per row)
323
  # ------------------------------
324
  def _compile_shape_aliases_from_file():
325
+ """
326
+ Build regex list from settings["shape_aliases"].
327
+ Left cell accepts comma OR pipe separated synonyms (multi-word OK).
328
+ Matches are case-insensitive, whole-word, captures simple plurals, and allows '-shaped' or ' shaped'.
329
+ """
330
  s = load_settings()
331
  if not s.get("shape_aliases_enabled", True):
332
  return []
 
339
  tokens = [t.strip() for t in re.split(r"[|,]", raw) if t.strip()]
340
  if not tokens:
341
  continue
342
+ # Build plural-aware alternation for all synonyms
343
+ alts = [_plural_token_regex(t) for t in tokens]
344
+ alts = [a for a in alts if a]
345
+ if not alts:
346
+ continue
347
+ # word-boundaries + optional "-shaped"/" shaped"
348
+ pat = r"\b(?:" + "|".join(alts) + r")(?:[-\s]?shaped)?\b"
349
  compiled.append((re.compile(pat, flags=re.I), name))
350
  return compiled
351
 
352
+
353
  _SHAPE_ALIASES = _compile_shape_aliases_from_file()
354
  def _refresh_shape_aliases_cache():
355
  global _SHAPE_ALIASES
 
368
  rows = [["", ""]]
369
  return rows, enabled
370
 
371
+ def save_shape_alias_rows(enabled, df_rows, persist):
372
+ """
373
+ Save or just apply alias rows.
374
+ - If persist=True → write to SETTINGS_FILE and recompile from file
375
+ - If persist=False → do NOT touch disk; just compile & apply in-memory
376
+ """
377
  cleaned = []
378
  for r in (df_rows or []):
379
  if not r:
 
382
  name = (r[1] or "").strip()
383
  if shape and name:
384
  cleaned.append({"shape": shape, "name": name})
385
+
386
+ status = "✅ Applied for this session only."
387
+ if persist:
388
+ cfg = load_settings()
389
+ cfg["shape_aliases_enabled"] = bool(enabled)
390
+ cfg["shape_aliases"] = cleaned
391
+ save_settings(cfg)
392
+ status = "💾 Saved to disk (will persist across restarts)."
393
+
394
+ # Recompile in-memory, regardless of persist
395
+ global _SHAPE_ALIASES
396
+ if bool(enabled):
397
+ # Build from the just-edited rows
398
+ tokens = []
399
+ compiled = []
400
+ # Emulate compiler using cleaned rows directly
401
+ compiled = []
402
+ for item in cleaned:
403
+ raw = item["shape"]
404
+ name = item["name"]
405
+ toks = [t.strip() for t in re.split(r"[|,]", raw) if t.strip()]
406
+ alts = [_plural_token_regex(t) for t in toks]
407
+ alts = [a for a in alts if a]
408
+ if not alts:
409
+ continue
410
+ pat = r"\b(?:" + "|".join(alts) + r")(?:[-\s]?shaped)?\b"
411
+ compiled.append((re.compile(pat, flags=re.I), name))
412
+ _SHAPE_ALIASES = compiled
413
+ else:
414
+ _SHAPE_ALIASES = []
415
+
416
  normalized = [[it["shape"], it["name"]] for it in cleaned] + [["", ""]]
417
+ return (
418
+ status,
419
+ gr.update(value=normalized, row_count=(max(1, len(normalized)), "dynamic"))
420
+ )
421
+
422
 
423
 
424
  # ------------------------------
 
527
  session_rows.append({"filename": filename, "caption": cap, "path": path, "thumb_path": thumb})
528
  processed += 1
529
 
530
+ if time_budget_s is not None and (time.time() - start) >= float(time_budget_s):
531
  leftover = files[idx+1:]
532
  break
533
 
 
557
 
558
 
559
  # ------------------------------
560
+ # 9) Export/Import helpers (CSV/XLSX/TXT ZIP)
561
  # ------------------------------
562
  def _rows_to_table(rows: List[dict]) -> list:
563
  return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])]
564
 
565
  def _table_to_rows(table_value: Any, rows: List[dict]) -> List[dict]:
566
+ # Expecting list-of-lists (Dataframe type="array")
567
+ tbl = table_value or []
568
  new = []
569
  for i, r in enumerate(rows or []):
570
  r = dict(r)
 
575
  return new
576
 
577
  def export_csv_from_table(table_value: Any, dataset_name: str) -> str:
578
+ data = table_value or []
579
  name = sanitize_basename(dataset_name)
580
  out = f"/tmp/{name}_{int(time.time())}.csv"
581
  with open(out, "w", newline="", encoding="utf-8") as f:
 
607
  raise RuntimeError("Excel export requires 'openpyxl' in requirements.txt.") from e
608
 
609
  caption_by_file = {}
610
+ for row in (table_value or []):
611
  if not row:
612
  continue
613
  fn = str(row[0]) if len(row) > 0 else ""
 
647
  """
648
  Create one .txt per caption, zip them.
649
  """
650
+ data = table_value or []
651
  # wipe old
652
  for fn in os.listdir(TXT_EXPORT_DIR):
653
  try:
 
679
  z.write(os.path.join(TXT_EXPORT_DIR, fn), arcname=fn)
680
  return zpath
681
 
682
+ def import_captions_file(file_path: str, session_rows: List[dict]) -> Tuple[List[dict], list, list, str]:
683
+ """
684
+ Import captions from CSV or XLSX and merge by filename into current session.
685
+ - If filename exists → update its caption
686
+ - Otherwise append a new row (without image path/thumbnail)
687
+ """
688
+ if not file_path or not os.path.exists(file_path):
689
+ # No change
690
+ table_rows = _rows_to_table(session_rows)
691
+ gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
692
+ for r in session_rows if (r.get("thumb_path") or r.get("path"))]
693
+ return session_rows, gallery_pairs, table_rows, "No file selected."
694
+
695
+ ext = os.path.splitext(file_path)[1].lower()
696
+ imported: List[Tuple[str, str]] = []
697
+
698
+ try:
699
+ if ext == ".csv":
700
+ with open(file_path, "r", encoding="utf-8") as f:
701
+ reader = csv.reader(f)
702
+ rows = list(reader)
703
+ # Skip header if present
704
+ if rows and len(rows[0]) >= 2 and str(rows[0][0]).lower().strip() == "filename":
705
+ rows = rows[1:]
706
+ for r in rows:
707
+ if not r or len(r) < 2:
708
+ continue
709
+ fn = str(r[0]).strip()
710
+ cap = str(r[1]).strip()
711
+ if fn:
712
+ imported.append((fn, cap))
713
+
714
+ elif ext in (".xlsx", ".xls"):
715
+ try:
716
+ from openpyxl import load_workbook
717
+ except Exception as e:
718
+ raise RuntimeError("XLSX import requires 'openpyxl' in requirements.txt.") from e
719
+ wb = load_workbook(file_path, read_only=True, data_only=True)
720
+ ws = wb.active
721
+ rows = list(ws.iter_rows(values_only=True))
722
+ if rows and len(rows[0]) >= 2 and str(rows[0][0]).lower().strip() == "filename":
723
+ rows = rows[1:]
724
+ for r in rows:
725
+ if not r or len(r) < 2:
726
+ continue
727
+ fn = str(r[0]).strip() if r[0] is not None else ""
728
+ cap = str(r[1]).strip() if r[1] is not None else ""
729
+ if fn:
730
+ imported.append((fn, cap))
731
+ else:
732
+ return session_rows, _rows_to_table(session_rows), _rows_to_table(session_rows), f"Unsupported file type: {ext}"
733
+ except Exception as e:
734
+ # Best-effort graceful fallback
735
+ table_rows = _rows_to_table(session_rows)
736
+ gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
737
+ for r in session_rows if (r.get("thumb_path") or r.get("path"))]
738
+ return session_rows, gallery_pairs, table_rows, f"Import failed: {e}"
739
+
740
+ # Merge
741
+ idx_by_name = {r.get("filename",""): i for i, r in enumerate(session_rows)}
742
+ updates, inserts = 0, 0
743
+ for fn, cap in imported:
744
+ if fn in idx_by_name:
745
+ session_rows[idx_by_name[fn]]["caption"] = cap
746
+ updates += 1
747
+ else:
748
+ session_rows.append({"filename": fn, "caption": cap, "path": "", "thumb_path": ""})
749
+ inserts += 1
750
+
751
+ save_session(session_rows)
752
+ gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
753
+ for r in session_rows if (r.get("thumb_path") or r.get("path"))]
754
+ table_rows = _rows_to_table(session_rows)
755
+ stamp = f"Imported {len(imported)} rows • updated {updates} • added {inserts} • {time.strftime('%H:%M:%S')}"
756
+ return session_rows, gallery_pairs, table_rows, stamp
757
+
758
 
759
  # ------------------------------
760
  # 10) UI header helper (logo auto-fit)
 
762
  def _render_header_html(auto: bool, px: int, scale: float) -> str:
763
  auto_js = "true" if auto else "false"
764
  return f"""
765
+ <div class="cf-hero">
766
+ {logo_b64_img()}
767
+ <div class="cf-text">
768
  <h1 class="cf-title">ForgeCaptions</h1>
769
  <div class="cf-sub">Batch captioning</div>
770
  <div class="cf-sub">Scrollable editor & autosave</div>
771
+ <div class="cf-sub">CSV / Excel / TXT export · CSV/XLSX import</div>
772
  </div>
773
  </div>
774
  <hr>
775
  <style>
776
+ .cf-logo {{ height: auto; width: auto; object-fit: contain; }}
 
 
 
 
 
 
 
 
777
  </style>
778
  <script>
779
  (function() {{
780
+ const AUTO = {auto_js};
781
+ const PX = {int(px)};
782
  const SCALE = {float(scale)};
 
783
  function fit() {{
 
 
784
  const logo = document.querySelector(".cf-logo");
785
+ const text = document.querySelector(".cf-text");
786
+ if (!logo || !text) return;
787
+ if (AUTO) {{
788
+ const h = text.getBoundingClientRect().height || 200;
789
+ const target = Math.max(80, Math.min(480, Math.round(h * SCALE)));
790
+ logo.style.height = target + "px";
791
+ }} else {{
792
+ logo.style.height = Math.max(80, Math.min(480, PX)) + "px";
793
+ }}
 
 
 
 
 
794
  }}
795
+ const textNode = document.querySelector(".cf-text");
796
+ if (window.ResizeObserver && textNode) {{
797
+ const ro = new ResizeObserver(fit);
798
+ ro.observe(textNode);
799
  }}
 
800
  window.addEventListener("resize", fit, {{ passive: true }});
 
 
 
 
 
 
 
801
  setTimeout(fit, 0);
 
 
802
  }})();
803
  </script>
804
  """
805
 
806
 
 
807
  # ------------------------------
808
  # 11) UI (Blocks)
809
  # ------------------------------
 
812
  .gradio-container{max-width:100%!important}
813
 
814
  /* Header */
815
+ .cf-hero{display:flex; align-items:center; justify-content:center; gap:16px;
816
+ margin:4px 0 12px; text-align:center;}
817
+ .cf-hero .cf-text{text-align:center;}
818
+ .cf-title{margin:0;font-size:3.25rem;line-height:1;letter-spacing:.2px}
819
+ .cf-sub{margin:6px 0 0;font-size:1.1rem;color:#cfd3da}
 
 
 
 
 
 
820
 
821
  /* Results area + robust scrollbars */
822
  .cf-scroll{border:1px solid #e6e6e6; border-radius:10px; padding:8px}
 
832
  # ---- Header
833
  settings = load_settings()
834
  header_html = gr.HTML(_render_header_html(settings.get("logo_auto", True),
835
+ settings.get("logo_px", 200),
836
+ settings.get("logo_scale", 1.10)))
837
 
838
  # ---- Controls group
839
  with gr.Group():
 
876
  value="Manual (step)", label="Batch mode"
877
  )
878
  chunk_size = gr.Slider(1, 50, value=10, step=1, label="Chunk size")
879
+ gpu_budget = gr.Slider(20, 110, value=55, step=5, label="Max seconds per GPU call")
880
+ no_time_limit = gr.Checkbox(value=False, label="No time limit (ignore above)")
 
 
 
 
881
 
882
  # Logo controls
883
  logo_auto = gr.Checkbox(value=settings.get("logo_auto", True),
884
  label="Auto-match logo height to text")
885
+ logo_px = gr.Slider(80, 480, value=settings.get("logo_px", 200),
886
  step=4, label="Logo height (px, if Auto off)")
887
+ logo_scale = gr.Slider(0.7, 1.6, value=settings.get("logo_scale", 1.10),
888
+ step=0.02, label="Logo scale × (if Auto on)")
889
 
890
  # Persist instruction + general settings
891
  def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms):
 
928
  logo_auto.change(_update_header, inputs=[logo_auto, logo_px, logo_scale], outputs=[header_html])
929
  logo_scale.change(_update_header, inputs=[logo_auto, logo_px, logo_scale], outputs=[header_html])
930
 
931
+ # ---- Shape Aliases block (placed WITH other settings, before uploads)
932
  # ---- Shape Aliases block (placed WITH other settings, before uploads)
933
  with gr.Accordion("Shape Aliases", open=False):
934
  gr.Markdown(
 
937
  "**How to use:**\n"
938
  "- Left column = a single token **or** comma/pipe-separated synonyms, e.g. `diamond, rhombus | lozenge`\n"
939
  "- Right column = replacement name, e.g. `starkey-emblem`\n"
940
+ "Matches are case-insensitive, whole-word, also catches simple plurals (`box`→`boxes`, `lady`→`ladies`), "
941
+ "and `*-shaped` or `* shaped` variants."
942
  )
943
  init_rows, init_enabled = get_shape_alias_rows_ui_defaults()
944
  enable_aliases = gr.Checkbox(label="Enable shape alias replacements", value=init_enabled)
945
+ persist_aliases = gr.Checkbox(
946
+ label="Save aliases across sessions",
947
+ value=load_settings().get("shape_aliases_persist", True)
948
+ )
949
  alias_table = gr.Dataframe(
950
  headers=["shape (token or synonyms)", "name to insert"],
951
  value=init_rows,
 
954
  type="array",
955
  interactive=True
956
  )
957
+
958
  with gr.Row():
959
  add_row_btn = gr.Button("+ Add row", variant="secondary")
960
  clear_btn = gr.Button("Clear", variant="secondary")
961
  save_btn = gr.Button("💾 Save", variant="primary")
962
+
963
+ # status line for saves
964
  save_status = gr.Markdown("")
965
+
966
+ def _add_row(cur):
967
+ cur = (cur or []) + [["", ""]]
968
+ return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic"))
969
+
970
+ def _clear_rows():
971
+ return gr.update(value=[["", ""]], row_count=(1, "dynamic"))
972
+
973
+ add_row_btn.click(_add_row, inputs=[alias_table], outputs=[alias_table])
974
+ clear_btn.click(_clear_rows, outputs=[alias_table])
975
+
976
+ # Persist the "persist" toggle itself
977
+ def _save_alias_persist_flag(v):
978
+ cfg = load_settings()
979
+ cfg["shape_aliases_persist"] = bool(v)
980
+ save_settings(cfg)
981
+ return gr.update()
982
+
983
+ persist_aliases.change(_save_alias_persist_flag, inputs=[persist_aliases], outputs=[])
984
+
985
+ # Save/apply rows (persists to disk only if persist_aliases is checked)
986
+ save_btn.click(
987
+ save_shape_alias_rows,
988
+ inputs=[enable_aliases, alias_table, persist_aliases],
989
+ outputs=[save_status, alias_table]
990
+ )
991
+
992
+ # Also persist the user's choice of the toggle itself
993
+ def _save_alias_persist_flag(v):
994
+ cfg = load_settings()
995
+ cfg["shape_aliases_persist"] = bool(v)
996
+ save_settings(cfg)
997
+ return gr.update()
998
+
999
+ persist_aliases.change(_save_alias_persist_flag, inputs=[persist_aliases], outputs=[])
1000
  def _add_row(cur):
1001
  cur = (cur or []) + [["", ""]]
1002
  return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic"))
 
1023
  input_files = gr.File(label="Drop images", file_types=["image"], file_count="multiple", type="filepath")
1024
  run_button = gr.Button("Caption batch", variant="primary")
1025
 
1026
+ with gr.Accordion("Import captions from CSV/XLSX (merge by filename)", open=False):
1027
+ import_file = gr.File(label="Choose .csv or .xlsx", file_types=[".csv", ".xlsx"], type="filepath")
1028
+ import_btn = gr.Button("Import into current session")
1029
+
1030
  # ---- Results area (gallery left / table right)
1031
  rows_state = gr.State(load_session())
1032
  autosave_md = gr.Markdown("Ready.")
 
1049
  headers=["filename", "caption"],
1050
  interactive=True,
1051
  wrap=True,
1052
+ type="array", # <— prevents DataFrame truth ambiguity
1053
  elem_id="cfTable"
1054
  )
1055
 
 
1122
  s = load_settings()
1123
  return s.get("temperature", 0.6), s.get("top_p", 0.9), s.get("max_tokens", 256)
1124
 
1125
+ def _run_click(files, rows, instr, ms, mode, csize, budget_s, no_limit):
1126
  t, p, m = _tpms()
1127
  files = files or []
1128
+ budget = None if no_limit else float(budget_s)
1129
 
1130
+ # Manual step → process first chunk only
1131
  if mode == "Manual (step)" and files:
1132
  chunks = _split_chunks(files, int(csize))
1133
  batch = chunks[0]
1134
  remaining = sum(chunks[1:], [])
1135
  new_rows, gal, tbl, stamp, leftover_from_batch, done, total = run_batch(
1136
+ batch, rows or [], instr, t, p, m, int(ms), budget
1137
  )
1138
  remaining = (leftover_from_batch or []) + remaining
1139
  panel_vis = gr.update(visible=bool(remaining))
 
1141
  prog = f"Batch progress: {done}/{total} processed in this step • Remaining overall: {len(remaining)}"
1142
  return new_rows, gal, tbl, stamp, remaining, panel_vis, gr.update(value=msg), gr.update(value=prog)
1143
 
1144
+ # Auto / all-at-once
1145
  new_rows, gal, tbl, stamp, leftover, done, total = run_batch(
1146
+ files, rows or [], instr, t, p, m, int(ms), budget
1147
  )
1148
  panel_vis = gr.update(visible=bool(leftover))
1149
  msg = f"{len(leftover)} files remain. Process next chunk?" if leftover else ""
 
1156
  outputs=[rows_state, gallery, table, autosave_md, remaining_state, step_panel, step_msg, progress_md]
1157
  )
1158
 
1159
+ def _step_next(remain, rows, instr, ms, csize, budget_s, no_limit):
1160
  t, p, m = _tpms()
1161
  remain = remain or []
1162
  budget = None if no_limit else float(budget_s)
1163
+
1164
  if not remain:
1165
+ return (
1166
+ rows,
1167
+ gr.update(value="No files remaining."),
1168
+ gr.update(visible=False),
1169
+ [],
1170
+ [],
1171
+ [],
1172
+ "Saved.",
1173
+ gr.update(value="")
1174
+ )
1175
  batch = remain[:int(csize)]
1176
  leftover = remain[int(csize):]
1177
  new_rows, gal, tbl, stamp, leftover_from_batch, done, total = run_batch(
 
1204
 
1205
  table.change(sync_table_to_session, inputs=[table, rows_state], outputs=[rows_state, gallery, autosave_md])
1206
 
1207
+ # ---- Import hook
1208
+ def _do_import(fpath, rows):
1209
+ new_rows, gal, tbl, stamp = import_captions_file(fpath, rows or [])
1210
+ return new_rows, gal, tbl, stamp
1211
+
1212
+ import_btn.click(
1213
+ _do_import,
1214
+ inputs=[import_file, rows_state],
1215
+ outputs=[rows_state, gallery, table, autosave_md]
1216
+ )
1217
+
1218
  # ---- Exports
1219
  export_csv_btn.click(
1220
  lambda tbl, ds: (export_csv_from_table(tbl, ds), gr.update(visible=True)),