JS6969 commited on
Commit
f4c864c
·
verified ·
1 Parent(s): 6231519

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -237
app.py CHANGED
@@ -1,9 +1,9 @@
1
- import os, io, csv, time, json, hashlib, base64, zipfile, re
2
  from typing import List, Tuple, Dict, Any
3
 
4
- # ────────────────────────────────────────────────────────
5
- # Cache locations (kept simple / persistent)
6
- # ────────────────────────────────────────────────────────
7
  os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")
8
  os.makedirs(os.environ["HF_HOME"], exist_ok=True)
9
 
@@ -12,17 +12,13 @@ from PIL import Image
12
  import torch
13
  from transformers import LlavaForConditionalGeneration, AutoProcessor
14
 
15
- # Try to import spaces and define a GPU decorator that works on CPU too
16
  try:
17
  import spaces
18
  gpu = spaces.GPU()
19
- except Exception:
20
- def gpu(f): # no-op on CPU / local
21
- return f
22
 
23
- # ────────────────────────────────────────────────────────
24
- # Paths & files
25
- # ────────────────────────────────────────────────────────
26
  APP_DIR = os.getcwd()
27
  SESSION_FILE = "/tmp/session.json"
28
  SETTINGS_FILE = "/tmp/cf_settings.json"
@@ -32,37 +28,50 @@ EXCEL_THUMB_DIR = "/tmp/forge_excel_thumbs"
32
  os.makedirs(THUMB_CACHE, exist_ok=True)
33
  os.makedirs(EXCEL_THUMB_DIR, exist_ok=True)
34
 
35
- # ────────────────────────────────────────────────────────
36
- # Model setup
37
- # ────────────────────────────────────────────────────────
38
  MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
39
 
40
- def _detect_gpu():
41
- if torch.cuda.is_available():
42
- p = torch.cuda.get_device_properties(0)
43
- return "cuda", int(p.total_memory/(1024**3)), p.name
44
- return "cpu", 0, "CPU"
45
 
46
- BACKEND, VRAM_GB, GPU_NAME = _detect_gpu()
47
- DEVICE = "cuda" if BACKEND == "cuda" else "cpu"
48
- DTYPE = torch.bfloat16 if BACKEND == "cuda" else torch.float32
49
- MAX_SIDE_CAP = 1024 if BACKEND == "cuda" else 640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- processor = AutoProcessor.from_pretrained(MODEL_PATH)
52
- model = LlavaForConditionalGeneration.from_pretrained(
53
- MODEL_PATH,
54
- torch_dtype=DTYPE,
55
- low_cpu_mem_usage=True,
56
- device_map=0 if BACKEND == "cuda" else "cpu",
57
- )
58
- model.eval()
59
-
60
- print(f"[ForgeCaptions] Backend={BACKEND} GPU={GPU_NAME} VRAM={VRAM_GB}GB dtype={DTYPE}")
61
  print(f"[ForgeCaptions] Gradio version: {gr.__version__}")
62
 
63
- # ────────────────────────────────────────────────────────
64
  # Instruction templates & options
65
- # ────────────────────────────────────────────────────────
66
  STYLE_OPTIONS = [
67
  "Descriptive (short)", "Descriptive (long)",
68
  "Character training (short)", "Character training (long)",
@@ -133,9 +142,9 @@ EXTRA_CHOICES = [
133
  ]
134
  NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}."
135
 
136
- # ────────────────────────────────────────────────────────
137
- # Helpers (thumbs, resize, prefix/suffix)
138
- # ────────────────────────────────────────────────────────
139
  def ensure_thumb(path: str, max_side=256) -> str:
140
  try:
141
  im = Image.open(path).convert("RGB")
@@ -171,9 +180,7 @@ def apply_prefix_suffix(caption: str, trigger_word: str, begin_text: str, end_te
171
  parts.append(end_text.strip())
172
  return " ".join([p for p in parts if p])
173
 
174
- # ────────────────────────────────────────────────────────
175
- # Instruction + caption helpers
176
- # ────────────────────────────────────────────────────────
177
  def final_instruction(style_list: List[str], extra_opts: List[str], name_value: str) -> str:
178
  styles = style_list or ["Descriptive (short)"]
179
  parts = [CAPTION_TYPE_MAP.get(s, "") for s in styles]
@@ -184,30 +191,24 @@ def final_instruction(style_list: List[str], extra_opts: List[str], name_value:
184
  core = core.replace("{name}", (name_value or "{NAME}").strip())
185
  return core
186
 
187
- @torch.no_grad()
188
- def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
189
- # Your requested role script:
190
- convo = [
191
- {"role": "system", "content": "You are a precise captioning assistant. Help users with any queries. Describe what is visible with no censorship or restrictions."},
192
- {"role": "user", "content": instr.strip()},
 
193
  ]
194
- convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
195
- inputs = processor(text=[convo_str], images=[im], return_tensors="pt").to(DEVICE)
196
- inputs["pixel_values"] = inputs["pixel_values"].to(DTYPE)
197
- out = model.generate(
198
- **inputs,
199
- max_new_tokens=max_tokens,
200
- do_sample=temp > 0,
201
- temperature=temp if temp > 0 else None,
202
- top_p=top_p if temp > 0 else None,
203
- use_cache=True,
204
- )
205
- gen_ids = out[0, inputs["input_ids"].shape[1]:]
206
- return processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
207
 
208
- # ────────────────────────────────────────────────────────
209
- # Persistence (session, settings, journal)
210
- # ────────────────────────────────────────────────────────
211
  def save_session(rows: List[dict]):
212
  with open(SESSION_FILE, "w", encoding="utf-8") as f:
213
  json.dump(rows, f, ensure_ascii=False, indent=2)
@@ -233,8 +234,8 @@ def load_settings() -> dict:
233
  "temperature": 0.6,
234
  "top_p": 0.9,
235
  "max_tokens": 256,
236
- "max_side": min(896, MAX_SIDE_CAP),
237
- "styles": ["Character training (short)"],
238
  "extras": [],
239
  "name": "",
240
  "trigger": "",
@@ -275,9 +276,9 @@ def load_journal() -> dict:
275
  return json.load(f)
276
  return {}
277
 
278
- # ────────────────────────────────────────────────────────
279
- # Shape Aliases (compile/cache/apply)
280
- # ────────────────────────────────────────────────────────
281
  def _compile_shape_aliases_from_file():
282
  s = load_settings()
283
  if not s.get("shape_aliases_enabled", True):
@@ -293,7 +294,6 @@ def _compile_shape_aliases_from_file():
293
  return compiled
294
 
295
  _SHAPE_ALIASES = _compile_shape_aliases_from_file()
296
-
297
  def _refresh_shape_aliases_cache():
298
  global _SHAPE_ALIASES
299
  _SHAPE_ALIASES = _compile_shape_aliases_from_file()
@@ -314,7 +314,6 @@ def get_shape_alias_rows_ui_defaults():
314
  def save_shape_alias_rows(enabled, df_rows):
315
  cfg = load_settings()
316
  cfg["shape_aliases_enabled"] = bool(enabled)
317
-
318
  cleaned = []
319
  for r in (df_rows or []):
320
  if not r:
@@ -323,20 +322,132 @@ def save_shape_alias_rows(enabled, df_rows):
323
  name = (r[1] or "").strip()
324
  if shape and name:
325
  cleaned.append({"shape": shape, "name": name})
326
-
327
  cfg["shape_aliases"] = cleaned
328
  save_settings(cfg)
329
  _refresh_shape_aliases_cache()
330
-
331
  normalized = [[it["shape"], it["name"]] for it in cleaned] + [["", ""]]
332
- return (
333
- "✅ Saved shape alias options.",
334
- gr.update(value=normalized, row_count=(max(1, len(normalized)), "dynamic"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
- # ────────────────────────────────────────────────────────
338
- # Exports
339
- # ────────────────────────────────────────────────────────
340
  def export_csv_from_table(table_value: Any) -> str:
341
  data = table_value or []
342
  out = f"/tmp/forgecaptions_{int(time.time())}.csv"
@@ -347,7 +458,6 @@ def export_csv_from_table(table_value: Any) -> str:
347
  return out
348
 
349
  def _resize_for_excel(path: str, px: int) -> str:
350
- """Create a temp resized copy for Excel embedding."""
351
  try:
352
  im = Image.open(path).convert("RGB")
353
  except Exception:
@@ -371,7 +481,6 @@ def export_excel_with_thumbs(table_value: Any, session_rows: List[dict], thumb_p
371
  except Exception as e:
372
  raise RuntimeError("Excel export requires 'openpyxl' in requirements.txt.") from e
373
 
374
- # Respect user edits (table wins)
375
  caption_by_file = {}
376
  for row in (table_value or []):
377
  if not row:
@@ -387,9 +496,7 @@ def export_excel_with_thumbs(table_value: Any, session_rows: List[dict], thumb_p
387
  ws.column_dimensions["B"].width = 42
388
  ws.column_dimensions["C"].width = 100
389
 
390
- # px→points (~0.75 pt per screen px @ ~96dpi)
391
- row_h = int(thumb_px * 0.75)
392
-
393
  r_i = 2
394
  for r in (session_rows or []):
395
  fn = r.get("filename","")
@@ -411,72 +518,6 @@ def export_excel_with_thumbs(table_value: Any, session_rows: List[dict], thumb_p
411
  wb.save(out)
412
  return out
413
 
414
- # Rows<->Table helpers
415
- def _rows_to_table(rows: List[dict]) -> list:
416
- return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])]
417
-
418
- def _table_to_rows(table_value: Any, rows: List[dict]) -> List[dict]:
419
- tbl = table_value or []
420
- new = []
421
- for i, r in enumerate(rows or []):
422
- r = dict(r)
423
- if i < len(tbl) and len(tbl[i]) >= 2:
424
- r["filename"] = str(tbl[i][0]) if tbl[i][0] is not None else r.get("filename","")
425
- r["caption"] = str(tbl[i][1]) if tbl[i][1] is not None else r.get("caption","")
426
- new.append(r)
427
- return new
428
-
429
- # ────────────────────────────────────────────────────────
430
- # Batch captioning (GPU) + sync
431
- # ────────────────────────────────────────────────────────
432
- @gpu
433
- @torch.no_grad()
434
- def run_batch(
435
- files: List[Any],
436
- session_rows: List[dict],
437
- instr_text: str,
438
- temp: float,
439
- top_p: float,
440
- max_tokens: int,
441
- max_side: int,
442
- ) -> Tuple[List[dict], list, list, str]:
443
- if torch.cuda.is_available():
444
- torch.cuda.empty_cache()
445
-
446
- session_rows = session_rows or []
447
- files = files or []
448
- if not files:
449
- gallery_pairs = [
450
- ((r.get("thumb_path") or r.get("path")), r.get("caption",""))
451
- for r in session_rows if (r.get("thumb_path") or r.get("path"))
452
- ]
453
- return session_rows, gallery_pairs, _rows_to_table(session_rows), f"Saved • {time.strftime('%H:%M:%S')}"
454
-
455
- for f in files:
456
- path = f if isinstance(f, str) else getattr(f, "name", None) or getattr(f, "path", None)
457
- if not path or not os.path.exists(path):
458
- continue
459
- try:
460
- im = Image.open(path).convert("RGB")
461
- except Exception:
462
- continue
463
- im = resize_for_model(im, max_side)
464
- cap = caption_once(im, instr_text, temp, top_p, max_tokens)
465
- cap = apply_shape_aliases(cap)
466
- s = load_settings()
467
- cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end",""))
468
-
469
- filename = os.path.basename(path)
470
- thumb = ensure_thumb(path, 256)
471
- session_rows.append({"filename": filename, "caption": cap, "path": path, "thumb_path": thumb})
472
-
473
- save_session(session_rows)
474
- gallery_pairs = [
475
- ((r.get("thumb_path") or r.get("path")), r.get("caption",""))
476
- for r in session_rows if (r.get("thumb_path") or r.get("path"))
477
- ]
478
- return session_rows, gallery_pairs, _rows_to_table(session_rows), f"Saved • {time.strftime('%H:%M:%S')}"
479
-
480
  def sync_table_to_session(table_value: Any, session_rows: List[dict]) -> Tuple[List[dict], list, str]:
481
  session_rows = _table_to_rows(table_value, session_rows or [])
482
  save_session(session_rows)
@@ -486,61 +527,28 @@ def sync_table_to_session(table_value: Any, session_rows: List[dict]) -> Tuple[L
486
  ]
487
  return session_rows, gallery_pairs, f"Saved • {time.strftime('%H:%M:%S')}"
488
 
489
- # Tiny GPU warmup for HF Spaces detection
490
- @gpu
491
- @torch.no_grad()
492
- def _gpu_startup_warm():
493
- try:
494
- im = Image.new("RGB", (64, 64), (127, 127, 127))
495
- _ = caption_once(im, "Warm up.", temp=0.0, top_p=1.0, max_tokens=8)
496
- print("[ForgeCaptions] GPU warmup complete")
497
- except Exception as e:
498
- print("[ForgeCaptions] GPU warmup skipped:", e)
499
-
500
- # ────────────────────────────────────────────────────────
501
  # UI
502
- # ────────────────────────────────────────────────────────
503
  BASE_CSS = """
504
  :root{--galleryW:50%;--tableW:50%;}
505
  .gradio-container{max-width:100%!important}
506
- .cf-hero{
507
- display:flex; align-items:center; justify-content:center; gap:16px;
508
- margin:4px 0 12px; text-align:center;
509
- }
510
  .cf-hero > div { text-align:center; }
511
  .cf-logo{height:calc(3.25rem + 3 * 1.1rem + 18px);width:auto;object-fit:contain}
512
  .cf-title{margin:0;font-size:3.25rem;line-height:1;letter-spacing:.2px}
513
  .cf-sub{margin:6px 0 0;font-size:1.1rem;color:#cfd3da}
514
- .cf-row{display:flex;gap:12px}
515
- .cf-col-gallery{flex:0 0 var(--galleryW)}
516
- .cf-col-table{flex:0 0 var(--tableW)}
517
- /* Shared scroll look */
518
  .cf-scroll{max-height:70vh; overflow-y:auto; border:1px solid #e6e6e6; border-radius:10px; padding:8px}
519
- /* Uniform sizes */
520
  #cfGal .grid > div { height: 96px; }
521
  """
522
 
523
- def logo_b64_img() -> str:
524
- candidates = [
525
- os.path.join(APP_DIR, "forgecaptions-logo.png"),
526
- os.path.join(APP_DIR, "captionforge-logo.png"),
527
- "/home/user/app/forgecaptions-logo.png",
528
- "forgecaptions-logo.png",
529
- "captionforge-logo.png",
530
- ]
531
- for p in candidates:
532
- if os.path.exists(p):
533
- with open(p, "rb") as f:
534
- b64 = base64.b64encode(f.read()).decode("ascii")
535
- return f"<img src='data:image/png;base64,{b64}' alt='ForgeCaptions' class='cf-logo'>"
536
- return ""
537
-
538
  with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
539
- # Ensure HF GPU detection runs once UI starts
540
  demo.load(_gpu_startup_warm, inputs=None, outputs=None)
541
 
542
  settings = load_settings()
543
- settings["styles"] = [s for s in settings.get("styles", []) if s in STYLE_OPTIONS] or ["Character training (short)"]
544
 
545
  gr.HTML(value=f"""
546
  <div class="cf-hero">
@@ -552,43 +560,45 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
552
  <div class="cf-sub">CSV / Excel export</div>
553
  </div>
554
  </div>
555
- <hr>
556
- """)
557
 
558
- # ===== Controls (top)
559
  with gr.Group():
560
  with gr.Row():
561
  with gr.Column(scale=2):
562
- style_checks = gr.CheckboxGroup(
563
- choices=STYLE_OPTIONS,
564
- value=settings.get("styles", ["Character training (short)"]),
565
- label="Caption style (choose one or combine)"
566
- )
567
- with gr.Accordion("Extra options", open=True):
 
568
  extra_opts = gr.CheckboxGroup(
569
  choices=[NAME_OPTION] + EXTRA_CHOICES,
570
  value=settings.get("extras", []),
571
  label=None
572
  )
573
- with gr.Accordion("Name & Prefix/Suffix", open=True):
574
  name_input = gr.Textbox(label="Person / Character Name", value=settings.get("name", ""))
575
  trig = gr.Textbox(label="Trigger word", value=settings.get("trigger",""))
576
  add_start = gr.Textbox(label="Add text to start", value=settings.get("begin",""))
577
  add_end = gr.Textbox(label="Add text to end", value=settings.get("end",""))
578
 
579
  with gr.Column(scale=1):
580
- instruction_preview = gr.Textbox(label="Model Instructions", lines=12)
581
- dataset_name = gr.Textbox(label="Dataset name (used for export file titles)", value=settings.get("dataset_name", "forgecaptions"))
582
- max_side = gr.Slider(256, MAX_SIDE_CAP, settings.get("max_side", min(896, MAX_SIDE_CAP)), step=32, label="Max side (resize)")
583
- excel_thumb_px = gr.Slider(64, 256, value=settings.get("excel_thumb_px", 128), step=8, label="Excel thumbnail size (px)")
584
- gr.Markdown("Generation (settings): temperature 0.6 top-p 0.9 max_tokens 256")
 
 
 
585
 
586
- # Persist options + live instruction
587
  def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms):
588
- instr = final_instruction(styles or ["Character training (short)"], extra or [], name_value)
589
  cfg = load_settings()
590
  cfg.update({
591
- "styles": styles or ["Character training (short)"],
592
  "extras": extra or [],
593
  "name": name_value,
594
  "trigger": trigv, "begin": begv, "end": endv,
@@ -599,16 +609,14 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
599
  return instr
600
 
601
  for comp in [style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side]:
602
- comp.change(
603
- _refresh_instruction,
604
- inputs=[style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side],
605
- outputs=[instruction_preview]
606
- )
607
 
608
- demo.load(lambda s,e,n: final_instruction(s or ["Character training (short)"], e or [], n),
609
  inputs=[style_checks, extra_opts, name_input], outputs=[instruction_preview])
610
 
611
- # ===== Shape Aliases (improved UX: add row / clear / save)
612
  with gr.Accordion("Shape Aliases", open=False):
613
  gr.Markdown(
614
  "### 🔷 Shape Aliases\n"
@@ -622,7 +630,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
622
  value=init_rows,
623
  col_count=(2, "fixed"),
624
  row_count=(max(1, len(init_rows)), "dynamic"),
625
- datatype=["str", "str"],
626
  type="array",
627
  interactive=True
628
  )
@@ -640,28 +648,14 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
640
  clear_btn.click(_clear_rows, outputs=[alias_table])
641
  save_btn.click(save_shape_alias_rows, inputs=[enable_aliases, alias_table], outputs=[save_status, alias_table])
642
 
643
- # ===== Tabs (Single + Batch)
644
  with gr.Tabs():
645
  with gr.Tab("Single"):
646
  input_image_single = gr.Image(type="pil", label="Input Image", height=512, width=512)
647
  single_caption_btn = gr.Button("Caption")
648
  single_caption_out = gr.Textbox(label="Caption (single)")
649
-
650
- def _caption_single(img, instr):
651
- if img is None:
652
- return "No image provided."
653
- s = load_settings()
654
- im = resize_for_model(img, int(s.get("max_side", MAX_SIDE_CAP)))
655
- t = s.get("temperature", 0.6)
656
- p = s.get("top_p", 0.9)
657
- m = s.get("max_tokens", 256)
658
- cap = caption_once(im, instr, t, p, m)
659
- cap = apply_shape_aliases(cap)
660
- cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end",""))
661
- return cap
662
-
663
  single_caption_btn.click(
664
- _caption_single,
665
  inputs=[input_image_single, instruction_preview],
666
  outputs=[single_caption_out]
667
  )
@@ -671,7 +665,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
671
  input_files = gr.File(label="Drop images", file_types=["image"], file_count="multiple", type="filepath")
672
  run_button = gr.Button("Caption batch", variant="primary")
673
 
674
- # ===== Results + Table (kept in the same place)
675
  rows_state = gr.State(load_session())
676
  autosave_md = gr.Markdown("Ready.")
677
 
@@ -705,10 +699,10 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
705
  export_xlsx_btn = gr.Button("Export Excel (.xlsx) with thumbnails")
706
  xlsx_file = gr.File(label="Excel file", visible=False)
707
 
708
- # Initial gallery render
709
  def _initial_gallery(rows):
710
  rows = rows or []
711
- return [((r.get("thumb_path") or r.get("path")), r.get("caption","")) for r in rows if (r.get("thumb_path") or r.get("path"))]
 
712
  demo.load(_initial_gallery, inputs=[rows_state], outputs=[gallery])
713
 
714
  # Scroll sync
@@ -756,7 +750,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
756
  </script>
757
  """)
758
 
759
- # Run batch
760
  def _run_click(files, rows, instr, ms):
761
  s = load_settings()
762
  t = s.get("temperature", 0.6)
@@ -771,14 +765,14 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
771
  outputs=[rows_state, gallery, table, autosave_md]
772
  )
773
 
774
- # Table edits sync → rows + gallery
775
  table.change(
776
  sync_table_to_session,
777
  inputs=[table, rows_state],
778
  outputs=[rows_state, gallery, autosave_md]
779
  )
780
 
781
- # Exports (use slider value)
782
  export_csv_btn.click(
783
  lambda tbl: (export_csv_from_table(tbl), gr.update(visible=True)),
784
  inputs=[table], outputs=[csv_file, csv_file]
@@ -788,7 +782,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
788
  inputs=[table, rows_state, excel_thumb_px], outputs=[xlsx_file, xlsx_file]
789
  )
790
 
791
- # Launch (disable experimental SSR to reduce churn)
792
  if __name__ == "__main__":
793
  demo.queue(max_size=64).launch(
794
  server_name="0.0.0.0",
 
1
+ import os, io, csv, time, json, base64, re
2
  from typing import List, Tuple, Dict, Any
3
 
4
+ # ---------------------------------------------------------------------
5
+ # Caching
6
+ # ---------------------------------------------------------------------
7
  os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")
8
  os.makedirs(os.environ["HF_HOME"], exist_ok=True)
9
 
 
12
  import torch
13
  from transformers import LlavaForConditionalGeneration, AutoProcessor
14
 
15
+ # ── HF Spaces GPU decorator (no-op on CPU/local) ─────────────────────
16
  try:
17
  import spaces
18
  gpu = spaces.GPU()
19
+ except Exception: # local/CPU
20
+ def gpu(f): return f
 
21
 
 
 
 
22
  APP_DIR = os.getcwd()
23
  SESSION_FILE = "/tmp/session.json"
24
  SETTINGS_FILE = "/tmp/cf_settings.json"
 
28
  os.makedirs(THUMB_CACHE, exist_ok=True)
29
  os.makedirs(EXCEL_THUMB_DIR, exist_ok=True)
30
 
31
+ # ---------------------------------------------------------------------
32
+ # Model identifiers
33
+ # ---------------------------------------------------------------------
34
  MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
35
 
36
+ # Load the processor on CPU (safe in stateless env)
37
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
 
 
 
38
 
39
+ # Lazy GPU/CPU model (created inside GPU worker only)
40
+ _MODEL = None
41
+ _DEVICE = "cpu"
42
+ _DTYPE = torch.float32
43
+
44
+ def get_model():
45
+ """Create/reuse model; only call this from inside @gpu functions."""
46
+ global _MODEL, _DEVICE, _DTYPE
47
+ if _MODEL is None:
48
+ if torch.cuda.is_available():
49
+ _DEVICE = "cuda"
50
+ _DTYPE = torch.bfloat16
51
+ _MODEL = LlavaForConditionalGeneration.from_pretrained(
52
+ MODEL_PATH,
53
+ torch_dtype=_DTYPE,
54
+ low_cpu_mem_usage=True,
55
+ device_map=0, # GPU:0 (inside GPU worker process)
56
+ )
57
+ else:
58
+ _DEVICE = "cpu"
59
+ _DTYPE = torch.float32
60
+ _MODEL = LlavaForConditionalGeneration.from_pretrained(
61
+ MODEL_PATH,
62
+ torch_dtype=_DTYPE,
63
+ low_cpu_mem_usage=True,
64
+ device_map="cpu",
65
+ )
66
+ _MODEL.eval()
67
+ print(f"[ForgeCaptions] Model ready on {_DEVICE} dtype={_DTYPE}")
68
+ return _MODEL, _DEVICE, _DTYPE
69
 
 
 
 
 
 
 
 
 
 
 
70
  print(f"[ForgeCaptions] Gradio version: {gr.__version__}")
71
 
72
+ # ---------------------------------------------------------------------
73
  # Instruction templates & options
74
+ # ---------------------------------------------------------------------
75
  STYLE_OPTIONS = [
76
  "Descriptive (short)", "Descriptive (long)",
77
  "Character training (short)", "Character training (long)",
 
142
  ]
143
  NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}."
144
 
145
+ # ---------------------------------------------------------------------
146
+ # Helpers
147
+ # ---------------------------------------------------------------------
148
  def ensure_thumb(path: str, max_side=256) -> str:
149
  try:
150
  im = Image.open(path).convert("RGB")
 
180
  parts.append(end_text.strip())
181
  return " ".join([p for p in parts if p])
182
 
183
+ # Instruction + caption
 
 
184
  def final_instruction(style_list: List[str], extra_opts: List[str], name_value: str) -> str:
185
  styles = style_list or ["Descriptive (short)"]
186
  parts = [CAPTION_TYPE_MAP.get(s, "") for s in styles]
 
191
  core = core.replace("{name}", (name_value or "{NAME}").strip())
192
  return core
193
 
194
+ def logo_b64_img() -> str:
195
+ candidates = [
196
+ os.path.join(APP_DIR, "forgecaptions-logo.png"),
197
+ os.path.join(APP_DIR, "captionforge-logo.png"),
198
+ "/home/user/app/forgecaptions-logo.png",
199
+ "forgecaptions-logo.png",
200
+ "captionforge-logo.png",
201
  ]
202
+ for p in candidates:
203
+ if os.path.exists(p):
204
+ with open(p, "rb") as f:
205
+ b64 = base64.b64encode(f.read()).decode("ascii")
206
+ return f"<img src='data:image/png;base64,{b64}' alt='ForgeCaptions' class='cf-logo'>"
207
+ return ""
 
 
 
 
 
 
 
208
 
209
+ # ---------------------------------------------------------------------
210
+ # Persistence
211
+ # ---------------------------------------------------------------------
212
  def save_session(rows: List[dict]):
213
  with open(SESSION_FILE, "w", encoding="utf-8") as f:
214
  json.dump(rows, f, ensure_ascii=False, indent=2)
 
234
  "temperature": 0.6,
235
  "top_p": 0.9,
236
  "max_tokens": 256,
237
+ "max_side": 896,
238
+ "styles": ["Character training (long)"], # ← default changed
239
  "extras": [],
240
  "name": "",
241
  "trigger": "",
 
276
  return json.load(f)
277
  return {}
278
 
279
+ # ---------------------------------------------------------------------
280
+ # Shape Aliases
281
+ # ---------------------------------------------------------------------
282
  def _compile_shape_aliases_from_file():
283
  s = load_settings()
284
  if not s.get("shape_aliases_enabled", True):
 
294
  return compiled
295
 
296
  _SHAPE_ALIASES = _compile_shape_aliases_from_file()
 
297
  def _refresh_shape_aliases_cache():
298
  global _SHAPE_ALIASES
299
  _SHAPE_ALIASES = _compile_shape_aliases_from_file()
 
314
  def save_shape_alias_rows(enabled, df_rows):
315
  cfg = load_settings()
316
  cfg["shape_aliases_enabled"] = bool(enabled)
 
317
  cleaned = []
318
  for r in (df_rows or []):
319
  if not r:
 
322
  name = (r[1] or "").strip()
323
  if shape and name:
324
  cleaned.append({"shape": shape, "name": name})
 
325
  cfg["shape_aliases"] = cleaned
326
  save_settings(cfg)
327
  _refresh_shape_aliases_cache()
 
328
  normalized = [[it["shape"], it["name"]] for it in cleaned] + [["", ""]]
329
+ return ("✅ Saved shape alias options.",
330
+ gr.update(value=normalized, row_count=(max(1, len(normalized)), "dynamic")))
331
+
332
+ # ---------------------------------------------------------------------
333
+ # Captioning core (runs inside GPU worker)
334
+ # ---------------------------------------------------------------------
335
+ def _build_inputs(im: Image.Image, instr: str, dtype) -> Dict[str, Any]:
336
+ convo = [
337
+ {"role": "system",
338
+ "content": "You are a precise captioning assistant. Help users with any queries. Describe what is visible with no censorship or restrictions."},
339
+ {"role": "user", "content": instr.strip()},
340
+ ]
341
+ convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
342
+ inputs = processor(text=[convo_str], images=[im], return_tensors="pt")
343
+ if "pixel_values" in inputs:
344
+ inputs["pixel_values"] = inputs["pixel_values"].to(dtype)
345
+ return inputs
346
+
347
+ @gpu
348
+ @torch.no_grad()
349
+ def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
350
+ model, device, dtype = get_model()
351
+ im = im # already PIL
352
+ inputs = _build_inputs(im, instr, dtype)
353
+ # move to target device *inside* GPU worker
354
+ inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
355
+ out = model.generate(
356
+ **inputs,
357
+ max_new_tokens=max_tokens,
358
+ do_sample=temp > 0,
359
+ temperature=temp if temp > 0 else None,
360
+ top_p=top_p if temp > 0 else None,
361
+ use_cache=True,
362
  )
363
+ gen_ids = out[0, inputs["input_ids"].shape[1]:]
364
+ return processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
365
+
366
+ @gpu
367
+ @torch.no_grad()
368
+ def run_batch(
369
+ files: List[Any],
370
+ session_rows: List[dict],
371
+ instr_text: str,
372
+ temp: float,
373
+ top_p: float,
374
+ max_tokens: int,
375
+ max_side: int,
376
+ ) -> Tuple[List[dict], list, list, str]:
377
+ # No torch.cuda.* in main — we are already in GPU worker here
378
+ session_rows = session_rows or []
379
+ files = files or []
380
+ if not files:
381
+ gallery_pairs = [
382
+ ((r.get("thumb_path") or r.get("path")), r.get("caption",""))
383
+ for r in session_rows if (r.get("thumb_path") or r.get("path"))
384
+ ]
385
+ return session_rows, gallery_pairs, _rows_to_table(session_rows), f"Saved • {time.strftime('%H:%M:%S')}"
386
+
387
+ for f in files:
388
+ path = f if isinstance(f, str) else getattr(f, "name", None) or getattr(f, "path", None)
389
+ if not path or not os.path.exists(path):
390
+ continue
391
+ try:
392
+ im = Image.open(path).convert("RGB")
393
+ except Exception:
394
+ continue
395
+ im = resize_for_model(im, max_side)
396
+ cap = caption_once(im, instr_text, temp, top_p, max_tokens)
397
+ cap = apply_shape_aliases(cap)
398
+ s = load_settings()
399
+ cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end",""))
400
+ filename = os.path.basename(path)
401
+ thumb = ensure_thumb(path, 256)
402
+ session_rows.append({"filename": filename, "caption": cap, "path": path, "thumb_path": thumb})
403
+
404
+ save_session(session_rows)
405
+ gallery_pairs = [
406
+ ((r.get("thumb_path") or r.get("path")), r.get("caption",""))
407
+ for r in session_rows if (r.get("thumb_path") or r.get("path"))
408
+ ]
409
+ return session_rows, gallery_pairs, _rows_to_table(session_rows), f"Saved • {time.strftime('%H:%M:%S')}"
410
+
411
+ @gpu
412
+ @torch.no_grad()
413
+ def caption_single(img: Image.Image, instr: str) -> str:
414
+ if img is None:
415
+ return "No image provided."
416
+ s = load_settings()
417
+ im = resize_for_model(img, int(s.get("max_side", 896)))
418
+ cap = caption_once(im, instr, s.get("temperature",0.6), s.get("top_p",0.9), s.get("max_tokens",256))
419
+ cap = apply_shape_aliases(cap)
420
+ cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end",""))
421
+ return cap
422
+
423
+ # tiny warmup so Spaces sees a GPU function at startup
424
+ @gpu
425
+ @torch.no_grad()
426
+ def _gpu_startup_warm():
427
+ try:
428
+ im = Image.new("RGB", (64, 64), (127,127,127))
429
+ _ = caption_once(im, "Warm up.", temp=0.0, top_p=1.0, max_tokens=8)
430
+ print("[ForgeCaptions] GPU warmup complete")
431
+ except Exception as e:
432
+ print("[ForgeCaptions] GPU warmup skipped:", e)
433
+
434
+ # ---------------------------------------------------------------------
435
+ # Export helpers
436
+ # ---------------------------------------------------------------------
437
+ def _rows_to_table(rows: List[dict]) -> list:
438
+ return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])]
439
+
440
+ def _table_to_rows(table_value: Any, rows: List[dict]) -> List[dict]:
441
+ tbl = table_value or []
442
+ new = []
443
+ for i, r in enumerate(rows or []):
444
+ r = dict(r)
445
+ if i < len(tbl) and len(tbl[i]) >= 2:
446
+ r["filename"] = str(tbl[i][0]) if tbl[i][0] is not None else r.get("filename","")
447
+ r["caption"] = str(tbl[i][1]) if tbl[i][1] is not None else r.get("caption","")
448
+ new.append(r)
449
+ return new
450
 
 
 
 
451
  def export_csv_from_table(table_value: Any) -> str:
452
  data = table_value or []
453
  out = f"/tmp/forgecaptions_{int(time.time())}.csv"
 
458
  return out
459
 
460
  def _resize_for_excel(path: str, px: int) -> str:
 
461
  try:
462
  im = Image.open(path).convert("RGB")
463
  except Exception:
 
481
  except Exception as e:
482
  raise RuntimeError("Excel export requires 'openpyxl' in requirements.txt.") from e
483
 
 
484
  caption_by_file = {}
485
  for row in (table_value or []):
486
  if not row:
 
496
  ws.column_dimensions["B"].width = 42
497
  ws.column_dimensions["C"].width = 100
498
 
499
+ row_h = int(int(thumb_px) * 0.75) # px→pt-ish
 
 
500
  r_i = 2
501
  for r in (session_rows or []):
502
  fn = r.get("filename","")
 
518
  wb.save(out)
519
  return out
520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
  def sync_table_to_session(table_value: Any, session_rows: List[dict]) -> Tuple[List[dict], list, str]:
522
  session_rows = _table_to_rows(table_value, session_rows or [])
523
  save_session(session_rows)
 
527
  ]
528
  return session_rows, gallery_pairs, f"Saved • {time.strftime('%H:%M:%S')}"
529
 
530
+ # ---------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
531
  # UI
532
+ # ---------------------------------------------------------------------
533
  BASE_CSS = """
534
  :root{--galleryW:50%;--tableW:50%;}
535
  .gradio-container{max-width:100%!important}
536
+ .cf-hero{display:flex; align-items:center; justify-content:center; gap:16px;
537
+ margin:4px 0 12px; text-align:center;}
 
 
538
  .cf-hero > div { text-align:center; }
539
  .cf-logo{height:calc(3.25rem + 3 * 1.1rem + 18px);width:auto;object-fit:contain}
540
  .cf-title{margin:0;font-size:3.25rem;line-height:1;letter-spacing:.2px}
541
  .cf-sub{margin:6px 0 0;font-size:1.1rem;color:#cfd3da}
 
 
 
 
542
  .cf-scroll{max-height:70vh; overflow-y:auto; border:1px solid #e6e6e6; border-radius:10px; padding:8px}
 
543
  #cfGal .grid > div { height: 96px; }
544
  """
545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
  with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
547
+ # ensure Spaces sees a GPU function at start (without touching CUDA in main)
548
  demo.load(_gpu_startup_warm, inputs=None, outputs=None)
549
 
550
  settings = load_settings()
551
+ settings["styles"] = [s for s in settings.get("styles", []) if s in STYLE_OPTIONS] or ["Character training (long)"]
552
 
553
  gr.HTML(value=f"""
554
  <div class="cf-hero">
 
560
  <div class="cf-sub">CSV / Excel export</div>
561
  </div>
562
  </div>
563
+ <hr>""")
 
564
 
565
+ # ── Controls
566
  with gr.Group():
567
  with gr.Row():
568
  with gr.Column(scale=2):
569
+ with gr.Accordion("Caption style (choose one or combine)", open=True):
570
+ style_checks = gr.CheckboxGroup(
571
+ choices=STYLE_OPTIONS,
572
+ value=settings.get("styles", ["Character training (long)"]),
573
+ label=None
574
+ )
575
+ with gr.Accordion("Extra options", open=False):
576
  extra_opts = gr.CheckboxGroup(
577
  choices=[NAME_OPTION] + EXTRA_CHOICES,
578
  value=settings.get("extras", []),
579
  label=None
580
  )
581
+ with gr.Accordion("Name & Prefix/Suffix", open=False):
582
  name_input = gr.Textbox(label="Person / Character Name", value=settings.get("name", ""))
583
  trig = gr.Textbox(label="Trigger word", value=settings.get("trigger",""))
584
  add_start = gr.Textbox(label="Add text to start", value=settings.get("begin",""))
585
  add_end = gr.Textbox(label="Add text to end", value=settings.get("end",""))
586
 
587
  with gr.Column(scale=1):
588
+ with gr.Accordion("Model Instructions", open=False):
589
+ instruction_preview = gr.Textbox(label=None, lines=12)
590
+ dataset_name = gr.Textbox(label="Dataset name (export title prefix)",
591
+ value=settings.get("dataset_name", "forgecaptions"))
592
+ max_side = gr.Slider(256, 1024, settings.get("max_side", 896), step=32, label="Max side (resize)")
593
+ excel_thumb_px = gr.Slider(64, 256, value=settings.get("excel_thumb_px", 128),
594
+ step=8, label="Excel thumbnail size (px)")
595
+ gr.Markdown("Generation settings: temperature 0.6 • top-p 0.9 • max tokens 256")
596
 
 
597
  def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms):
598
+ instr = final_instruction(styles or ["Character training (long)"], extra or [], name_value)
599
  cfg = load_settings()
600
  cfg.update({
601
+ "styles": styles or ["Character training (long)"],
602
  "extras": extra or [],
603
  "name": name_value,
604
  "trigger": trigv, "begin": begv, "end": endv,
 
609
  return instr
610
 
611
  for comp in [style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side]:
612
+ comp.change(_refresh_instruction,
613
+ inputs=[style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side],
614
+ outputs=[instruction_preview])
 
 
615
 
616
+ demo.load(lambda s,e,n: final_instruction(s or ["Character training (long)"], e or [], n),
617
  inputs=[style_checks, extra_opts, name_input], outputs=[instruction_preview])
618
 
619
+ # ── Shape Aliases (improved)
620
  with gr.Accordion("Shape Aliases", open=False):
621
  gr.Markdown(
622
  "### 🔷 Shape Aliases\n"
 
630
  value=init_rows,
631
  col_count=(2, "fixed"),
632
  row_count=(max(1, len(init_rows)), "dynamic"),
633
+ datatype=["str","str"],
634
  type="array",
635
  interactive=True
636
  )
 
648
  clear_btn.click(_clear_rows, outputs=[alias_table])
649
  save_btn.click(save_shape_alias_rows, inputs=[enable_aliases, alias_table], outputs=[save_status, alias_table])
650
 
651
+ # ── Tabs: Single & Batch
652
  with gr.Tabs():
653
  with gr.Tab("Single"):
654
  input_image_single = gr.Image(type="pil", label="Input Image", height=512, width=512)
655
  single_caption_btn = gr.Button("Caption")
656
  single_caption_out = gr.Textbox(label="Caption (single)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
  single_caption_btn.click(
658
+ caption_single,
659
  inputs=[input_image_single, instruction_preview],
660
  outputs=[single_caption_out]
661
  )
 
665
  input_files = gr.File(label="Drop images", file_types=["image"], file_count="multiple", type="filepath")
666
  run_button = gr.Button("Caption batch", variant="primary")
667
 
668
+ # ── Results + Table (same position)
669
  rows_state = gr.State(load_session())
670
  autosave_md = gr.Markdown("Ready.")
671
 
 
699
  export_xlsx_btn = gr.Button("Export Excel (.xlsx) with thumbnails")
700
  xlsx_file = gr.File(label="Excel file", visible=False)
701
 
 
702
  def _initial_gallery(rows):
703
  rows = rows or []
704
+ return [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
705
+ for r in rows if (r.get("thumb_path") or r.get("path"))]
706
  demo.load(_initial_gallery, inputs=[rows_state], outputs=[gallery])
707
 
708
  # Scroll sync
 
750
  </script>
751
  """)
752
 
753
+ # Batch run → rows + gallery + table
754
  def _run_click(files, rows, instr, ms):
755
  s = load_settings()
756
  t = s.get("temperature", 0.6)
 
765
  outputs=[rows_state, gallery, table, autosave_md]
766
  )
767
 
768
+ # Table edits sync
769
  table.change(
770
  sync_table_to_session,
771
  inputs=[table, rows_state],
772
  outputs=[rows_state, gallery, autosave_md]
773
  )
774
 
775
+ # Exports
776
  export_csv_btn.click(
777
  lambda tbl: (export_csv_from_table(tbl), gr.update(visible=True)),
778
  inputs=[table], outputs=[csv_file, csv_file]
 
782
  inputs=[table, rows_state, excel_thumb_px], outputs=[xlsx_file, xlsx_file]
783
  )
784
 
785
+ # Launch (SSR off for stability on Spaces)
786
  if __name__ == "__main__":
787
  demo.queue(max_size=64).launch(
788
  server_name="0.0.0.0",