JS6969 commited on
Commit
0f30e70
·
verified ·
1 Parent(s): 4ff6c0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -208
app.py CHANGED
@@ -5,28 +5,20 @@
5
  # ------------------------------
6
  # 0) Imports & environment
7
  # ------------------------------
8
- import os
9
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
10
  os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")
11
  os.makedirs(os.environ["HF_HOME"], exist_ok=True)
12
- import io, csv, time, json, base64, re, zipfile
13
- from typing import Generator, List, Tuple, Dict, Any
14
- from threading import Thread
15
- # Persist model caches between restarts
16
 
 
 
17
 
18
  import gradio as gr
19
  from PIL import Image
20
  import torch
21
- from transformers import LlavaForConditionalGeneration, AutoProcessor, TextIteratorStreamer
22
-
23
- # Optional deps for import/export (we handle gracefully if missing)
24
- try:
25
- import pandas as pd # not required at runtime; kept for future use
26
- except Exception:
27
- pd = None
28
 
29
- # Liger is optional; skip if missing
30
  try:
31
  from liger_kernel.transformers import apply_liger_kernel_to_llama
32
  except Exception:
@@ -46,11 +38,10 @@ except Exception:
46
  # ------------------------------
47
  APP_DIR = os.getcwd()
48
  SESSION_FILE = "/tmp/forge_session.json"
49
- # --- Branding
50
 
 
51
  LOGO_HEIGHT_PX = int(os.getenv("FORGE_LOGO_PX", 60))
52
 
53
-
54
  # Settings live in a user cache dir (persists better than /tmp)
55
  CONFIG_DIR = os.path.expanduser("~/.cache/forgecaptions")
56
  os.makedirs(CONFIG_DIR, exist_ok=True)
@@ -58,7 +49,7 @@ SETTINGS_FILE = os.path.join(CONFIG_DIR, "settings.json")
58
 
59
  JOURNAL_FILE = "/tmp/forge_journal.json"
60
 
61
- # IMPORTANT: keep generated assets in /tmp so Gradio can serve them safely
62
  THUMB_CACHE = "/tmp/forgecaptions/thumbs"
63
  EXCEL_THUMB_DIR = "/tmp/forge_excel_thumbs"
64
  TXT_EXPORT_DIR = "/tmp/forge_txt"
@@ -80,8 +71,6 @@ print(f"[ForgeCaptions] Gradio version: {gr.__version__}")
80
 
81
  # ------------------------------
82
  # 2) Model loader (GPU-safe lazy init)
83
- # - Processor on CPU
84
- # - Model is ONLY created inside @gpu functions to satisfy Stateless GPU
85
  # ------------------------------
86
  processor = AutoProcessor.from_pretrained(MODEL_PATH)
87
  _MODEL = None
@@ -104,8 +93,8 @@ def get_model():
104
  low_cpu_mem_usage=True,
105
  device_map=0,
106
  )
 
107
  try:
108
- from liger_kernel.transformers import apply_liger_kernel_to_llama
109
  lm = getattr(_MODEL, "language_model", None) or getattr(_MODEL, "model", None)
110
  if lm is not None:
111
  ok = apply_liger_kernel_to_llama(lm)
@@ -137,19 +126,19 @@ STYLE_OPTIONS = [
137
  "Flux.1-Dev",
138
  "Stable Diffusion",
139
  "MidJourney",
140
- "E-commerce product",
141
  "Portrait (photography)",
142
  "Landscape (photography)",
143
  "Art analysis (no artist names)",
144
  "Social caption",
145
  "Aesthetic tags (comma-sep)"
146
- ]
147
 
148
  CAPTION_TYPE_MAP: Dict[str, str] = {
149
  "Descriptive": "Write a detailed description for this image.",
150
  "Character training": (
151
  "Write a thorough, training-ready caption for a character dataset. "
152
- "Describe subject appearance (physique, face/hair), clothing and accessories, actions/pose/gesture, camera angle/focal cues."
153
  "If multiple subjects are present, describe each briefly (most prominent first) and distinguish them by visible traits."
154
  ),
155
  "Flux.1-Dev": "Write a Flux.1-Dev style prompt that would reproduce this image faithfully.",
@@ -163,7 +152,6 @@ CAPTION_TYPE_MAP: Dict[str, str] = {
163
  "Social caption": "Write an engaging caption describing the visible content. No hashtags.",
164
  }
165
 
166
-
167
  LENGTH_CHOICES = ["any", "very short", "short", "medium-length", "long", "very long"] + [str(i) for i in range(20, 261, 10)]
168
 
169
  _LENGTH_HINTS = {
@@ -173,7 +161,6 @@ _LENGTH_HINTS = {
173
  "long": "Write a detailed caption (≈80–120 words).",
174
  "very long": "Write a very detailed caption (≈150–250 words).",
175
  }
176
-
177
  def _length_hint(choice: str) -> str:
178
  if not choice or choice == "any":
179
  return ""
@@ -181,10 +168,9 @@ def _length_hint(choice: str) -> str:
181
  return f"Limit the caption to at most {choice} words."
182
  return _LENGTH_HINTS.get(choice, "")
183
 
184
-
185
  EXTRA_CHOICES = [
186
  "Only include a character's modifiable, style-level attributes (hair style/color, makeup, clothing/accessories, pose, expression). Do NOT mention identity traits (skin tone, age, body type).",
187
- "Use profanity when describing sexual acts or genitalia (pussy, fucking, cum, cock, etc.).", # You do NOT have permission to remove
188
  "Be sexually graphic and describe sexual position when visible.",
189
  "Exclude brand names, watermarks, artist names, file/meta info, quality terms, hashtags, and resolution notes.",
190
  "Arrange information in the order: Subject; Clothing/Accessories; Action/Pose; Setting/Environment; Lighting/Camera.",
@@ -204,7 +190,6 @@ EXTRA_CHOICES = [
204
  ]
205
  NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}."
206
 
207
- # Place near EXTRA_CHOICES
208
  DEFAULT_EXTRA_CHOICES = [
209
  "Use profanity when describing sexual acts or genitalia (pussy, fucking, cum, cock, etc.).",
210
  "Be sexually graphic and describe sexual position when visible.",
@@ -259,15 +244,16 @@ def load_settings() -> dict:
259
  "shape_aliases_enabled": True,
260
  "shape_aliases": [],
261
  "excel_thumb_px": 128,
262
- "logo_px": 60,
263
  "shape_aliases_persist": True,
264
  "extras": DEFAULT_EXTRA_CHOICES,
 
265
  }
266
 
267
  for k, v in defaults.items():
268
  cfg.setdefault(k, v)
269
 
270
- # Normalize styles to a valid list
271
  styles = cfg.get("styles") or []
272
  if not isinstance(styles, list):
273
  styles = [styles]
@@ -348,13 +334,6 @@ def logo_b64_img() -> str:
348
  # 6) Shape Aliases (plural-aware + '-shaped' variants)
349
  # ------------------------------
350
  def _plural_token_regex(tok: str) -> str:
351
- """
352
- Build a regex for a token that also matches simple English plurals.
353
- Rules:
354
- - endswith s/x/z/ch/sh → add '(?:es)?'
355
- - consonant + y → '(?:y|ies)'
356
- - default → 's?'
357
- """
358
  t = (tok or "").strip()
359
  if not t: return ""
360
  t_low = t.lower()
@@ -365,11 +344,6 @@ def _plural_token_regex(tok: str) -> str:
365
  return re.escape(t) + r"s?"
366
 
367
  def _compile_shape_aliases_from_file():
368
- """
369
- Build regex list from settings["shape_aliases"].
370
- Left cell accepts comma OR pipe separated synonyms (multi-word OK).
371
- Matches are case-insensitive, catches simple plurals, and allows '-shaped' or ' shaped'.
372
- """
373
  s = load_settings()
374
  if not s.get("shape_aliases_enabled", True):
375
  return []
@@ -409,11 +383,6 @@ def get_shape_alias_rows_ui_defaults():
409
  return rows, enabled
410
 
411
  def save_shape_alias_rows(enabled, df_rows, persist):
412
- """
413
- Save or just apply alias rows.
414
- - If persist=True → write to SETTINGS_FILE and recompile from file
415
- - If persist=False → do NOT touch disk; just compile & apply in-memory
416
- """
417
  cleaned = []
418
  for r in (df_rows or []):
419
  if not r:
@@ -431,7 +400,6 @@ def save_shape_alias_rows(enabled, df_rows, persist):
431
  save_settings(cfg)
432
  status = "💾 Saved to disk (will persist across restarts)."
433
 
434
- # Recompile in-memory, regardless of persist
435
  global _SHAPE_ALIASES
436
  if bool(enabled):
437
  compiled = []
@@ -453,7 +421,7 @@ def save_shape_alias_rows(enabled, df_rows, persist):
453
 
454
 
455
  # ------------------------------
456
- # 7) Prompt builder (instruction text shown/used for model)
457
  # ------------------------------
458
  def final_instruction(style_list: List[str], extra_opts: List[str], name_value: str, length_choice: str = "long") -> str:
459
  styles = style_list or ["Character training"]
@@ -463,10 +431,10 @@ def final_instruction(style_list: List[str], extra_opts: List[str], name_value:
463
  core += " " + " ".join(extra_opts)
464
  if NAME_OPTION in (extra_opts or []):
465
  core = core.replace("{name}", (name_value or "{NAME}").strip())
466
- if "Aesthetic tags (comma-sep)" not in styles: # If they're asking for comma-separated tags, ignore word-length guidance.
467
  lh = _length_hint(length_choice or "any")
468
  if lh:
469
- core += " " + lh
470
  return core
471
 
472
 
@@ -486,9 +454,6 @@ def _build_inputs(im: Image.Image, instr: str, dtype) -> Dict[str, Any]:
486
 
487
  @torch.no_grad()
488
  def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
489
- """
490
- NOTE: Not GPU-decorated on purpose; call this only from within a @gpu function.
491
- """
492
  model, device, dtype = get_model()
493
  inputs = _build_inputs(im, instr, dtype)
494
  inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
@@ -525,8 +490,8 @@ def run_batch(
525
  top_p: float,
526
  max_tokens: int,
527
  max_side: int,
528
- time_budget_s: float | None = None, # respects Zero-GPU window (None = unlimited)
529
- progress: gr.Progress = gr.Progress(track_tqdm=True), # drives the progress bar
530
  ) -> Tuple[List[dict], list, list, str, List[str], int, int]:
531
  """
532
  Returns:
@@ -580,17 +545,14 @@ def run_batch(
580
  total,
581
  )
582
 
583
- @gpu
584
- @torch.no_grad()
585
 
586
  # ------------------------------
587
- # 9) Export/Import helpers (CSV/XLSX/TXT ZIP)
588
  # ------------------------------
589
  def _rows_to_table(rows: List[dict]) -> list:
590
  return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])]
591
 
592
  def _table_to_rows(table_value: Any, rows: List[dict]) -> List[dict]:
593
- # Expect list-of-lists (Dataframe type="array")
594
  tbl = table_value or []
595
  new = []
596
  for i, r in enumerate(rows or []):
@@ -706,79 +668,6 @@ def export_txt_zip(table_value: Any, dataset_name: str) -> str:
706
  z.write(os.path.join(TXT_EXPORT_DIR, fn), arcname=fn)
707
  return zpath
708
 
709
- def import_captions_file(file_path: str, session_rows: List[dict]) -> Tuple[List[dict], list, list, str]:
710
- """
711
- Import captions from CSV or XLSX and merge by filename into current session.
712
- - If filename exists → update its caption
713
- - Otherwise append a new row (without image path/thumbnail)
714
- """
715
- if not file_path or not os.path.exists(file_path):
716
- table_rows = _rows_to_table(session_rows)
717
- gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
718
- for r in session_rows if (r.get("thumb_path") or r.get("path"))]
719
- return session_rows, gallery_pairs, table_rows, "No file selected."
720
-
721
- ext = os.path.splitext(file_path)[1].lower()
722
- imported: List[Tuple[str, str]] = []
723
-
724
- try:
725
- if ext == ".csv":
726
- with open(file_path, "r", encoding="utf-8") as f:
727
- reader = csv.reader(f)
728
- rows = list(reader)
729
- if rows and len(rows[0]) >= 2 and str(rows[0][0]).lower().strip() == "filename":
730
- rows = rows[1:]
731
- for r in rows:
732
- if not r or len(r) < 2:
733
- continue
734
- fn = str(r[0]).strip()
735
- cap = str(r[1]).strip()
736
- if fn:
737
- imported.append((fn, cap))
738
-
739
- elif ext in (".xlsx", ".xls"):
740
- try:
741
- from openpyxl import load_workbook
742
- except Exception as e:
743
- raise RuntimeError("XLSX import requires 'openpyxl' in requirements.txt.") from e
744
- wb = load_workbook(file_path, read_only=True, data_only=True)
745
- ws = wb.active
746
- rows = list(ws.iter_rows(values_only=True))
747
- if rows and len(rows[0]) >= 2 and str(rows[0][0]).lower().strip() == "filename":
748
- rows = rows[1:]
749
- for r in rows:
750
- if not r or len(r) < 2:
751
- continue
752
- fn = str(r[0]).strip() if r[0] is not None else ""
753
- cap = str(r[1]).strip() if r[1] is not None else ""
754
- if fn:
755
- imported.append((fn, cap))
756
- else:
757
- return session_rows, _rows_to_table(session_rows), _rows_to_table(session_rows), f"Unsupported file type: {ext}"
758
- except Exception as e:
759
- table_rows = _rows_to_table(session_rows)
760
- gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
761
- for r in session_rows if (r.get("thumb_path") or r.get("path"))]
762
- return session_rows, gallery_pairs, table_rows, f"Import failed: {e}"
763
-
764
- # Merge
765
- idx_by_name = {r.get("filename",""): i for i, r in enumerate(session_rows)}
766
- updates, inserts = 0, 0
767
- for fn, cap in imported:
768
- if fn in idx_by_name:
769
- session_rows[idx_by_name[fn]]["caption"] = cap
770
- updates += 1
771
- else:
772
- session_rows.append({"filename": fn, "caption": cap, "path": "", "thumb_path": ""})
773
- inserts += 1
774
-
775
- save_session(session_rows)
776
- gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
777
- for r in session_rows if (r.get("thumb_path") or r.get("path"))]
778
- table_rows = _rows_to_table(session_rows)
779
- stamp = f"Imported {len(imported)} rows • updated {updates} • added {inserts} • {time.strftime('%H:%M:%S')}"
780
- return session_rows, gallery_pairs, table_rows, stamp
781
-
782
 
783
  # ------------------------------
784
  # 10) UI header helper (fixed logo size)
@@ -801,13 +690,15 @@ def _render_header_html(px: int) -> str:
801
  width: auto;
802
  object-fit: contain;
803
  display: block;
 
804
  }}
805
  @media (max-width: 640px) {{
806
- .cf-logo {{ height: {max(60, int(px) - 12)}px; }} /* optional small-screen tweak */
807
  }}
808
  </style>
809
  """
810
 
 
811
  # ------------------------------
812
  # 11) UI (Blocks)
813
  # ------------------------------
@@ -819,8 +710,8 @@ BASE_CSS = """
819
  .cf-hero{display:flex; align-items:center; justify-content:center; gap:16px;
820
  margin:4px 0 12px; text-align:center;}
821
  .cf-hero .cf-text{text-align:center;}
822
- .cf-title{margin:0;font-size:3.25rem;line-height:1;letter-spacing:.2px}
823
- .cf-sub{margin:6px 0 0;font-size:1.1rem;color:#cfd3da}
824
 
825
  /* Results area + robust scrollbars */
826
  .cf-scroll{border:1px solid #e6e6e6; border-radius:10px; padding:8px}
@@ -831,14 +722,9 @@ BASE_CSS = """
831
  """
832
 
833
  with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
834
- # Ensure Spaces sees a GPU function (without touching CUDA in main)
835
- demo.load(inputs=None, outputs=None)
836
-
837
  # ---- Header
838
  settings = load_settings()
839
- header_html = gr.HTML(_render_header_html(LOGO_HEIGHT_PX))
840
-
841
-
842
 
843
  # ---- Controls group
844
  with gr.Group():
@@ -867,19 +753,18 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
867
  trig = gr.Textbox(label="Trigger word", value=settings.get("trigger",""))
868
  add_start = gr.Textbox(label="Add text to start", value=settings.get("begin",""))
869
  add_end = gr.Textbox(label="Add text to end", value=settings.get("end",""))
870
-
871
 
872
- # RIGHT: instructions + dataset + general sliders + logo controls
873
  with gr.Column(scale=1):
874
  with gr.Accordion("Model Instructions", open=False):
875
  instruction_preview = gr.Textbox(
876
- label=None,
877
  lines=12,
878
  value=final_instruction(
879
  settings.get("styles", ["Character training"]),
880
  settings.get("extras", []),
881
  settings.get("name",""),
882
- settings.get("caption_length", "long"),
883
  ),
884
  )
885
  dataset_name = gr.Textbox(label="Dataset name (export title prefix)",
@@ -896,14 +781,13 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
896
  gpu_budget = gr.Slider(20, 110, value=55, step=5, label="Max seconds per GPU call")
897
  no_time_limit = gr.Checkbox(value=False, label="No time limit (ignore above)")
898
 
899
-
900
  # Persist instruction + general settings
901
  def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms, cap_len):
902
  instr = final_instruction(styles or ["Character training"], extra or [], name_value, cap_len)
903
  cfg = load_settings()
904
  cfg.update({
905
  "styles": styles or ["Character training"],
906
- "extras": extra or [],
907
  "name": name_value,
908
  "trigger": trigv, "begin": begv, "end": endv,
909
  "excel_thumb_px": int(excel_px),
@@ -914,9 +798,11 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
914
  return instr
915
 
916
  for comp in [style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side, caption_length]:
917
- comp.change(_refresh_instruction,
918
- inputs=[style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side, caption_length],
919
- outputs=[instruction_preview])
 
 
920
 
921
  def _save_dataset_name(name):
922
  cfg = load_settings()
@@ -925,8 +811,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
925
  return gr.update()
926
  dataset_name.change(_save_dataset_name, inputs=[dataset_name], outputs=[])
927
 
928
-
929
- # ---- Shape Aliases (with plural matching + persist) ----
930
  with gr.Accordion("Shape Aliases", open=False):
931
  gr.Markdown(
932
  "### 🔷 Shape Aliases\n"
@@ -934,8 +819,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
934
  "**How to use:**\n"
935
  "- Left column = a single token **or** comma/pipe-separated synonyms, e.g. `diamond, rhombus | lozenge`\n"
936
  "- Right column = replacement name, e.g. `family-emblem`\n"
937
- "Matches are case-insensitive, catches simple plurals (`box`→`boxes`, `lady`→`ladies`), "
938
- "and also matches `*-shaped` or `* shaped` variants."
939
  )
940
 
941
  init_rows, init_enabled = get_shape_alias_rows_ui_defaults()
@@ -959,10 +843,8 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
959
  clear_btn = gr.Button("Clear", variant="secondary")
960
  save_btn = gr.Button("💾 Save", variant="primary")
961
 
962
- # status line for saves
963
  save_status = gr.Markdown("")
964
 
965
- # --- local handlers (must stay inside Blocks context) ---
966
  def _add_row(cur):
967
  cur = (cur or []) + [["", ""]]
968
  return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic"))
@@ -980,7 +862,6 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
980
  return gr.update()
981
  persist_aliases.change(_save_alias_persist_flag, inputs=[persist_aliases], outputs=[])
982
 
983
- # Persist rows if persist_aliases checked; otherwise apply in-memory only
984
  save_btn.click(
985
  save_shape_alias_rows,
986
  inputs=[enable_aliases, alias_table, persist_aliases],
@@ -999,28 +880,16 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
999
  outputs=[single_caption_out]
1000
  )
1001
 
1002
- # with gr.Tab("Batch"):
1003
- # with gr.Accordion("Uploaded images", open=True):
1004
- # input_files = gr.File(label="Drop images (or click to select)", file_types=["image"], file_count="multiple", type="filepath")
1005
- # run_button = gr.Button("Caption batch", variant="primary")
1006
-
1007
- # with gr.Accordion("Import captions from CSV/XLSX (merge by filename)", open=False):
1008
- # import_file = gr.File(label="Choose .csv or .xlsx", file_types=[".csv", ".xlsx"], type="filepath")
1009
- # import_btn = gr.Button("Import into current session")
1010
-
1011
  with gr.Tab("Batch"):
1012
  with gr.Accordion("Uploaded images", open=True):
1013
- input_files = gr.File(label="Drop images (or click to select)", file_types=["image"], file_count="multiple",)
1014
- preview_gallery = gr.Gallery(
1015
- label="Preview (un-captioned)",
1016
- show_label=True,
1017
- columns=5,
1018
- height=220,
1019
- )
1020
- input_files.change(on_files_changed, inputs=[input_files], outputs=[preview_gallery])
1021
  run_button = gr.Button("Caption batch", variant="primary")
1022
 
1023
-
1024
  # ---- Results area (gallery left / table right)
1025
  rows_state = gr.State(load_session())
1026
  autosave_md = gr.Markdown("Ready.")
@@ -1043,7 +912,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
1043
  headers=["filename", "caption"],
1044
  interactive=True,
1045
  wrap=True,
1046
- type="array", # prevents pandas truth ambiguity
1047
  elem_id="cfTable"
1048
  )
1049
 
@@ -1066,7 +935,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
1066
  export_txt_btn = gr.Button("Export captions as .txt (zip)")
1067
  txt_zip = gr.File(label="TXT zip", visible=False)
1068
 
1069
- # ---- Robust scroll sync (works with Gradio v5 Gallery)
1070
  gr.HTML("""
1071
  <script>
1072
  (function () {
@@ -1075,9 +944,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
1075
  if (!host) return null;
1076
  return host.querySelector('[data-testid="gallery"]') || host;
1077
  }
1078
- function findTbl() {
1079
- return document.querySelector("#cfTableWrap");
1080
- }
1081
  function syncScroll(a, b) {
1082
  if (!a || !b) return;
1083
  let lock = false;
@@ -1121,7 +988,6 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
1121
  files = files or []
1122
  budget = None if no_limit else float(budget_s)
1123
 
1124
- # Manual step → process first chunk only
1125
  if mode == "Manual (step)" and files:
1126
  chunks = _split_chunks(files, int(csize))
1127
  batch = chunks[0]
@@ -1135,7 +1001,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
1135
  prog = f"Batch progress: {done}/{total} processed in this step • Remaining overall: {len(remaining)}"
1136
  return new_rows, gal, tbl, stamp, remaining, panel_vis, gr.update(value=msg), gr.update(value=prog)
1137
 
1138
- # Auto
1139
  new_rows, gal, tbl, stamp, leftover, done, total = run_batch(
1140
  files, rows or [], instr, t, p, m, int(ms), budget
1141
  )
@@ -1147,21 +1013,9 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
1147
  run_button.click(
1148
  _run_click,
1149
  inputs=[input_files, rows_state, instruction_preview, max_side, chunk_mode, chunk_size, gpu_budget, no_time_limit],
1150
- outputs=[rows_state, gallery, table, autosave_md, remaining_state, step_panel, step_msg, progress_md],
1151
- ).then(
1152
- lambda rows: [(Image.open(r["path"]).convert("RGB"), r["caption"]) for r in rows],
1153
- inputs=[rows_state],
1154
- outputs=[gallery],
1155
- )
1156
- table.change(
1157
- sync_table_to_session,
1158
- inputs=[table, rows_state],
1159
- outputs=[rows_state, captions_text],
1160
- ).then(
1161
- lambda rows: [(Image.open(r["path"]).convert("RGB"), r["caption"]) for r in rows],
1162
- inputs=[rows_state],
1163
- outputs=[gallery],
1164
  )
 
1165
  def _step_next(remain, rows, instr, ms, csize, budget_s, no_limit):
1166
  t, p, m = _tpms()
1167
  remain = remain or []
@@ -1171,7 +1025,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
1171
  return (
1172
  rows,
1173
  gr.update(value="No files remaining."),
1174
- gr.update(visible=True),
1175
  [],
1176
  [],
1177
  [],
@@ -1207,16 +1061,8 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
1207
  gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
1208
  for r in session_rows if (r.get("thumb_path") or r.get("path"))]
1209
  return session_rows, gallery_pairs, f"Saved • {time.strftime('%H:%M:%S')}"
1210
- table.change(sync_table_to_session, inputs=[table, rows_state], outputs=[rows_state, gallery, autosave_md])
1211
-
1212
- def new_session() -> Tuple[List[dict], list, list, str]:
1213
- return [], [], _rows_to_table([]), ""
1214
 
1215
- # ---- Import hook
1216
- def _do_import(fpath, rows):
1217
- new_rows, gal, tbl, stamp = import_captions_file(fpath, rows or [])
1218
- return new_rows, gal, tbl, stamp
1219
- import_btn.click(_do_import, inputs=[import_file, rows_state], outputs=[rows_state, gallery, table, autosave_md])
1220
 
1221
  # ---- Exports
1222
  export_csv_btn.click(
@@ -1234,7 +1080,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
1234
 
1235
 
1236
  # ------------------------------
1237
- # 12) Launch (SSR disabled for stability on Spaces)
1238
  # ------------------------------
1239
  if __name__ == "__main__":
1240
  demo.queue(max_size=64).launch(
@@ -1243,6 +1089,5 @@ if __name__ == "__main__":
1243
  ssr_mode=False,
1244
  debug=True,
1245
  show_error=True,
1246
- # Allow Gradio to serve generated files from /tmp caches
1247
  allowed_paths=[THUMB_CACHE, EXCEL_THUMB_DIR, TXT_EXPORT_DIR],
1248
  )
 
5
  # ------------------------------
6
  # 0) Imports & environment
7
  # ------------------------------
8
+ import os
9
  os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
10
  os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")
11
  os.makedirs(os.environ["HF_HOME"], exist_ok=True)
 
 
 
 
12
 
13
+ import io, csv, time, json, base64, re, zipfile
14
+ from typing import List, Tuple, Dict, Any
15
 
16
  import gradio as gr
17
  from PIL import Image
18
  import torch
19
+ from transformers import LlavaForConditionalGeneration, AutoProcessor
 
 
 
 
 
 
20
 
21
+ # Optional: Liger kernel (ignored if missing)
22
  try:
23
  from liger_kernel.transformers import apply_liger_kernel_to_llama
24
  except Exception:
 
38
  # ------------------------------
39
  APP_DIR = os.getcwd()
40
  SESSION_FILE = "/tmp/forge_session.json"
 
41
 
42
+ # Branding: fixed logo height
43
  LOGO_HEIGHT_PX = int(os.getenv("FORGE_LOGO_PX", 60))
44
 
 
45
  # Settings live in a user cache dir (persists better than /tmp)
46
  CONFIG_DIR = os.path.expanduser("~/.cache/forgecaptions")
47
  os.makedirs(CONFIG_DIR, exist_ok=True)
 
49
 
50
  JOURNAL_FILE = "/tmp/forge_journal.json"
51
 
52
+ # Generated assets in /tmp so Gradio can serve them safely
53
  THUMB_CACHE = "/tmp/forgecaptions/thumbs"
54
  EXCEL_THUMB_DIR = "/tmp/forge_excel_thumbs"
55
  TXT_EXPORT_DIR = "/tmp/forge_txt"
 
71
 
72
  # ------------------------------
73
  # 2) Model loader (GPU-safe lazy init)
 
 
74
  # ------------------------------
75
  processor = AutoProcessor.from_pretrained(MODEL_PATH)
76
  _MODEL = None
 
93
  low_cpu_mem_usage=True,
94
  device_map=0,
95
  )
96
+ # Try to enable Liger on the LLM submodule (best-effort, silent if missing)
97
  try:
 
98
  lm = getattr(_MODEL, "language_model", None) or getattr(_MODEL, "model", None)
99
  if lm is not None:
100
  ok = apply_liger_kernel_to_llama(lm)
 
126
  "Flux.1-Dev",
127
  "Stable Diffusion",
128
  "MidJourney",
129
+ "E-commerce product",
130
  "Portrait (photography)",
131
  "Landscape (photography)",
132
  "Art analysis (no artist names)",
133
  "Social caption",
134
  "Aesthetic tags (comma-sep)"
135
+ ]
136
 
137
  CAPTION_TYPE_MAP: Dict[str, str] = {
138
  "Descriptive": "Write a detailed description for this image.",
139
  "Character training": (
140
  "Write a thorough, training-ready caption for a character dataset. "
141
+ "Describe subject appearance (physique, face/hair), clothing and accessories, actions/pose/gesture, camera angle/focal cues. "
142
  "If multiple subjects are present, describe each briefly (most prominent first) and distinguish them by visible traits."
143
  ),
144
  "Flux.1-Dev": "Write a Flux.1-Dev style prompt that would reproduce this image faithfully.",
 
152
  "Social caption": "Write an engaging caption describing the visible content. No hashtags.",
153
  }
154
 
 
155
  LENGTH_CHOICES = ["any", "very short", "short", "medium-length", "long", "very long"] + [str(i) for i in range(20, 261, 10)]
156
 
157
  _LENGTH_HINTS = {
 
161
  "long": "Write a detailed caption (≈80–120 words).",
162
  "very long": "Write a very detailed caption (≈150–250 words).",
163
  }
 
164
  def _length_hint(choice: str) -> str:
165
  if not choice or choice == "any":
166
  return ""
 
168
  return f"Limit the caption to at most {choice} words."
169
  return _LENGTH_HINTS.get(choice, "")
170
 
 
171
  EXTRA_CHOICES = [
172
  "Only include a character's modifiable, style-level attributes (hair style/color, makeup, clothing/accessories, pose, expression). Do NOT mention identity traits (skin tone, age, body type).",
173
+ "Use profanity when describing sexual acts or genitalia (pussy, fucking, cum, cock, etc.).",
174
  "Be sexually graphic and describe sexual position when visible.",
175
  "Exclude brand names, watermarks, artist names, file/meta info, quality terms, hashtags, and resolution notes.",
176
  "Arrange information in the order: Subject; Clothing/Accessories; Action/Pose; Setting/Environment; Lighting/Camera.",
 
190
  ]
191
  NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}."
192
 
 
193
  DEFAULT_EXTRA_CHOICES = [
194
  "Use profanity when describing sexual acts or genitalia (pussy, fucking, cum, cock, etc.).",
195
  "Be sexually graphic and describe sexual position when visible.",
 
244
  "shape_aliases_enabled": True,
245
  "shape_aliases": [],
246
  "excel_thumb_px": 128,
247
+ "logo_px": LOGO_HEIGHT_PX,
248
  "shape_aliases_persist": True,
249
  "extras": DEFAULT_EXTRA_CHOICES,
250
+ "caption_length": "long",
251
  }
252
 
253
  for k, v in defaults.items():
254
  cfg.setdefault(k, v)
255
 
256
+ # Normalize
257
  styles = cfg.get("styles") or []
258
  if not isinstance(styles, list):
259
  styles = [styles]
 
334
  # 6) Shape Aliases (plural-aware + '-shaped' variants)
335
  # ------------------------------
336
  def _plural_token_regex(tok: str) -> str:
 
 
 
 
 
 
 
337
  t = (tok or "").strip()
338
  if not t: return ""
339
  t_low = t.lower()
 
344
  return re.escape(t) + r"s?"
345
 
346
  def _compile_shape_aliases_from_file():
 
 
 
 
 
347
  s = load_settings()
348
  if not s.get("shape_aliases_enabled", True):
349
  return []
 
383
  return rows, enabled
384
 
385
  def save_shape_alias_rows(enabled, df_rows, persist):
 
 
 
 
 
386
  cleaned = []
387
  for r in (df_rows or []):
388
  if not r:
 
400
  save_settings(cfg)
401
  status = "💾 Saved to disk (will persist across restarts)."
402
 
 
403
  global _SHAPE_ALIASES
404
  if bool(enabled):
405
  compiled = []
 
421
 
422
 
423
  # ------------------------------
424
+ # 7) Prompt builder
425
  # ------------------------------
426
  def final_instruction(style_list: List[str], extra_opts: List[str], name_value: str, length_choice: str = "long") -> str:
427
  styles = style_list or ["Character training"]
 
431
  core += " " + " ".join(extra_opts)
432
  if NAME_OPTION in (extra_opts or []):
433
  core = core.replace("{name}", (name_value or "{NAME}").strip())
434
+ if "Aesthetic tags (comma-sep)" not in styles:
435
  lh = _length_hint(length_choice or "any")
436
  if lh:
437
+ core += " " + lh
438
  return core
439
 
440
 
 
454
 
455
  @torch.no_grad()
456
  def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
 
 
 
457
  model, device, dtype = get_model()
458
  inputs = _build_inputs(im, instr, dtype)
459
  inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
 
490
  top_p: float,
491
  max_tokens: int,
492
  max_side: int,
493
+ time_budget_s: float | None = None,
494
+ progress: gr.Progress = gr.Progress(track_tqdm=True),
495
  ) -> Tuple[List[dict], list, list, str, List[str], int, int]:
496
  """
497
  Returns:
 
545
  total,
546
  )
547
 
 
 
548
 
549
  # ------------------------------
550
+ # 9) Export helpers (CSV/XLSX/TXT ZIP)
551
  # ------------------------------
552
  def _rows_to_table(rows: List[dict]) -> list:
553
  return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])]
554
 
555
  def _table_to_rows(table_value: Any, rows: List[dict]) -> List[dict]:
 
556
  tbl = table_value or []
557
  new = []
558
  for i, r in enumerate(rows or []):
 
668
  z.write(os.path.join(TXT_EXPORT_DIR, fn), arcname=fn)
669
  return zpath
670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
671
 
672
  # ------------------------------
673
  # 10) UI header helper (fixed logo size)
 
690
  width: auto;
691
  object-fit: contain;
692
  display: block;
693
+ max-width: 320px; /* cap very wide logos */
694
  }}
695
  @media (max-width: 640px) {{
696
+ .cf-logo {{ height: {max(48, int(px) - 8)}px; }}
697
  }}
698
  </style>
699
  """
700
 
701
+
702
  # ------------------------------
703
  # 11) UI (Blocks)
704
  # ------------------------------
 
710
  .cf-hero{display:flex; align-items:center; justify-content:center; gap:16px;
711
  margin:4px 0 12px; text-align:center;}
712
  .cf-hero .cf-text{text-align:center;}
713
+ .cf-title{margin:0;font-size:3.0rem;line-height:1;letter-spacing:.2px}
714
+ .cf-sub{margin:6px 0 0;font-size:1.05rem;color:#cfd3da}
715
 
716
  /* Results area + robust scrollbars */
717
  .cf-scroll{border:1px solid #e6e6e6; border-radius:10px; padding:8px}
 
722
  """
723
 
724
  with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
 
 
 
725
  # ---- Header
726
  settings = load_settings()
727
+ header_html = gr.HTML(_render_header_html(settings.get("logo_px", LOGO_HEIGHT_PX)))
 
 
728
 
729
  # ---- Controls group
730
  with gr.Group():
 
753
  trig = gr.Textbox(label="Trigger word", value=settings.get("trigger",""))
754
  add_start = gr.Textbox(label="Add text to start", value=settings.get("begin",""))
755
  add_end = gr.Textbox(label="Add text to end", value=settings.get("end",""))
 
756
 
757
+ # RIGHT: instructions + dataset + general sliders
758
  with gr.Column(scale=1):
759
  with gr.Accordion("Model Instructions", open=False):
760
  instruction_preview = gr.Textbox(
761
+ label=None,
762
  lines=12,
763
  value=final_instruction(
764
  settings.get("styles", ["Character training"]),
765
  settings.get("extras", []),
766
  settings.get("name",""),
767
+ settings.get("caption_length", "long"),
768
  ),
769
  )
770
  dataset_name = gr.Textbox(label="Dataset name (export title prefix)",
 
781
  gpu_budget = gr.Slider(20, 110, value=55, step=5, label="Max seconds per GPU call")
782
  no_time_limit = gr.Checkbox(value=False, label="No time limit (ignore above)")
783
 
 
784
  # Persist instruction + general settings
785
  def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms, cap_len):
786
  instr = final_instruction(styles or ["Character training"], extra or [], name_value, cap_len)
787
  cfg = load_settings()
788
  cfg.update({
789
  "styles": styles or ["Character training"],
790
+ "extras": _valid_extras(extra),
791
  "name": name_value,
792
  "trigger": trigv, "begin": begv, "end": endv,
793
  "excel_thumb_px": int(excel_px),
 
798
  return instr
799
 
800
  for comp in [style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side, caption_length]:
801
+ comp.change(
802
+ _refresh_instruction,
803
+ inputs=[style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side, caption_length],
804
+ outputs=[instruction_preview]
805
+ )
806
 
807
  def _save_dataset_name(name):
808
  cfg = load_settings()
 
811
  return gr.update()
812
  dataset_name.change(_save_dataset_name, inputs=[dataset_name], outputs=[])
813
 
814
+ # ---- Shape Aliases (with plural matching + persist)
 
815
  with gr.Accordion("Shape Aliases", open=False):
816
  gr.Markdown(
817
  "### 🔷 Shape Aliases\n"
 
819
  "**How to use:**\n"
820
  "- Left column = a single token **or** comma/pipe-separated synonyms, e.g. `diamond, rhombus | lozenge`\n"
821
  "- Right column = replacement name, e.g. `family-emblem`\n"
822
+ "Matches are case-insensitive, catches simple plurals, and also matches `*-shaped` / `* shaped` variants."
 
823
  )
824
 
825
  init_rows, init_enabled = get_shape_alias_rows_ui_defaults()
 
843
  clear_btn = gr.Button("Clear", variant="secondary")
844
  save_btn = gr.Button("💾 Save", variant="primary")
845
 
 
846
  save_status = gr.Markdown("")
847
 
 
848
  def _add_row(cur):
849
  cur = (cur or []) + [["", ""]]
850
  return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic"))
 
862
  return gr.update()
863
  persist_aliases.change(_save_alias_persist_flag, inputs=[persist_aliases], outputs=[])
864
 
 
865
  save_btn.click(
866
  save_shape_alias_rows,
867
  inputs=[enable_aliases, alias_table, persist_aliases],
 
880
  outputs=[single_caption_out]
881
  )
882
 
 
 
 
 
 
 
 
 
 
883
  with gr.Tab("Batch"):
884
  with gr.Accordion("Uploaded images", open=True):
885
+ input_files = gr.File(
886
+ label="Drop images (or click to select)",
887
+ file_types=["image"],
888
+ file_count="multiple",
889
+ type="filepath"
890
+ )
 
 
891
  run_button = gr.Button("Caption batch", variant="primary")
892
 
 
893
  # ---- Results area (gallery left / table right)
894
  rows_state = gr.State(load_session())
895
  autosave_md = gr.Markdown("Ready.")
 
912
  headers=["filename", "caption"],
913
  interactive=True,
914
  wrap=True,
915
+ type="array",
916
  elem_id="cfTable"
917
  )
918
 
 
935
  export_txt_btn = gr.Button("Export captions as .txt (zip)")
936
  txt_zip = gr.File(label="TXT zip", visible=False)
937
 
938
+ # ---- Robust scroll sync
939
  gr.HTML("""
940
  <script>
941
  (function () {
 
944
  if (!host) return null;
945
  return host.querySelector('[data-testid="gallery"]') || host;
946
  }
947
+ function findTbl() { return document.querySelector("#cfTableWrap"); }
 
 
948
  function syncScroll(a, b) {
949
  if (!a || !b) return;
950
  let lock = false;
 
988
  files = files or []
989
  budget = None if no_limit else float(budget_s)
990
 
 
991
  if mode == "Manual (step)" and files:
992
  chunks = _split_chunks(files, int(csize))
993
  batch = chunks[0]
 
1001
  prog = f"Batch progress: {done}/{total} processed in this step • Remaining overall: {len(remaining)}"
1002
  return new_rows, gal, tbl, stamp, remaining, panel_vis, gr.update(value=msg), gr.update(value=prog)
1003
 
1004
+ # Auto
1005
  new_rows, gal, tbl, stamp, leftover, done, total = run_batch(
1006
  files, rows or [], instr, t, p, m, int(ms), budget
1007
  )
 
1013
  run_button.click(
1014
  _run_click,
1015
  inputs=[input_files, rows_state, instruction_preview, max_side, chunk_mode, chunk_size, gpu_budget, no_time_limit],
1016
+ outputs=[rows_state, gallery, table, autosave_md, remaining_state, step_panel, step_msg, progress_md]
 
 
 
 
 
 
 
 
 
 
 
 
 
1017
  )
1018
+
1019
  def _step_next(remain, rows, instr, ms, csize, budget_s, no_limit):
1020
  t, p, m = _tpms()
1021
  remain = remain or []
 
1025
  return (
1026
  rows,
1027
  gr.update(value="No files remaining."),
1028
+ gr.update(visible=False),
1029
  [],
1030
  [],
1031
  [],
 
1061
  gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
1062
  for r in session_rows if (r.get("thumb_path") or r.get("path"))]
1063
  return session_rows, gallery_pairs, f"Saved • {time.strftime('%H:%M:%S')}"
 
 
 
 
1064
 
1065
+ table.change(sync_table_to_session, inputs=[table, rows_state], outputs=[rows_state, gallery, autosave_md])
 
 
 
 
1066
 
1067
  # ---- Exports
1068
  export_csv_btn.click(
 
1080
 
1081
 
1082
  # ------------------------------
1083
+ # 12) Launch
1084
  # ------------------------------
1085
  if __name__ == "__main__":
1086
  demo.queue(max_size=64).launch(
 
1089
  ssr_mode=False,
1090
  debug=True,
1091
  show_error=True,
 
1092
  allowed_paths=[THUMB_CACHE, EXCEL_THUMB_DIR, TXT_EXPORT_DIR],
1093
  )