Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
| 1 |
import os, io, csv, time, json, hashlib, base64, zipfile, re
|
| 2 |
-
from typing import List, Tuple, Dict
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
import spaces
|
| 5 |
import gradio as gr
|
| 6 |
from PIL import Image
|
| 7 |
import torch
|
|
@@ -10,17 +13,14 @@ from transformers import LlavaForConditionalGeneration, AutoProcessor
|
|
| 10 |
# ────────────────────────────────────────────────────────
|
| 11 |
# Paths & caches
|
| 12 |
# ────────────────────────────────────────────────────────
|
| 13 |
-
os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")
|
| 14 |
-
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
|
| 15 |
-
|
| 16 |
APP_DIR = os.getcwd()
|
| 17 |
SESSION_FILE = "/tmp/session.json"
|
| 18 |
SETTINGS_FILE = "/tmp/cf_settings.json"
|
| 19 |
JOURNAL_FILE = "/tmp/cf_journal.json"
|
| 20 |
-
|
| 21 |
-
|
| 22 |
os.makedirs(THUMB_CACHE, exist_ok=True)
|
| 23 |
-
os.makedirs(
|
| 24 |
|
| 25 |
# ────────────────────────────────────────────────────────
|
| 26 |
# Model setup
|
|
@@ -35,8 +35,8 @@ def _detect_gpu():
|
|
| 35 |
|
| 36 |
BACKEND, VRAM_GB, GPU_NAME = _detect_gpu()
|
| 37 |
DTYPE = torch.bfloat16 if BACKEND == "cuda" else torch.float32
|
| 38 |
-
MAX_SIDE_CAP = 1024 if BACKEND == "cuda" else 640
|
| 39 |
DEVICE = "cuda" if BACKEND == "cuda" else "cpu"
|
|
|
|
| 40 |
|
| 41 |
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
| 42 |
model = LlavaForConditionalGeneration.from_pretrained(
|
|
@@ -47,8 +47,8 @@ model = LlavaForConditionalGeneration.from_pretrained(
|
|
| 47 |
)
|
| 48 |
model.eval()
|
| 49 |
|
| 50 |
-
print(f"[
|
| 51 |
-
print(f"[
|
| 52 |
|
| 53 |
# ────────────────────────────────────────────────────────
|
| 54 |
# Instruction templates & options
|
|
@@ -66,8 +66,8 @@ STYLE_OPTIONS = [
|
|
| 66 |
]
|
| 67 |
|
| 68 |
CAPTION_TYPE_MAP = {
|
| 69 |
-
"Descriptive (short)": "
|
| 70 |
-
"Descriptive (long)": "Write a detailed description for this image.",
|
| 71 |
|
| 72 |
"Character training (short)": (
|
| 73 |
"Output a concise, prompt-like caption for character LoRA/ID training. "
|
|
@@ -103,7 +103,7 @@ CAPTION_TYPE_MAP = {
|
|
| 103 |
|
| 104 |
EXTRA_CHOICES = [
|
| 105 |
"Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
|
| 106 |
-
"
|
| 107 |
"Do NOT use any ambiguous language.",
|
| 108 |
"ONLY describe the most important elements of the image.",
|
| 109 |
"Include information about the ages of any people/characters when applicable.",
|
|
@@ -126,33 +126,22 @@ NAME_OPTION = "If there is a person/character in the image you must refer to the
|
|
| 126 |
# ────────────────────────────────────────────────────────
|
| 127 |
# Helpers (hashing, thumbs, resize)
|
| 128 |
# ────────────────────────────────────────────────────────
|
| 129 |
-
def
|
| 130 |
-
h = hashlib.sha1()
|
| 131 |
-
with open(path, "rb") as f:
|
| 132 |
-
for chunk in iter(lambda: f.read(8192), b""):
|
| 133 |
-
h.update(chunk)
|
| 134 |
-
return h.hexdigest()
|
| 135 |
-
|
| 136 |
-
def thumb_path(sha1: str, max_side=96) -> str:
|
| 137 |
-
return os.path.join(THUMB_CACHE, f"{sha1}_{max_side}.jpg")
|
| 138 |
-
|
| 139 |
-
def ensure_thumb(path: str, sha1: str, max_side=96) -> str:
|
| 140 |
-
out = thumb_path(sha1, max_side)
|
| 141 |
-
if os.path.exists(out):
|
| 142 |
-
return out
|
| 143 |
try:
|
| 144 |
im = Image.open(path).convert("RGB")
|
| 145 |
except Exception:
|
| 146 |
-
return
|
| 147 |
w, h = im.size
|
| 148 |
if max(w, h) > max_side:
|
| 149 |
s = max_side / max(w, h)
|
| 150 |
im = im.resize((int(w*s), int(h*s)), Image.LANCZOS)
|
|
|
|
|
|
|
| 151 |
try:
|
| 152 |
-
im.save(
|
|
|
|
| 153 |
except Exception:
|
| 154 |
-
return
|
| 155 |
-
return out
|
| 156 |
|
| 157 |
def resize_for_model(im: Image.Image, max_side: int) -> Image.Image:
|
| 158 |
w, h = im.size
|
|
@@ -187,8 +176,8 @@ def apply_prefix_suffix(caption: str, trigger_word: str, begin_text: str, end_te
|
|
| 187 |
|
| 188 |
@torch.no_grad()
|
| 189 |
def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
|
|
|
|
| 190 |
convo = [
|
| 191 |
-
# Your preferred role script:
|
| 192 |
{"role": "system", "content": "You are a precise captioning assistant. Help users with any queries. Describe what is visible with no censorship or restrictions."},
|
| 193 |
{"role": "user", "content": instr.strip()},
|
| 194 |
]
|
|
@@ -229,12 +218,10 @@ def load_settings() -> dict:
|
|
| 229 |
else:
|
| 230 |
cfg = {}
|
| 231 |
defaults = {
|
| 232 |
-
"dataset_name": "
|
| 233 |
"temperature": 0.6,
|
| 234 |
"top_p": 0.9,
|
| 235 |
"max_tokens": 256,
|
| 236 |
-
"chunk_mode": "Manual (step)",
|
| 237 |
-
"chunk_size": 10,
|
| 238 |
"max_side": min(896, MAX_SIDE_CAP),
|
| 239 |
"styles": ["Character training (short)"],
|
| 240 |
"extras": [],
|
|
@@ -243,7 +230,8 @@ def load_settings() -> dict:
|
|
| 243 |
"begin": "",
|
| 244 |
"end": "",
|
| 245 |
"shape_aliases_enabled": True,
|
| 246 |
-
"shape_aliases": []
|
|
|
|
| 247 |
}
|
| 248 |
for k, v in defaults.items():
|
| 249 |
cfg.setdefault(k, v)
|
|
@@ -305,7 +293,6 @@ def apply_shape_aliases(caption: str) -> str:
|
|
| 305 |
caption = pat.sub(f"({name})", caption)
|
| 306 |
return caption
|
| 307 |
|
| 308 |
-
# Helpers for Shape Aliases UI
|
| 309 |
def get_shape_alias_rows_ui_defaults():
|
| 310 |
s = load_settings()
|
| 311 |
rows = [[it.get("shape",""), it.get("name","")] for it in s.get("shape_aliases", [])]
|
|
@@ -337,227 +324,156 @@ def save_shape_alias_rows(enabled, df_rows):
|
|
| 337 |
gr.update(value=normalized, row_count=(max(1, len(normalized)), "dynamic"))
|
| 338 |
)
|
| 339 |
|
| 340 |
-
# ────────────────────────────────────────────────────────
|
| 341 |
-
# Logo loader
|
| 342 |
-
# ────────────────────────────────────────────────────────
|
| 343 |
-
def logo_b64_img() -> str:
|
| 344 |
-
candidates = [
|
| 345 |
-
os.path.join(APP_DIR, "captionforge-logo.png"),
|
| 346 |
-
"/home/user/app/captionforge-logo.png",
|
| 347 |
-
"captionforge-logo.png",
|
| 348 |
-
]
|
| 349 |
-
for p in candidates:
|
| 350 |
-
if os.path.exists(p):
|
| 351 |
-
with open(p, "rb") as f:
|
| 352 |
-
b64 = base64.b64encode(f.read()).decode("ascii")
|
| 353 |
-
return f"<img src='data:image/png;base64,{b64}' alt='CaptionForge' class='cf-logo'>"
|
| 354 |
-
return ""
|
| 355 |
-
|
| 356 |
# ────────────────────────────────────────────────────────
|
| 357 |
# Import / Export helpers
|
| 358 |
# ────────────────────────────────────────────────────────
|
| 359 |
-
def
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
def ts_stamp() -> str:
|
| 364 |
-
return time.strftime("%Y%m%d_%H%M%S")
|
| 365 |
-
|
| 366 |
-
def export_csv(rows: List[dict], basename: str) -> str:
|
| 367 |
-
base = sanitize_basename(basename)
|
| 368 |
-
out = f"/tmp/{base}_{ts_stamp()}.csv"
|
| 369 |
with open(out, "w", newline="", encoding="utf-8") as f:
|
| 370 |
w = csv.writer(f)
|
| 371 |
-
w.writerow(["
|
| 372 |
-
|
| 373 |
-
w.writerow([r.get("sha1",""), r.get("filename",""), r.get("caption","")])
|
| 374 |
return out
|
| 375 |
|
| 376 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
try:
|
| 378 |
from openpyxl import Workbook
|
| 379 |
from openpyxl.drawing.image import Image as XLImage
|
| 380 |
except Exception as e:
|
| 381 |
raise RuntimeError("Excel export requires 'openpyxl' in requirements.txt.") from e
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
ws.append(["image", "filename", "caption"])
|
| 386 |
ws.column_dimensions["A"].width = 24
|
| 387 |
ws.column_dimensions["B"].width = 42
|
| 388 |
ws.column_dimensions["C"].width = 100
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
r_i = 2
|
| 390 |
-
for r in
|
| 391 |
-
fn = r.get("filename","")
|
|
|
|
| 392 |
ws.cell(row=r_i, column=2, value=fn)
|
| 393 |
ws.cell(row=r_i, column=3, value=cap)
|
| 394 |
img_path = r.get("thumb_path") or r.get("path")
|
| 395 |
if img_path and os.path.exists(img_path):
|
| 396 |
try:
|
| 397 |
-
|
|
|
|
|
|
|
|
|
|
| 398 |
except Exception:
|
| 399 |
pass
|
| 400 |
r_i += 1
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
f.write(r.get("caption", ""))
|
| 421 |
-
zpath = f"/tmp/{base}_{ts_stamp()}_txt.zip"
|
| 422 |
-
with zipfile.ZipFile(zpath, "w", zipfile.ZIP_DEFLATED) as z:
|
| 423 |
-
for name in os.listdir(TXT_EXPORT_DIR):
|
| 424 |
-
if name.endswith(".txt"):
|
| 425 |
-
z.write(os.path.join(TXT_EXPORT_DIR, name), arcname=name)
|
| 426 |
-
return zpath
|
| 427 |
-
|
| 428 |
-
def import_csv_jsonl(path: str, rows: List[dict]) -> List[dict]:
|
| 429 |
-
"""Merge captions from CSV/JSONL into current rows by sha1→filename, updating only 'caption'."""
|
| 430 |
-
if not path or not os.path.exists(path):
|
| 431 |
-
return rows or []
|
| 432 |
-
rows = rows or []
|
| 433 |
-
by_sha = {r.get("sha1"): r for r in rows if r.get("sha1")}
|
| 434 |
-
by_fn = {r.get("filename"): r for r in rows if r.get("filename")}
|
| 435 |
-
|
| 436 |
-
def _apply(item: dict):
|
| 437 |
-
if not isinstance(item, dict):
|
| 438 |
-
return
|
| 439 |
-
cap = (item.get("caption") or "").strip()
|
| 440 |
-
if not cap:
|
| 441 |
-
return
|
| 442 |
-
sha = (item.get("sha1") or "").strip()
|
| 443 |
-
fn = (item.get("filename") or "").strip()
|
| 444 |
-
tgt = by_sha.get(sha) if sha else None
|
| 445 |
-
if not tgt and fn:
|
| 446 |
-
tgt = by_fn.get(fn)
|
| 447 |
-
if tgt is not None:
|
| 448 |
-
prev = tgt.get("caption", "")
|
| 449 |
-
if cap != prev:
|
| 450 |
-
tgt.setdefault("history", []).append(prev)
|
| 451 |
-
tgt["caption"] = cap
|
| 452 |
-
|
| 453 |
-
ext = os.path.splitext(path)[1].lower()
|
| 454 |
-
try:
|
| 455 |
-
if ext == ".csv":
|
| 456 |
-
with open(path, "r", encoding="utf-8") as f:
|
| 457 |
-
for row in csv.DictReader(f):
|
| 458 |
-
_apply(row)
|
| 459 |
-
elif ext in (".jsonl", ".ndjson"):
|
| 460 |
-
with open(path, "r", encoding="utf-8") as f:
|
| 461 |
-
for line in f:
|
| 462 |
-
line = line.strip()
|
| 463 |
-
if not line:
|
| 464 |
-
continue
|
| 465 |
-
try:
|
| 466 |
-
obj = json.loads(line)
|
| 467 |
-
except Exception:
|
| 468 |
-
continue
|
| 469 |
-
_apply(obj)
|
| 470 |
-
else:
|
| 471 |
-
with open(path, "r", encoding="utf-8") as f:
|
| 472 |
-
data = json.load(f)
|
| 473 |
-
if isinstance(data, list):
|
| 474 |
-
for obj in data:
|
| 475 |
-
_apply(obj)
|
| 476 |
-
except Exception as e:
|
| 477 |
-
print("[CaptionForge] Import merge failed:", e)
|
| 478 |
-
save_session(rows)
|
| 479 |
-
return rows
|
| 480 |
|
| 481 |
# ────────────────────────────────────────────────────────
|
| 482 |
-
#
|
| 483 |
# ────────────────────────────────────────────────────────
|
| 484 |
-
def adaptive_defaults(vram_gb: int) -> Tuple[int, int]:
|
| 485 |
-
if vram_gb >= 40: return 24, min(1024, MAX_SIDE_CAP)
|
| 486 |
-
if vram_gb >= 24: return 16, min(1024, MAX_SIDE_CAP)
|
| 487 |
-
if vram_gb >= 12: return 10, 896
|
| 488 |
-
if vram_gb >= 8: return 6, 768
|
| 489 |
-
return 4, 640
|
| 490 |
-
|
| 491 |
-
def warmup_if_needed(temp: float, top_p: float, max_tokens: int, max_side: int):
|
| 492 |
-
try:
|
| 493 |
-
im = Image.new("RGB", (min(128, max_side), min(128, max_side)), (127,127,127))
|
| 494 |
-
_ = caption_once(im, "Warm up.", temp=0.0, top_p=top_p, max_tokens=min(16, max_tokens))
|
| 495 |
-
except Exception as e:
|
| 496 |
-
print("[CaptionForge] Warm-up skipped:", e)
|
| 497 |
-
|
| 498 |
-
@spaces.GPU()
|
| 499 |
@torch.no_grad()
|
| 500 |
-
def
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
if torch.cuda.is_available():
|
| 506 |
torch.cuda.empty_cache()
|
| 507 |
-
files = files or []
|
| 508 |
-
if not files: return rows, [], []
|
| 509 |
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
|
| 514 |
-
sha_seen = {r.get("sha1") for r in rows}
|
| 515 |
-
f_infos = []
|
| 516 |
for f in files:
|
|
|
|
|
|
|
|
|
|
| 517 |
try:
|
| 518 |
-
|
| 519 |
except Exception:
|
| 520 |
continue
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
cap = apply_shape_aliases(cap) # Apply custom shape→name mappings
|
| 547 |
-
cap = apply_prefix_suffix(cap, trigger_word, begin_text, end_text)
|
| 548 |
-
break
|
| 549 |
-
except Exception as e:
|
| 550 |
-
error_log.append(f"{fname} ({sha[:8]}): {type(e).__name__}: {e} [attempt {attempt}]")
|
| 551 |
-
rows.append({
|
| 552 |
-
"sha1": sha, "filename": fname, "path": f, "thumb_path": thumb,
|
| 553 |
-
"caption": cap or "", "history": []
|
| 554 |
-
})
|
| 555 |
-
save_session(rows)
|
| 556 |
-
|
| 557 |
-
if step_once and len(f_infos) > len(chunks[0]):
|
| 558 |
-
remaining = [fi[0] for fi in f_infos[len(chunks[0]):]]
|
| 559 |
-
save_journal({"remaining_files": remaining})
|
| 560 |
-
return rows, remaining, error_log
|
| 561 |
|
| 562 |
# ────────────────────────────────────────────────────────
|
| 563 |
# UI
|
|
@@ -580,12 +496,24 @@ BASE_CSS = """
|
|
| 580 |
.cf-scroll{max-height:70vh; overflow-y:auto; border:1px solid #e6e6e6; border-radius:10px; padding:8px}
|
| 581 |
/* Uniform sizes */
|
| 582 |
#cfGal .grid > div { height: 96px; }
|
| 583 |
-
/* Hide SHA1 in table (first column) */
|
| 584 |
-
#cfTable table thead tr th:first-child { display:none; }
|
| 585 |
-
#cfTable table tbody tr td:first-child { display:none; }
|
| 586 |
"""
|
| 587 |
|
| 588 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
settings = load_settings()
|
| 590 |
settings["styles"] = [s for s in settings.get("styles", []) if s in STYLE_OPTIONS] or ["Character training (short)"]
|
| 591 |
|
|
@@ -593,10 +521,10 @@ with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
|
|
| 593 |
<div class="cf-hero">
|
| 594 |
{logo_b64_img()}
|
| 595 |
<div>
|
| 596 |
-
<h1 class="cf-title">
|
| 597 |
<div class="cf-sub">Batch captioning</div>
|
| 598 |
<div class="cf-sub">Scrollable editor & autosave</div>
|
| 599 |
-
<div class="cf-sub">CSV / Excel
|
| 600 |
</div>
|
| 601 |
</div>
|
| 602 |
<hr>
|
|
@@ -619,123 +547,116 @@ with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
|
|
| 619 |
)
|
| 620 |
with gr.Accordion("Name & Prefix/Suffix", open=True):
|
| 621 |
name_input = gr.Textbox(label="Person / Character Name", value=settings.get("name", ""))
|
| 622 |
-
trig = gr.Textbox(label="Trigger word")
|
| 623 |
-
add_start = gr.Textbox(label="Add text to start")
|
| 624 |
-
add_end = gr.Textbox(label="Add text to end")
|
| 625 |
|
| 626 |
with gr.Column(scale=1):
|
| 627 |
-
instruction_preview = gr.Textbox(label="Model Instructions", lines=
|
| 628 |
-
dataset_name = gr.Textbox(label="Dataset name (used for export file titles)", value=settings.get("dataset_name", "
|
|
|
|
|
|
|
|
|
|
| 629 |
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
)
|
| 636 |
-
|
| 637 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
|
| 639 |
-
#
|
| 640 |
with gr.Accordion("Shape Aliases", open=False):
|
| 641 |
gr.Markdown(
|
| 642 |
"### 🔷 Shape Aliases\n"
|
| 643 |
"Replace literal **shape tokens** in captions with a preferred **name**.\n\n"
|
| 644 |
-
"**How to use
|
| 645 |
-
"1) Left column: exact shape token (e.g., `diamond`, `heart`).\n"
|
| 646 |
-
"2) Right column: replacement name (e.g., `starkey-emblem`).\n"
|
| 647 |
-
"3) Click **Save**.\n\n"
|
| 648 |
-
"_Rows with blanks are ignored; add as many rows as you like._"
|
| 649 |
)
|
| 650 |
-
|
| 651 |
init_rows, init_enabled = get_shape_alias_rows_ui_defaults()
|
| 652 |
-
|
| 653 |
-
enable_aliases = gr.Checkbox(
|
| 654 |
-
label="Enable shape alias replacements",
|
| 655 |
-
value=init_enabled,
|
| 656 |
-
)
|
| 657 |
-
|
| 658 |
alias_table = gr.Dataframe(
|
| 659 |
headers=["shape (literal token)", "name to insert"],
|
| 660 |
-
value=init_rows,
|
| 661 |
-
col_count=(2, "fixed"),
|
| 662 |
-
row_count=(max(1, len(init_rows)), "dynamic"),
|
| 663 |
datatype=["str", "str"],
|
| 664 |
-
type="array",
|
| 665 |
interactive=True
|
| 666 |
)
|
| 667 |
-
|
| 668 |
with gr.Row():
|
| 669 |
add_row_btn = gr.Button("+ Add row", variant="secondary")
|
| 670 |
clear_btn = gr.Button("Clear", variant="secondary")
|
| 671 |
save_btn = gr.Button("💾 Save", variant="primary")
|
| 672 |
-
|
| 673 |
save_status = gr.Markdown("")
|
| 674 |
-
|
| 675 |
-
# Button handlers
|
| 676 |
def _add_row(cur):
|
| 677 |
cur = (cur or []) + [["", ""]]
|
| 678 |
return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic"))
|
| 679 |
-
|
| 680 |
def _clear_rows():
|
| 681 |
return gr.update(value=[["", ""]], row_count=(1, "dynamic"))
|
| 682 |
-
|
| 683 |
add_row_btn.click(_add_row, inputs=[alias_table], outputs=[alias_table])
|
| 684 |
clear_btn.click(_clear_rows, outputs=[alias_table])
|
| 685 |
-
save_btn.click(save_shape_alias_rows, inputs=[enable_aliases, alias_table],
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
|
| 701 |
-
|
| 702 |
-
|
|
|
|
|
|
|
| 703 |
|
| 704 |
-
# =====
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
import_file = gr.File(label="Import CSV or JSONL (merge captions)")
|
| 708 |
|
| 709 |
-
with gr.Row():
|
| 710 |
-
import_btn = gr.Button("Import → Merge")
|
| 711 |
-
resume_btn = gr.Button("Resume last run", variant="secondary")
|
| 712 |
-
run_button = gr.Button("Caption batch", variant="primary")
|
| 713 |
-
|
| 714 |
-
# Step-chunk controls
|
| 715 |
-
step_panel = gr.Group(visible=False)
|
| 716 |
-
with step_panel:
|
| 717 |
-
step_msg = gr.Markdown("")
|
| 718 |
-
step_next = gr.Button("Process next chunk")
|
| 719 |
-
step_finish = gr.Button("Finish")
|
| 720 |
-
|
| 721 |
-
# States
|
| 722 |
-
rows_state = gr.State(load_session())
|
| 723 |
-
visible_count_state = gr.State(100)
|
| 724 |
-
pending_files = gr.State([])
|
| 725 |
-
remaining_state = gr.State([])
|
| 726 |
-
autosave_md = gr.Markdown("Ready.")
|
| 727 |
-
|
| 728 |
-
# Error log
|
| 729 |
-
with gr.Accordion("Error log", open=False):
|
| 730 |
-
log_md = gr.Markdown("")
|
| 731 |
-
|
| 732 |
-
# Side-by-side gallery + table with shared scroll
|
| 733 |
with gr.Row():
|
| 734 |
with gr.Column(scale=1):
|
| 735 |
gallery = gr.Gallery(
|
| 736 |
-
label="Results (
|
| 737 |
show_label=True,
|
| 738 |
-
columns=
|
| 739 |
height=520,
|
| 740 |
elem_id="cfGal",
|
| 741 |
elem_classes=["cf-scroll"]
|
|
@@ -743,49 +664,30 @@ with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
|
|
| 743 |
with gr.Column(scale=1):
|
| 744 |
table = gr.Dataframe(
|
| 745 |
label="Editable captions (whole session)",
|
| 746 |
-
value=
|
| 747 |
-
headers=["
|
| 748 |
interactive=True,
|
|
|
|
| 749 |
elem_id="cfTable",
|
| 750 |
elem_classes=["cf-scroll"]
|
| 751 |
)
|
| 752 |
|
| 753 |
-
# Exports
|
| 754 |
with gr.Row():
|
| 755 |
with gr.Column():
|
| 756 |
export_csv_btn = gr.Button("Export CSV")
|
| 757 |
csv_file = gr.File(label="CSV file", visible=False)
|
| 758 |
with gr.Column():
|
| 759 |
-
export_xlsx_btn = gr.Button("Export Excel (.xlsx)")
|
| 760 |
xlsx_file = gr.File(label="Excel file", visible=False)
|
| 761 |
-
with gr.Column():
|
| 762 |
-
export_txt_btn = gr.Button("Export captions as .txt (zip)")
|
| 763 |
-
txt_zip = gr.File(label="TXT zip", visible=False)
|
| 764 |
|
| 765 |
-
#
|
| 766 |
-
def
|
| 767 |
rows = rows or []
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
# Gallery expects list of paths
|
| 772 |
-
gal = []
|
| 773 |
-
for r in win:
|
| 774 |
-
p = r.get("thumb_path") or r.get("path")
|
| 775 |
-
if isinstance(p, str) and os.path.exists(p):
|
| 776 |
-
gal.append(p)
|
| 777 |
-
|
| 778 |
-
# Dataframe rows (sha1 kept for mapping; hidden via CSS)
|
| 779 |
-
df = [[str(r.get("sha1","")), str(r.get("filename","")), str(r.get("caption",""))] for r in win]
|
| 780 |
-
info = f"Showing {len(win)} of {len(rows)}"
|
| 781 |
-
return gal, df, info
|
| 782 |
-
|
| 783 |
-
# Initial loads
|
| 784 |
-
demo.load(lambda s,e,n: final_instruction(s or ["Character training (short)"], e or [], n),
|
| 785 |
-
inputs=[style_checks, extra_opts, name_input], outputs=[instruction_preview])
|
| 786 |
-
demo.load(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
|
| 787 |
|
| 788 |
-
# Scroll sync
|
| 789 |
gr.HTML("""
|
| 790 |
<script>
|
| 791 |
(function () {
|
|
@@ -830,123 +732,36 @@ with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
|
|
| 830 |
</script>
|
| 831 |
""")
|
| 832 |
|
| 833 |
-
#
|
| 834 |
-
def
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
if not tgt:
|
| 842 |
-
continue
|
| 843 |
-
prev = tgt.get("caption", "")
|
| 844 |
-
newc = row[2] if len(row) > 2 else prev
|
| 845 |
-
if newc != prev:
|
| 846 |
-
hist = tgt.setdefault("history", [])
|
| 847 |
-
hist.append(prev)
|
| 848 |
-
tgt["caption"] = newc
|
| 849 |
-
save_session(rows or [])
|
| 850 |
-
return rows, f"Saved • {time.strftime('%H:%M:%S')}"
|
| 851 |
-
table.change(on_table_edit, inputs=[table, rows_state], outputs=[rows_state, autosave_md])\
|
| 852 |
-
.then(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
|
| 853 |
-
|
| 854 |
-
# Import merge
|
| 855 |
-
def do_import(file, rows):
|
| 856 |
-
if not file:
|
| 857 |
-
return rows, gr.update(value="No file"), rows
|
| 858 |
-
rows = import_csv_jsonl(file.name, rows or [])
|
| 859 |
-
return rows, gr.update(value=f"Merged from {os.path.basename(file.name)} at {time.strftime('%H:%M:%S')}"), rows
|
| 860 |
-
import_btn.click(do_import, inputs=[import_file, rows_state], outputs=[rows_state, autosave_md, rows_state])\
|
| 861 |
-
.then(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
|
| 862 |
-
|
| 863 |
-
# Prepare run (hash dupes)
|
| 864 |
-
def prepare_run(files, rows, pending):
|
| 865 |
-
files = files or []
|
| 866 |
-
rows = rows or []
|
| 867 |
-
pending = pending or []
|
| 868 |
-
if not files:
|
| 869 |
-
return [], gr.update(open=True), []
|
| 870 |
-
existing = {r.get("sha1") for r in rows if r.get("sha1")}
|
| 871 |
-
pending_sha = set()
|
| 872 |
-
for f in pending:
|
| 873 |
-
try:
|
| 874 |
-
pending_sha.add(sha1_file(f))
|
| 875 |
-
except Exception:
|
| 876 |
-
pass
|
| 877 |
-
keep = []
|
| 878 |
-
for f in files:
|
| 879 |
-
try:
|
| 880 |
-
s = sha1_file(f)
|
| 881 |
-
except Exception:
|
| 882 |
-
continue
|
| 883 |
-
if s not in existing and s not in pending_sha:
|
| 884 |
-
keep.append(f)
|
| 885 |
-
return keep, gr.update(open=False), keep
|
| 886 |
-
input_files.change(prepare_run, inputs=[input_files, rows_state, pending_files], outputs=[pending_files, uploads_acc, pending_files])
|
| 887 |
-
|
| 888 |
-
# Step visibility helper (dynamic message)
|
| 889 |
-
def _step_visibility(remain, mode):
|
| 890 |
-
if (mode == "Manual (step)") and remain:
|
| 891 |
-
return gr.update(visible=True), gr.update(value=f"{len(remain)} files remain. Process next chunk?")
|
| 892 |
-
return gr.update(visible=False), gr.update(value="")
|
| 893 |
-
|
| 894 |
-
# Run helpers (use settings for gen params)
|
| 895 |
-
def _run_once(pending, rows, instr, trigv, begv, endv, mode, csize, mside):
|
| 896 |
-
cfg = load_settings()
|
| 897 |
-
t = cfg.get("temperature", 0.6)
|
| 898 |
-
p = cfg.get("top_p", 0.9)
|
| 899 |
-
m = cfg.get("max_tokens", 256)
|
| 900 |
-
step_once = (mode == "Manual (step)")
|
| 901 |
-
new_rows, remaining, errors = process_batch(
|
| 902 |
-
pending, False, rows or [], instr, t, p, m, trigv, begv, endv,
|
| 903 |
-
mode, int(csize), int(mside), step_once
|
| 904 |
-
)
|
| 905 |
-
log = "\n".join(errors) if errors else "(no errors)"
|
| 906 |
-
return new_rows, remaining, log, f"Saved • {time.strftime('%H:%M:%S')}"
|
| 907 |
|
| 908 |
run_button.click(
|
| 909 |
-
|
| 910 |
-
inputs=[
|
| 911 |
-
outputs=[rows_state,
|
| 912 |
-
).then(
|
| 913 |
-
_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md]
|
| 914 |
-
).then(
|
| 915 |
-
_step_visibility, inputs=[remaining_state, chunk_mode], outputs=[step_panel, step_msg]
|
| 916 |
)
|
| 917 |
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
def _step_next(remain, rows, instr, trigv, begv, endv, mode, csize, mside):
|
| 925 |
-
return _run_once(remain, rows, instr, trigv, begv, endv, mode, csize, mside)
|
| 926 |
-
step_next.click(
|
| 927 |
-
_step_next,
|
| 928 |
-
inputs=[remaining_state, rows_state, instruction_preview, trig, add_start, add_end, chunk_mode, chunk_size, max_side],
|
| 929 |
-
outputs=[rows_state, remaining_state, log_md, autosave_md]
|
| 930 |
-
).then(
|
| 931 |
-
_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md]
|
| 932 |
-
).then(
|
| 933 |
-
_step_visibility, inputs=[remaining_state, chunk_mode], outputs=[step_panel, step_msg]
|
| 934 |
)
|
| 935 |
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
# Exports (reveal file widgets when created)
|
| 939 |
export_csv_btn.click(
|
| 940 |
-
lambda
|
| 941 |
-
inputs=[
|
| 942 |
)
|
| 943 |
export_xlsx_btn.click(
|
| 944 |
-
lambda rows,
|
| 945 |
-
inputs=[rows_state,
|
| 946 |
-
)
|
| 947 |
-
export_txt_btn.click(
|
| 948 |
-
lambda rows, base: (export_txt_zip(rows or [], base), gr.update(visible=True)),
|
| 949 |
-
inputs=[rows_state, dataset_name], outputs=[txt_zip, txt_zip]
|
| 950 |
)
|
| 951 |
|
| 952 |
# Launch
|
|
|
|
| 1 |
import os, io, csv, time, json, hashlib, base64, zipfile, re
|
| 2 |
+
from typing import List, Tuple, Dict, Any
|
| 3 |
+
|
| 4 |
+
# Caches
|
| 5 |
+
os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")
|
| 6 |
+
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
|
| 7 |
|
|
|
|
| 8 |
import gradio as gr
|
| 9 |
from PIL import Image
|
| 10 |
import torch
|
|
|
|
| 13 |
# ────────────────────────────────────────────────────────
|
| 14 |
# Paths & caches
|
| 15 |
# ────────────────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
| 16 |
APP_DIR = os.getcwd()
|
| 17 |
SESSION_FILE = "/tmp/session.json"
|
| 18 |
SETTINGS_FILE = "/tmp/cf_settings.json"
|
| 19 |
JOURNAL_FILE = "/tmp/cf_journal.json"
|
| 20 |
+
THUMB_CACHE = os.path.expanduser("~/.cache/forgecaptions/thumbs")
|
| 21 |
+
EXCEL_THUMB_DIR = "/tmp/forge_excel_thumbs"
|
| 22 |
os.makedirs(THUMB_CACHE, exist_ok=True)
|
| 23 |
+
os.makedirs(EXCEL_THUMB_DIR, exist_ok=True)
|
| 24 |
|
| 25 |
# ────────────────────────────────────────────────────────
|
| 26 |
# Model setup
|
|
|
|
| 35 |
|
| 36 |
BACKEND, VRAM_GB, GPU_NAME = _detect_gpu()
|
| 37 |
DTYPE = torch.bfloat16 if BACKEND == "cuda" else torch.float32
|
|
|
|
| 38 |
DEVICE = "cuda" if BACKEND == "cuda" else "cpu"
|
| 39 |
+
MAX_SIDE_CAP = 1024 if BACKEND == "cuda" else 640
|
| 40 |
|
| 41 |
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
| 42 |
model = LlavaForConditionalGeneration.from_pretrained(
|
|
|
|
| 47 |
)
|
| 48 |
model.eval()
|
| 49 |
|
| 50 |
+
print(f"[ForgeCaptions] Backend={BACKEND} GPU={GPU_NAME} VRAM={VRAM_GB}GB dtype={DTYPE}")
|
| 51 |
+
print(f"[ForgeCaptions] Gradio version: {gr.__version__}")
|
| 52 |
|
| 53 |
# ────────────────────────────────────────────────────────
|
| 54 |
# Instruction templates & options
|
|
|
|
| 66 |
]
|
| 67 |
|
| 68 |
CAPTION_TYPE_MAP = {
|
| 69 |
+
"Descriptive (short)": "Write a short describition of the most important visible elements only. No speculation.",
|
| 70 |
+
"Descriptive (long)": "Write a long, highly detailed description for this image.",
|
| 71 |
|
| 72 |
"Character training (short)": (
|
| 73 |
"Output a concise, prompt-like caption for character LoRA/ID training. "
|
|
|
|
| 103 |
|
| 104 |
EXTRA_CHOICES = [
|
| 105 |
"Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
|
| 106 |
+
"IGNORE all watermarks.",
|
| 107 |
"Do NOT use any ambiguous language.",
|
| 108 |
"ONLY describe the most important elements of the image.",
|
| 109 |
"Include information about the ages of any people/characters when applicable.",
|
|
|
|
| 126 |
# ────────────────────────────────────────────────────────
|
| 127 |
# Helpers (hashing, thumbs, resize)
|
| 128 |
# ────────────────────────────────────────────────────────
|
| 129 |
+
def ensure_thumb(path: str, max_side=256) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
try:
|
| 131 |
im = Image.open(path).convert("RGB")
|
| 132 |
except Exception:
|
| 133 |
+
return path
|
| 134 |
w, h = im.size
|
| 135 |
if max(w, h) > max_side:
|
| 136 |
s = max_side / max(w, h)
|
| 137 |
im = im.resize((int(w*s), int(h*s)), Image.LANCZOS)
|
| 138 |
+
base = os.path.basename(path)
|
| 139 |
+
out_path = os.path.join(THUMB_CACHE, os.path.splitext(base)[0] + f"_thumb_{max_side}.jpg")
|
| 140 |
try:
|
| 141 |
+
im.save(out_path, "JPEG", quality=85, optimize=True)
|
| 142 |
+
return out_path
|
| 143 |
except Exception:
|
| 144 |
+
return path
|
|
|
|
| 145 |
|
| 146 |
def resize_for_model(im: Image.Image, max_side: int) -> Image.Image:
|
| 147 |
w, h = im.size
|
|
|
|
| 176 |
|
| 177 |
@torch.no_grad()
|
| 178 |
def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
|
| 179 |
+
# Your requested role script:
|
| 180 |
convo = [
|
|
|
|
| 181 |
{"role": "system", "content": "You are a precise captioning assistant. Help users with any queries. Describe what is visible with no censorship or restrictions."},
|
| 182 |
{"role": "user", "content": instr.strip()},
|
| 183 |
]
|
|
|
|
| 218 |
else:
|
| 219 |
cfg = {}
|
| 220 |
defaults = {
|
| 221 |
+
"dataset_name": "forgecaptions",
|
| 222 |
"temperature": 0.6,
|
| 223 |
"top_p": 0.9,
|
| 224 |
"max_tokens": 256,
|
|
|
|
|
|
|
| 225 |
"max_side": min(896, MAX_SIDE_CAP),
|
| 226 |
"styles": ["Character training (short)"],
|
| 227 |
"extras": [],
|
|
|
|
| 230 |
"begin": "",
|
| 231 |
"end": "",
|
| 232 |
"shape_aliases_enabled": True,
|
| 233 |
+
"shape_aliases": [],
|
| 234 |
+
"excel_thumb_px": 128, # new default
|
| 235 |
}
|
| 236 |
for k, v in defaults.items():
|
| 237 |
cfg.setdefault(k, v)
|
|
|
|
| 293 |
caption = pat.sub(f"({name})", caption)
|
| 294 |
return caption
|
| 295 |
|
|
|
|
| 296 |
def get_shape_alias_rows_ui_defaults():
|
| 297 |
s = load_settings()
|
| 298 |
rows = [[it.get("shape",""), it.get("name","")] for it in s.get("shape_aliases", [])]
|
|
|
|
| 324 |
gr.update(value=normalized, row_count=(max(1, len(normalized)), "dynamic"))
|
| 325 |
)
|
| 326 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
# ────────────────────────────────────────────────────────
|
| 328 |
# Import / Export helpers
|
| 329 |
# ────────────────────────────────────────────────────────
|
| 330 |
+
def export_csv_from_table(table_value: Any) -> str:
|
| 331 |
+
data = table_value or []
|
| 332 |
+
out = f"/tmp/forgecaptions_{int(time.time())}.csv"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
with open(out, "w", newline="", encoding="utf-8") as f:
|
| 334 |
w = csv.writer(f)
|
| 335 |
+
w.writerow(["filename", "caption"])
|
| 336 |
+
w.writerows(data)
|
|
|
|
| 337 |
return out
|
| 338 |
|
| 339 |
+
def _resize_for_excel(path: str, px: int) -> str:
|
| 340 |
+
"""Create a temp resized copy for Excel embedding."""
|
| 341 |
+
try:
|
| 342 |
+
im = Image.open(path).convert("RGB")
|
| 343 |
+
except Exception:
|
| 344 |
+
return path
|
| 345 |
+
w, h = im.size
|
| 346 |
+
if max(w, h) > px:
|
| 347 |
+
s = px / max(w, h)
|
| 348 |
+
im = im.resize((int(w*s), int(h*s)), Image.LANCZOS)
|
| 349 |
+
base = os.path.basename(path)
|
| 350 |
+
out_path = os.path.join(EXCEL_THUMB_DIR, f"{os.path.splitext(base)[0]}_{px}px.jpg")
|
| 351 |
+
try:
|
| 352 |
+
im.save(out_path, "JPEG", quality=85, optimize=True)
|
| 353 |
+
return out_path
|
| 354 |
+
except Exception:
|
| 355 |
+
return path
|
| 356 |
+
|
| 357 |
+
def export_excel_with_thumbs(table_value: Any, session_rows: List[dict], thumb_px: int) -> str:
|
| 358 |
try:
|
| 359 |
from openpyxl import Workbook
|
| 360 |
from openpyxl.drawing.image import Image as XLImage
|
| 361 |
except Exception as e:
|
| 362 |
raise RuntimeError("Excel export requires 'openpyxl' in requirements.txt.") from e
|
| 363 |
+
|
| 364 |
+
# Respect user edits (table wins)
|
| 365 |
+
caption_by_file = {}
|
| 366 |
+
for row in (table_value or []):
|
| 367 |
+
if not row:
|
| 368 |
+
continue
|
| 369 |
+
fn = str(row[0]) if len(row) > 0 else ""
|
| 370 |
+
cap = str(row[1]) if len(row) > 1 and row[1] is not None else ""
|
| 371 |
+
if fn:
|
| 372 |
+
caption_by_file[fn] = cap
|
| 373 |
+
|
| 374 |
+
wb = Workbook(); ws = wb.active; ws.title = "ForgeCaptions"
|
| 375 |
ws.append(["image", "filename", "caption"])
|
| 376 |
ws.column_dimensions["A"].width = 24
|
| 377 |
ws.column_dimensions["B"].width = 42
|
| 378 |
ws.column_dimensions["C"].width = 100
|
| 379 |
+
|
| 380 |
+
# Approx px→points (Excel row height is points; ~0.75 pt per px @ 96dpi)
|
| 381 |
+
row_h = int(thumb_px * 0.75)
|
| 382 |
+
|
| 383 |
r_i = 2
|
| 384 |
+
for r in (session_rows or []):
|
| 385 |
+
fn = r.get("filename","")
|
| 386 |
+
cap = caption_by_file.get(fn, r.get("caption",""))
|
| 387 |
ws.cell(row=r_i, column=2, value=fn)
|
| 388 |
ws.cell(row=r_i, column=3, value=cap)
|
| 389 |
img_path = r.get("thumb_path") or r.get("path")
|
| 390 |
if img_path and os.path.exists(img_path):
|
| 391 |
try:
|
| 392 |
+
resized = _resize_for_excel(img_path, int(thumb_px))
|
| 393 |
+
xlimg = XLImage(resized)
|
| 394 |
+
ws.add_image(xlimg, f"A{r_i}")
|
| 395 |
+
ws.row_dimensions[r_i].height = row_h
|
| 396 |
except Exception:
|
| 397 |
pass
|
| 398 |
r_i += 1
|
| 399 |
+
|
| 400 |
+
out = f"/tmp/forgecaptions_{int(time.time())}.xlsx"
|
| 401 |
+
wb.save(out)
|
| 402 |
+
return out
|
| 403 |
+
|
| 404 |
+
# Rows<->Table helpers
|
| 405 |
+
def _rows_to_table(rows: List[dict]) -> list:
|
| 406 |
+
return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])]
|
| 407 |
+
|
| 408 |
+
def _table_to_rows(table_value: Any, rows: List[dict]) -> List[dict]:
|
| 409 |
+
tbl = table_value or []
|
| 410 |
+
new = []
|
| 411 |
+
for i, r in enumerate(rows or []):
|
| 412 |
+
r = dict(r)
|
| 413 |
+
if i < len(tbl) and len(tbl[i]) >= 2:
|
| 414 |
+
r["filename"] = str(tbl[i][0]) if tbl[i][0] is not None else r.get("filename","")
|
| 415 |
+
r["caption"] = str(tbl[i][1]) if tbl[i][1] is not None else r.get("caption","")
|
| 416 |
+
new.append(r)
|
| 417 |
+
return new
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
# ────────────────────────────────────────────────────────
|
| 420 |
+
# Batch captioning (returns rows, gallery, table)
|
| 421 |
# ────────────────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
@torch.no_grad()
|
| 423 |
+
def run_batch(
|
| 424 |
+
files: List[Any],
|
| 425 |
+
session_rows: List[dict],
|
| 426 |
+
instr_text: str,
|
| 427 |
+
temp: float,
|
| 428 |
+
top_p: float,
|
| 429 |
+
max_tokens: int,
|
| 430 |
+
max_side: int,
|
| 431 |
+
) -> Tuple[List[dict], list, list, str]:
|
| 432 |
if torch.cuda.is_available():
|
| 433 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
| 434 |
|
| 435 |
+
session_rows = session_rows or []
|
| 436 |
+
files = files or []
|
| 437 |
+
if not files:
|
| 438 |
+
gallery_pairs = [
|
| 439 |
+
((r.get("thumb_path") or r.get("path")), r.get("caption",""))
|
| 440 |
+
for r in session_rows if (r.get("thumb_path") or r.get("path"))
|
| 441 |
+
]
|
| 442 |
+
return session_rows, gallery_pairs, _rows_to_table(session_rows), f"Saved • {time.strftime('%H:%M:%S')}"
|
| 443 |
|
|
|
|
|
|
|
| 444 |
for f in files:
|
| 445 |
+
path = f if isinstance(f, str) else getattr(f, "name", None) or getattr(f, "path", None)
|
| 446 |
+
if not path or not os.path.exists(path):
|
| 447 |
+
continue
|
| 448 |
try:
|
| 449 |
+
im = Image.open(path).convert("RGB")
|
| 450 |
except Exception:
|
| 451 |
continue
|
| 452 |
+
im = resize_for_model(im, max_side)
|
| 453 |
+
cap = caption_once(im, instr_text, temp, top_p, max_tokens)
|
| 454 |
+
cap = apply_shape_aliases(cap)
|
| 455 |
+
s = load_settings()
|
| 456 |
+
cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end",""))
|
| 457 |
+
|
| 458 |
+
filename = os.path.basename(path)
|
| 459 |
+
thumb = ensure_thumb(path, 256)
|
| 460 |
+
session_rows.append({"filename": filename, "caption": cap, "path": path, "thumb_path": thumb})
|
| 461 |
+
|
| 462 |
+
save_session(session_rows)
|
| 463 |
+
gallery_pairs = [
|
| 464 |
+
((r.get("thumb_path") or r.get("path")), r.get("caption",""))
|
| 465 |
+
for r in session_rows if (r.get("thumb_path") or r.get("path"))
|
| 466 |
+
]
|
| 467 |
+
return session_rows, gallery_pairs, _rows_to_table(session_rows), f"Saved • {time.strftime('%H:%M:%S')}"
|
| 468 |
+
|
| 469 |
+
def sync_table_to_session(table_value: Any, session_rows: List[dict]) -> Tuple[List[dict], list, str]:
|
| 470 |
+
session_rows = _table_to_rows(table_value, session_rows or [])
|
| 471 |
+
save_session(session_rows)
|
| 472 |
+
gallery_pairs = [
|
| 473 |
+
((r.get("thumb_path") or r.get("path")), r.get("caption",""))
|
| 474 |
+
for r in session_rows if (r.get("thumb_path") or r.get("path"))
|
| 475 |
+
]
|
| 476 |
+
return session_rows, gallery_pairs, f"Saved • {time.strftime('%H:%M:%S')}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
|
| 478 |
# ────────────────────────────────────────────────────────
|
| 479 |
# UI
|
|
|
|
| 496 |
.cf-scroll{max-height:70vh; overflow-y:auto; border:1px solid #e6e6e6; border-radius:10px; padding:8px}
|
| 497 |
/* Uniform sizes */
|
| 498 |
#cfGal .grid > div { height: 96px; }
|
|
|
|
|
|
|
|
|
|
| 499 |
"""
|
| 500 |
|
| 501 |
+
def logo_b64_img() -> str:
|
| 502 |
+
candidates = [
|
| 503 |
+
os.path.join(APP_DIR, "forgecaptions-logo.png"),
|
| 504 |
+
os.path.join(APP_DIR, "captionforge-logo.png"), # fallback if you kept the old name
|
| 505 |
+
"/home/user/app/forgecaptions-logo.png",
|
| 506 |
+
"forgecaptions-logo.png",
|
| 507 |
+
"captionforge-logo.png",
|
| 508 |
+
]
|
| 509 |
+
for p in candidates:
|
| 510 |
+
if os.path.exists(p):
|
| 511 |
+
with open(p, "rb") as f:
|
| 512 |
+
b64 = base64.b64encode(f.read()).decode("ascii")
|
| 513 |
+
return f"<img src='data:image/png;base64,{b64}' alt='ForgeCaptions' class='cf-logo'>"
|
| 514 |
+
return ""
|
| 515 |
+
|
| 516 |
+
with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
| 517 |
settings = load_settings()
|
| 518 |
settings["styles"] = [s for s in settings.get("styles", []) if s in STYLE_OPTIONS] or ["Character training (short)"]
|
| 519 |
|
|
|
|
| 521 |
<div class="cf-hero">
|
| 522 |
{logo_b64_img()}
|
| 523 |
<div>
|
| 524 |
+
<h1 class="cf-title">ForgeCaptions</h1>
|
| 525 |
<div class="cf-sub">Batch captioning</div>
|
| 526 |
<div class="cf-sub">Scrollable editor & autosave</div>
|
| 527 |
+
<div class="cf-sub">CSV / Excel export</div>
|
| 528 |
</div>
|
| 529 |
</div>
|
| 530 |
<hr>
|
|
|
|
| 547 |
)
|
| 548 |
with gr.Accordion("Name & Prefix/Suffix", open=True):
|
| 549 |
name_input = gr.Textbox(label="Person / Character Name", value=settings.get("name", ""))
|
| 550 |
+
trig = gr.Textbox(label="Trigger word", value=settings.get("trigger",""))
|
| 551 |
+
add_start = gr.Textbox(label="Add text to start", value=settings.get("begin",""))
|
| 552 |
+
add_end = gr.Textbox(label="Add text to end", value=settings.get("end",""))
|
| 553 |
|
| 554 |
with gr.Column(scale=1):
|
| 555 |
+
instruction_preview = gr.Textbox(label="Model Instructions", lines=12)
|
| 556 |
+
dataset_name = gr.Textbox(label="Dataset name (used for export file titles)", value=settings.get("dataset_name", "forgecaptions"))
|
| 557 |
+
max_side = gr.Slider(256, MAX_SIDE_CAP, settings.get("max_side", min(896, MAX_SIDE_CAP)), step=32, label="Max side (resize)")
|
| 558 |
+
excel_thumb_px = gr.Slider(64, 256, value=settings.get("excel_thumb_px", 128), step=8, label="Excel thumbnail size (px)")
|
| 559 |
+
gr.Markdown("Generation (saved in settings): temperature 0.6 • top-p 0.9 • max_tokens 256")
|
| 560 |
|
| 561 |
+
# Auto-refresh instruction & persist key controls
|
| 562 |
+
def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms):
|
| 563 |
+
instr = final_instruction(styles or ["Character training (short)"], extra or [], name_value)
|
| 564 |
+
cfg = load_settings()
|
| 565 |
+
cfg.update({
|
| 566 |
+
"styles": styles or ["Character training (short)"],
|
| 567 |
+
"extras": extra or [],
|
| 568 |
+
"name": name_value,
|
| 569 |
+
"trigger": trigv, "begin": begv, "end": endv,
|
| 570 |
+
"excel_thumb_px": int(excel_px),
|
| 571 |
+
"max_side": int(ms),
|
| 572 |
+
})
|
| 573 |
+
save_settings(cfg)
|
| 574 |
+
return instr
|
| 575 |
+
|
| 576 |
+
for comp in [style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side]:
|
| 577 |
+
comp.change(
|
| 578 |
+
_refresh_instruction,
|
| 579 |
+
inputs=[style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side],
|
| 580 |
+
outputs=[instruction_preview]
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
# Set initial instruction text on load
|
| 584 |
+
demo.load(lambda s,e,n: final_instruction(s or ["Character training (short)"], e or [], n),
|
| 585 |
+
inputs=[style_checks, extra_opts, name_input], outputs=[instruction_preview])
|
| 586 |
|
| 587 |
+
# ===== Shape Aliases
|
| 588 |
with gr.Accordion("Shape Aliases", open=False):
|
| 589 |
gr.Markdown(
|
| 590 |
"### 🔷 Shape Aliases\n"
|
| 591 |
"Replace literal **shape tokens** in captions with a preferred **name**.\n\n"
|
| 592 |
+
"**How to use:** left column = shape token (e.g., `diamond`), right column = replacement name (e.g., `starkey-emblem`)."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
)
|
|
|
|
| 594 |
init_rows, init_enabled = get_shape_alias_rows_ui_defaults()
|
| 595 |
+
enable_aliases = gr.Checkbox(label="Enable shape alias replacements", value=init_enabled)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
alias_table = gr.Dataframe(
|
| 597 |
headers=["shape (literal token)", "name to insert"],
|
| 598 |
+
value=init_rows,
|
| 599 |
+
col_count=(2, "fixed"),
|
| 600 |
+
row_count=(max(1, len(init_rows)), "dynamic"),
|
| 601 |
datatype=["str", "str"],
|
| 602 |
+
type="array",
|
| 603 |
interactive=True
|
| 604 |
)
|
|
|
|
| 605 |
with gr.Row():
|
| 606 |
add_row_btn = gr.Button("+ Add row", variant="secondary")
|
| 607 |
clear_btn = gr.Button("Clear", variant="secondary")
|
| 608 |
save_btn = gr.Button("💾 Save", variant="primary")
|
|
|
|
| 609 |
save_status = gr.Markdown("")
|
|
|
|
|
|
|
| 610 |
def _add_row(cur):
|
| 611 |
cur = (cur or []) + [["", ""]]
|
| 612 |
return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic"))
|
|
|
|
| 613 |
def _clear_rows():
|
| 614 |
return gr.update(value=[["", ""]], row_count=(1, "dynamic"))
|
|
|
|
| 615 |
add_row_btn.click(_add_row, inputs=[alias_table], outputs=[alias_table])
|
| 616 |
clear_btn.click(_clear_rows, outputs=[alias_table])
|
| 617 |
+
save_btn.click(save_shape_alias_rows, inputs=[enable_aliases, alias_table], outputs=[save_status, alias_table])
|
| 618 |
+
|
| 619 |
+
# ===== Tabs: Single & Batch (keeps gallery/table position below)
|
| 620 |
+
with gr.Tabs():
|
| 621 |
+
with gr.Tab("Single"):
|
| 622 |
+
input_image_single = gr.Image(type="pil", label="Input Image", height=512, width=512)
|
| 623 |
+
single_caption_btn = gr.Button("Caption")
|
| 624 |
+
single_caption_out = gr.Textbox(label="Caption (single)")
|
| 625 |
+
|
| 626 |
+
def _caption_single(img, instr):
|
| 627 |
+
if img is None:
|
| 628 |
+
return "No image provided."
|
| 629 |
+
s = load_settings()
|
| 630 |
+
im = resize_for_model(img, int(s.get("max_side", MAX_SIDE_CAP)))
|
| 631 |
+
t = s.get("temperature", 0.6)
|
| 632 |
+
p = s.get("top_p", 0.9)
|
| 633 |
+
m = s.get("max_tokens", 256)
|
| 634 |
+
cap = caption_once(im, instr, t, p, m)
|
| 635 |
+
cap = apply_shape_aliases(cap)
|
| 636 |
+
cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end",""))
|
| 637 |
+
return cap
|
| 638 |
+
|
| 639 |
+
single_caption_btn.click(
|
| 640 |
+
_caption_single,
|
| 641 |
+
inputs=[input_image_single, instruction_preview],
|
| 642 |
+
outputs=[single_caption_out]
|
| 643 |
+
)
|
| 644 |
|
| 645 |
+
with gr.Tab("Batch"):
|
| 646 |
+
with gr.Accordion("Uploaded images", open=True):
|
| 647 |
+
input_files = gr.File(label="Drop images", file_types=["image"], file_count="multiple", type="filepath")
|
| 648 |
+
run_button = gr.Button("Caption batch", variant="primary")
|
| 649 |
|
| 650 |
+
# ===== Results/Table (position unchanged)
|
| 651 |
+
rows_state = gr.State(load_session())
|
| 652 |
+
autosave_md = gr.Markdown("Ready.")
|
|
|
|
| 653 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
with gr.Row():
|
| 655 |
with gr.Column(scale=1):
|
| 656 |
gallery = gr.Gallery(
|
| 657 |
+
label="Results (image + caption)",
|
| 658 |
show_label=True,
|
| 659 |
+
columns=3,
|
| 660 |
height=520,
|
| 661 |
elem_id="cfGal",
|
| 662 |
elem_classes=["cf-scroll"]
|
|
|
|
| 664 |
with gr.Column(scale=1):
|
| 665 |
table = gr.Dataframe(
|
| 666 |
label="Editable captions (whole session)",
|
| 667 |
+
value=_rows_to_table(load_session()),
|
| 668 |
+
headers=["filename", "caption"],
|
| 669 |
interactive=True,
|
| 670 |
+
wrap=True,
|
| 671 |
elem_id="cfTable",
|
| 672 |
elem_classes=["cf-scroll"]
|
| 673 |
)
|
| 674 |
|
| 675 |
+
# Exports
|
| 676 |
with gr.Row():
|
| 677 |
with gr.Column():
|
| 678 |
export_csv_btn = gr.Button("Export CSV")
|
| 679 |
csv_file = gr.File(label="CSV file", visible=False)
|
| 680 |
with gr.Column():
|
| 681 |
+
export_xlsx_btn = gr.Button("Export Excel (.xlsx) with thumbnails")
|
| 682 |
xlsx_file = gr.File(label="Excel file", visible=False)
|
|
|
|
|
|
|
|
|
|
| 683 |
|
| 684 |
+
# Initial gallery render
|
| 685 |
+
def _initial_gallery(rows):
|
| 686 |
rows = rows or []
|
| 687 |
+
return [((r.get("thumb_path") or r.get("path")), r.get("caption","")) for r in rows if (r.get("thumb_path") or r.get("path"))]
|
| 688 |
+
demo.load(_initial_gallery, inputs=[rows_state], outputs=[gallery])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
|
| 690 |
+
# Scroll sync
|
| 691 |
gr.HTML("""
|
| 692 |
<script>
|
| 693 |
(function () {
|
|
|
|
| 732 |
</script>
|
| 733 |
""")
|
| 734 |
|
| 735 |
+
# Run batch
|
| 736 |
+
def _run_click(files, rows, instr, ms):
|
| 737 |
+
s = load_settings()
|
| 738 |
+
t = s.get("temperature", 0.6)
|
| 739 |
+
p = s.get("top_p", 0.9)
|
| 740 |
+
m = s.get("max_tokens", 256)
|
| 741 |
+
new_rows, gal, tbl, stamp = run_batch(files, rows or [], instr, t, p, m, int(ms))
|
| 742 |
+
return new_rows, gal, tbl, stamp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 743 |
|
| 744 |
run_button.click(
|
| 745 |
+
_run_click,
|
| 746 |
+
inputs=[input_files, rows_state, instruction_preview, max_side],
|
| 747 |
+
outputs=[rows_state, gallery, table, autosave_md]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 748 |
)
|
| 749 |
|
| 750 |
+
# Table edits sync → rows + gallery
|
| 751 |
+
table.change(
|
| 752 |
+
sync_table_to_session,
|
| 753 |
+
inputs=[table, rows_state],
|
| 754 |
+
outputs=[rows_state, gallery, autosave_md]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 755 |
)
|
| 756 |
|
| 757 |
+
# Exports (use slider value)
|
|
|
|
|
|
|
| 758 |
export_csv_btn.click(
|
| 759 |
+
lambda tbl: (export_csv_from_table(tbl), gr.update(visible=True)),
|
| 760 |
+
inputs=[table], outputs=[csv_file, csv_file]
|
| 761 |
)
|
| 762 |
export_xlsx_btn.click(
|
| 763 |
+
lambda tbl, rows, px: (export_excel_with_thumbs(tbl, rows or [], int(px)), gr.update(visible=True)),
|
| 764 |
+
inputs=[table, rows_state, excel_thumb_px], outputs=[xlsx_file, xlsx_file]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 765 |
)
|
| 766 |
|
| 767 |
# Launch
|