Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,28 +5,20 @@
|
|
| 5 |
# ------------------------------
|
| 6 |
# 0) Imports & environment
|
| 7 |
# ------------------------------
|
| 8 |
-
import os
|
| 9 |
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
| 10 |
os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")
|
| 11 |
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
|
| 12 |
-
import io, csv, time, json, base64, re, zipfile
|
| 13 |
-
from typing import Generator, List, Tuple, Dict, Any
|
| 14 |
-
from threading import Thread
|
| 15 |
-
# Persist model caches between restarts
|
| 16 |
|
|
|
|
|
|
|
| 17 |
|
| 18 |
import gradio as gr
|
| 19 |
from PIL import Image
|
| 20 |
import torch
|
| 21 |
-
from transformers import LlavaForConditionalGeneration, AutoProcessor
|
| 22 |
-
|
| 23 |
-
# Optional deps for import/export (we handle gracefully if missing)
|
| 24 |
-
try:
|
| 25 |
-
import pandas as pd # not required at runtime; kept for future use
|
| 26 |
-
except Exception:
|
| 27 |
-
pd = None
|
| 28 |
|
| 29 |
-
# Liger
|
| 30 |
try:
|
| 31 |
from liger_kernel.transformers import apply_liger_kernel_to_llama
|
| 32 |
except Exception:
|
|
@@ -46,11 +38,10 @@ except Exception:
|
|
| 46 |
# ------------------------------
|
| 47 |
APP_DIR = os.getcwd()
|
| 48 |
SESSION_FILE = "/tmp/forge_session.json"
|
| 49 |
-
# --- Branding
|
| 50 |
|
|
|
|
| 51 |
LOGO_HEIGHT_PX = int(os.getenv("FORGE_LOGO_PX", 60))
|
| 52 |
|
| 53 |
-
|
| 54 |
# Settings live in a user cache dir (persists better than /tmp)
|
| 55 |
CONFIG_DIR = os.path.expanduser("~/.cache/forgecaptions")
|
| 56 |
os.makedirs(CONFIG_DIR, exist_ok=True)
|
|
@@ -58,7 +49,7 @@ SETTINGS_FILE = os.path.join(CONFIG_DIR, "settings.json")
|
|
| 58 |
|
| 59 |
JOURNAL_FILE = "/tmp/forge_journal.json"
|
| 60 |
|
| 61 |
-
#
|
| 62 |
THUMB_CACHE = "/tmp/forgecaptions/thumbs"
|
| 63 |
EXCEL_THUMB_DIR = "/tmp/forge_excel_thumbs"
|
| 64 |
TXT_EXPORT_DIR = "/tmp/forge_txt"
|
|
@@ -80,8 +71,6 @@ print(f"[ForgeCaptions] Gradio version: {gr.__version__}")
|
|
| 80 |
|
| 81 |
# ------------------------------
|
| 82 |
# 2) Model loader (GPU-safe lazy init)
|
| 83 |
-
# - Processor on CPU
|
| 84 |
-
# - Model is ONLY created inside @gpu functions to satisfy Stateless GPU
|
| 85 |
# ------------------------------
|
| 86 |
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
| 87 |
_MODEL = None
|
|
@@ -104,8 +93,8 @@ def get_model():
|
|
| 104 |
low_cpu_mem_usage=True,
|
| 105 |
device_map=0,
|
| 106 |
)
|
|
|
|
| 107 |
try:
|
| 108 |
-
from liger_kernel.transformers import apply_liger_kernel_to_llama
|
| 109 |
lm = getattr(_MODEL, "language_model", None) or getattr(_MODEL, "model", None)
|
| 110 |
if lm is not None:
|
| 111 |
ok = apply_liger_kernel_to_llama(lm)
|
|
@@ -137,19 +126,19 @@ STYLE_OPTIONS = [
|
|
| 137 |
"Flux.1-Dev",
|
| 138 |
"Stable Diffusion",
|
| 139 |
"MidJourney",
|
| 140 |
-
"E-commerce product",
|
| 141 |
"Portrait (photography)",
|
| 142 |
"Landscape (photography)",
|
| 143 |
"Art analysis (no artist names)",
|
| 144 |
"Social caption",
|
| 145 |
"Aesthetic tags (comma-sep)"
|
| 146 |
-
|
| 147 |
|
| 148 |
CAPTION_TYPE_MAP: Dict[str, str] = {
|
| 149 |
"Descriptive": "Write a detailed description for this image.",
|
| 150 |
"Character training": (
|
| 151 |
"Write a thorough, training-ready caption for a character dataset. "
|
| 152 |
-
"Describe subject appearance (physique, face/hair), clothing and accessories, actions/pose/gesture, camera angle/focal cues."
|
| 153 |
"If multiple subjects are present, describe each briefly (most prominent first) and distinguish them by visible traits."
|
| 154 |
),
|
| 155 |
"Flux.1-Dev": "Write a Flux.1-Dev style prompt that would reproduce this image faithfully.",
|
|
@@ -163,7 +152,6 @@ CAPTION_TYPE_MAP: Dict[str, str] = {
|
|
| 163 |
"Social caption": "Write an engaging caption describing the visible content. No hashtags.",
|
| 164 |
}
|
| 165 |
|
| 166 |
-
|
| 167 |
LENGTH_CHOICES = ["any", "very short", "short", "medium-length", "long", "very long"] + [str(i) for i in range(20, 261, 10)]
|
| 168 |
|
| 169 |
_LENGTH_HINTS = {
|
|
@@ -173,7 +161,6 @@ _LENGTH_HINTS = {
|
|
| 173 |
"long": "Write a detailed caption (≈80–120 words).",
|
| 174 |
"very long": "Write a very detailed caption (≈150–250 words).",
|
| 175 |
}
|
| 176 |
-
|
| 177 |
def _length_hint(choice: str) -> str:
|
| 178 |
if not choice or choice == "any":
|
| 179 |
return ""
|
|
@@ -181,10 +168,9 @@ def _length_hint(choice: str) -> str:
|
|
| 181 |
return f"Limit the caption to at most {choice} words."
|
| 182 |
return _LENGTH_HINTS.get(choice, "")
|
| 183 |
|
| 184 |
-
|
| 185 |
EXTRA_CHOICES = [
|
| 186 |
"Only include a character's modifiable, style-level attributes (hair style/color, makeup, clothing/accessories, pose, expression). Do NOT mention identity traits (skin tone, age, body type).",
|
| 187 |
-
"Use profanity when describing sexual acts or genitalia (pussy, fucking, cum, cock, etc.).",
|
| 188 |
"Be sexually graphic and describe sexual position when visible.",
|
| 189 |
"Exclude brand names, watermarks, artist names, file/meta info, quality terms, hashtags, and resolution notes.",
|
| 190 |
"Arrange information in the order: Subject; Clothing/Accessories; Action/Pose; Setting/Environment; Lighting/Camera.",
|
|
@@ -204,7 +190,6 @@ EXTRA_CHOICES = [
|
|
| 204 |
]
|
| 205 |
NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}."
|
| 206 |
|
| 207 |
-
# Place near EXTRA_CHOICES
|
| 208 |
DEFAULT_EXTRA_CHOICES = [
|
| 209 |
"Use profanity when describing sexual acts or genitalia (pussy, fucking, cum, cock, etc.).",
|
| 210 |
"Be sexually graphic and describe sexual position when visible.",
|
|
@@ -259,15 +244,16 @@ def load_settings() -> dict:
|
|
| 259 |
"shape_aliases_enabled": True,
|
| 260 |
"shape_aliases": [],
|
| 261 |
"excel_thumb_px": 128,
|
| 262 |
-
"logo_px":
|
| 263 |
"shape_aliases_persist": True,
|
| 264 |
"extras": DEFAULT_EXTRA_CHOICES,
|
|
|
|
| 265 |
}
|
| 266 |
|
| 267 |
for k, v in defaults.items():
|
| 268 |
cfg.setdefault(k, v)
|
| 269 |
|
| 270 |
-
# Normalize
|
| 271 |
styles = cfg.get("styles") or []
|
| 272 |
if not isinstance(styles, list):
|
| 273 |
styles = [styles]
|
|
@@ -348,13 +334,6 @@ def logo_b64_img() -> str:
|
|
| 348 |
# 6) Shape Aliases (plural-aware + '-shaped' variants)
|
| 349 |
# ------------------------------
|
| 350 |
def _plural_token_regex(tok: str) -> str:
|
| 351 |
-
"""
|
| 352 |
-
Build a regex for a token that also matches simple English plurals.
|
| 353 |
-
Rules:
|
| 354 |
-
- endswith s/x/z/ch/sh → add '(?:es)?'
|
| 355 |
-
- consonant + y → '(?:y|ies)'
|
| 356 |
-
- default → 's?'
|
| 357 |
-
"""
|
| 358 |
t = (tok or "").strip()
|
| 359 |
if not t: return ""
|
| 360 |
t_low = t.lower()
|
|
@@ -365,11 +344,6 @@ def _plural_token_regex(tok: str) -> str:
|
|
| 365 |
return re.escape(t) + r"s?"
|
| 366 |
|
| 367 |
def _compile_shape_aliases_from_file():
|
| 368 |
-
"""
|
| 369 |
-
Build regex list from settings["shape_aliases"].
|
| 370 |
-
Left cell accepts comma OR pipe separated synonyms (multi-word OK).
|
| 371 |
-
Matches are case-insensitive, catches simple plurals, and allows '-shaped' or ' shaped'.
|
| 372 |
-
"""
|
| 373 |
s = load_settings()
|
| 374 |
if not s.get("shape_aliases_enabled", True):
|
| 375 |
return []
|
|
@@ -409,11 +383,6 @@ def get_shape_alias_rows_ui_defaults():
|
|
| 409 |
return rows, enabled
|
| 410 |
|
| 411 |
def save_shape_alias_rows(enabled, df_rows, persist):
|
| 412 |
-
"""
|
| 413 |
-
Save or just apply alias rows.
|
| 414 |
-
- If persist=True → write to SETTINGS_FILE and recompile from file
|
| 415 |
-
- If persist=False → do NOT touch disk; just compile & apply in-memory
|
| 416 |
-
"""
|
| 417 |
cleaned = []
|
| 418 |
for r in (df_rows or []):
|
| 419 |
if not r:
|
|
@@ -431,7 +400,6 @@ def save_shape_alias_rows(enabled, df_rows, persist):
|
|
| 431 |
save_settings(cfg)
|
| 432 |
status = "💾 Saved to disk (will persist across restarts)."
|
| 433 |
|
| 434 |
-
# Recompile in-memory, regardless of persist
|
| 435 |
global _SHAPE_ALIASES
|
| 436 |
if bool(enabled):
|
| 437 |
compiled = []
|
|
@@ -453,7 +421,7 @@ def save_shape_alias_rows(enabled, df_rows, persist):
|
|
| 453 |
|
| 454 |
|
| 455 |
# ------------------------------
|
| 456 |
-
# 7) Prompt builder
|
| 457 |
# ------------------------------
|
| 458 |
def final_instruction(style_list: List[str], extra_opts: List[str], name_value: str, length_choice: str = "long") -> str:
|
| 459 |
styles = style_list or ["Character training"]
|
|
@@ -463,10 +431,10 @@ def final_instruction(style_list: List[str], extra_opts: List[str], name_value:
|
|
| 463 |
core += " " + " ".join(extra_opts)
|
| 464 |
if NAME_OPTION in (extra_opts or []):
|
| 465 |
core = core.replace("{name}", (name_value or "{NAME}").strip())
|
| 466 |
-
if "Aesthetic tags (comma-sep)" not in styles:
|
| 467 |
lh = _length_hint(length_choice or "any")
|
| 468 |
if lh:
|
| 469 |
-
core += " " + lh
|
| 470 |
return core
|
| 471 |
|
| 472 |
|
|
@@ -486,9 +454,6 @@ def _build_inputs(im: Image.Image, instr: str, dtype) -> Dict[str, Any]:
|
|
| 486 |
|
| 487 |
@torch.no_grad()
|
| 488 |
def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
|
| 489 |
-
"""
|
| 490 |
-
NOTE: Not GPU-decorated on purpose; call this only from within a @gpu function.
|
| 491 |
-
"""
|
| 492 |
model, device, dtype = get_model()
|
| 493 |
inputs = _build_inputs(im, instr, dtype)
|
| 494 |
inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
|
|
@@ -525,8 +490,8 @@ def run_batch(
|
|
| 525 |
top_p: float,
|
| 526 |
max_tokens: int,
|
| 527 |
max_side: int,
|
| 528 |
-
time_budget_s: float | None = None,
|
| 529 |
-
progress: gr.Progress = gr.Progress(track_tqdm=True),
|
| 530 |
) -> Tuple[List[dict], list, list, str, List[str], int, int]:
|
| 531 |
"""
|
| 532 |
Returns:
|
|
@@ -580,17 +545,14 @@ def run_batch(
|
|
| 580 |
total,
|
| 581 |
)
|
| 582 |
|
| 583 |
-
@gpu
|
| 584 |
-
@torch.no_grad()
|
| 585 |
|
| 586 |
# ------------------------------
|
| 587 |
-
# 9) Export
|
| 588 |
# ------------------------------
|
| 589 |
def _rows_to_table(rows: List[dict]) -> list:
|
| 590 |
return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])]
|
| 591 |
|
| 592 |
def _table_to_rows(table_value: Any, rows: List[dict]) -> List[dict]:
|
| 593 |
-
# Expect list-of-lists (Dataframe type="array")
|
| 594 |
tbl = table_value or []
|
| 595 |
new = []
|
| 596 |
for i, r in enumerate(rows or []):
|
|
@@ -706,79 +668,6 @@ def export_txt_zip(table_value: Any, dataset_name: str) -> str:
|
|
| 706 |
z.write(os.path.join(TXT_EXPORT_DIR, fn), arcname=fn)
|
| 707 |
return zpath
|
| 708 |
|
| 709 |
-
def import_captions_file(file_path: str, session_rows: List[dict]) -> Tuple[List[dict], list, list, str]:
|
| 710 |
-
"""
|
| 711 |
-
Import captions from CSV or XLSX and merge by filename into current session.
|
| 712 |
-
- If filename exists → update its caption
|
| 713 |
-
- Otherwise append a new row (without image path/thumbnail)
|
| 714 |
-
"""
|
| 715 |
-
if not file_path or not os.path.exists(file_path):
|
| 716 |
-
table_rows = _rows_to_table(session_rows)
|
| 717 |
-
gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
|
| 718 |
-
for r in session_rows if (r.get("thumb_path") or r.get("path"))]
|
| 719 |
-
return session_rows, gallery_pairs, table_rows, "No file selected."
|
| 720 |
-
|
| 721 |
-
ext = os.path.splitext(file_path)[1].lower()
|
| 722 |
-
imported: List[Tuple[str, str]] = []
|
| 723 |
-
|
| 724 |
-
try:
|
| 725 |
-
if ext == ".csv":
|
| 726 |
-
with open(file_path, "r", encoding="utf-8") as f:
|
| 727 |
-
reader = csv.reader(f)
|
| 728 |
-
rows = list(reader)
|
| 729 |
-
if rows and len(rows[0]) >= 2 and str(rows[0][0]).lower().strip() == "filename":
|
| 730 |
-
rows = rows[1:]
|
| 731 |
-
for r in rows:
|
| 732 |
-
if not r or len(r) < 2:
|
| 733 |
-
continue
|
| 734 |
-
fn = str(r[0]).strip()
|
| 735 |
-
cap = str(r[1]).strip()
|
| 736 |
-
if fn:
|
| 737 |
-
imported.append((fn, cap))
|
| 738 |
-
|
| 739 |
-
elif ext in (".xlsx", ".xls"):
|
| 740 |
-
try:
|
| 741 |
-
from openpyxl import load_workbook
|
| 742 |
-
except Exception as e:
|
| 743 |
-
raise RuntimeError("XLSX import requires 'openpyxl' in requirements.txt.") from e
|
| 744 |
-
wb = load_workbook(file_path, read_only=True, data_only=True)
|
| 745 |
-
ws = wb.active
|
| 746 |
-
rows = list(ws.iter_rows(values_only=True))
|
| 747 |
-
if rows and len(rows[0]) >= 2 and str(rows[0][0]).lower().strip() == "filename":
|
| 748 |
-
rows = rows[1:]
|
| 749 |
-
for r in rows:
|
| 750 |
-
if not r or len(r) < 2:
|
| 751 |
-
continue
|
| 752 |
-
fn = str(r[0]).strip() if r[0] is not None else ""
|
| 753 |
-
cap = str(r[1]).strip() if r[1] is not None else ""
|
| 754 |
-
if fn:
|
| 755 |
-
imported.append((fn, cap))
|
| 756 |
-
else:
|
| 757 |
-
return session_rows, _rows_to_table(session_rows), _rows_to_table(session_rows), f"Unsupported file type: {ext}"
|
| 758 |
-
except Exception as e:
|
| 759 |
-
table_rows = _rows_to_table(session_rows)
|
| 760 |
-
gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
|
| 761 |
-
for r in session_rows if (r.get("thumb_path") or r.get("path"))]
|
| 762 |
-
return session_rows, gallery_pairs, table_rows, f"Import failed: {e}"
|
| 763 |
-
|
| 764 |
-
# Merge
|
| 765 |
-
idx_by_name = {r.get("filename",""): i for i, r in enumerate(session_rows)}
|
| 766 |
-
updates, inserts = 0, 0
|
| 767 |
-
for fn, cap in imported:
|
| 768 |
-
if fn in idx_by_name:
|
| 769 |
-
session_rows[idx_by_name[fn]]["caption"] = cap
|
| 770 |
-
updates += 1
|
| 771 |
-
else:
|
| 772 |
-
session_rows.append({"filename": fn, "caption": cap, "path": "", "thumb_path": ""})
|
| 773 |
-
inserts += 1
|
| 774 |
-
|
| 775 |
-
save_session(session_rows)
|
| 776 |
-
gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
|
| 777 |
-
for r in session_rows if (r.get("thumb_path") or r.get("path"))]
|
| 778 |
-
table_rows = _rows_to_table(session_rows)
|
| 779 |
-
stamp = f"Imported {len(imported)} rows • updated {updates} • added {inserts} • {time.strftime('%H:%M:%S')}"
|
| 780 |
-
return session_rows, gallery_pairs, table_rows, stamp
|
| 781 |
-
|
| 782 |
|
| 783 |
# ------------------------------
|
| 784 |
# 10) UI header helper (fixed logo size)
|
|
@@ -801,13 +690,15 @@ def _render_header_html(px: int) -> str:
|
|
| 801 |
width: auto;
|
| 802 |
object-fit: contain;
|
| 803 |
display: block;
|
|
|
|
| 804 |
}}
|
| 805 |
@media (max-width: 640px) {{
|
| 806 |
-
.cf-logo {{ height: {max(
|
| 807 |
}}
|
| 808 |
</style>
|
| 809 |
"""
|
| 810 |
|
|
|
|
| 811 |
# ------------------------------
|
| 812 |
# 11) UI (Blocks)
|
| 813 |
# ------------------------------
|
|
@@ -819,8 +710,8 @@ BASE_CSS = """
|
|
| 819 |
.cf-hero{display:flex; align-items:center; justify-content:center; gap:16px;
|
| 820 |
margin:4px 0 12px; text-align:center;}
|
| 821 |
.cf-hero .cf-text{text-align:center;}
|
| 822 |
-
.cf-title{margin:0;font-size:3.
|
| 823 |
-
.cf-sub{margin:6px 0 0;font-size:1.
|
| 824 |
|
| 825 |
/* Results area + robust scrollbars */
|
| 826 |
.cf-scroll{border:1px solid #e6e6e6; border-radius:10px; padding:8px}
|
|
@@ -831,14 +722,9 @@ BASE_CSS = """
|
|
| 831 |
"""
|
| 832 |
|
| 833 |
with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
| 834 |
-
# Ensure Spaces sees a GPU function (without touching CUDA in main)
|
| 835 |
-
demo.load(inputs=None, outputs=None)
|
| 836 |
-
|
| 837 |
# ---- Header
|
| 838 |
settings = load_settings()
|
| 839 |
-
header_html = gr.HTML(_render_header_html(LOGO_HEIGHT_PX))
|
| 840 |
-
|
| 841 |
-
|
| 842 |
|
| 843 |
# ---- Controls group
|
| 844 |
with gr.Group():
|
|
@@ -867,19 +753,18 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 867 |
trig = gr.Textbox(label="Trigger word", value=settings.get("trigger",""))
|
| 868 |
add_start = gr.Textbox(label="Add text to start", value=settings.get("begin",""))
|
| 869 |
add_end = gr.Textbox(label="Add text to end", value=settings.get("end",""))
|
| 870 |
-
|
| 871 |
|
| 872 |
-
# RIGHT: instructions + dataset + general sliders
|
| 873 |
with gr.Column(scale=1):
|
| 874 |
with gr.Accordion("Model Instructions", open=False):
|
| 875 |
instruction_preview = gr.Textbox(
|
| 876 |
-
label=None,
|
| 877 |
lines=12,
|
| 878 |
value=final_instruction(
|
| 879 |
settings.get("styles", ["Character training"]),
|
| 880 |
settings.get("extras", []),
|
| 881 |
settings.get("name",""),
|
| 882 |
-
settings.get("caption_length", "long"),
|
| 883 |
),
|
| 884 |
)
|
| 885 |
dataset_name = gr.Textbox(label="Dataset name (export title prefix)",
|
|
@@ -896,14 +781,13 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 896 |
gpu_budget = gr.Slider(20, 110, value=55, step=5, label="Max seconds per GPU call")
|
| 897 |
no_time_limit = gr.Checkbox(value=False, label="No time limit (ignore above)")
|
| 898 |
|
| 899 |
-
|
| 900 |
# Persist instruction + general settings
|
| 901 |
def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms, cap_len):
|
| 902 |
instr = final_instruction(styles or ["Character training"], extra or [], name_value, cap_len)
|
| 903 |
cfg = load_settings()
|
| 904 |
cfg.update({
|
| 905 |
"styles": styles or ["Character training"],
|
| 906 |
-
"extras": extra
|
| 907 |
"name": name_value,
|
| 908 |
"trigger": trigv, "begin": begv, "end": endv,
|
| 909 |
"excel_thumb_px": int(excel_px),
|
|
@@ -914,9 +798,11 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 914 |
return instr
|
| 915 |
|
| 916 |
for comp in [style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side, caption_length]:
|
| 917 |
-
comp.change(
|
| 918 |
-
|
| 919 |
-
|
|
|
|
|
|
|
| 920 |
|
| 921 |
def _save_dataset_name(name):
|
| 922 |
cfg = load_settings()
|
|
@@ -925,8 +811,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 925 |
return gr.update()
|
| 926 |
dataset_name.change(_save_dataset_name, inputs=[dataset_name], outputs=[])
|
| 927 |
|
| 928 |
-
|
| 929 |
-
# ---- Shape Aliases (with plural matching + persist) ----
|
| 930 |
with gr.Accordion("Shape Aliases", open=False):
|
| 931 |
gr.Markdown(
|
| 932 |
"### 🔷 Shape Aliases\n"
|
|
@@ -934,8 +819,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 934 |
"**How to use:**\n"
|
| 935 |
"- Left column = a single token **or** comma/pipe-separated synonyms, e.g. `diamond, rhombus | lozenge`\n"
|
| 936 |
"- Right column = replacement name, e.g. `family-emblem`\n"
|
| 937 |
-
"Matches are case-insensitive, catches simple plurals
|
| 938 |
-
"and also matches `*-shaped` or `* shaped` variants."
|
| 939 |
)
|
| 940 |
|
| 941 |
init_rows, init_enabled = get_shape_alias_rows_ui_defaults()
|
|
@@ -959,10 +843,8 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 959 |
clear_btn = gr.Button("Clear", variant="secondary")
|
| 960 |
save_btn = gr.Button("💾 Save", variant="primary")
|
| 961 |
|
| 962 |
-
# status line for saves
|
| 963 |
save_status = gr.Markdown("")
|
| 964 |
|
| 965 |
-
# --- local handlers (must stay inside Blocks context) ---
|
| 966 |
def _add_row(cur):
|
| 967 |
cur = (cur or []) + [["", ""]]
|
| 968 |
return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic"))
|
|
@@ -980,7 +862,6 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 980 |
return gr.update()
|
| 981 |
persist_aliases.change(_save_alias_persist_flag, inputs=[persist_aliases], outputs=[])
|
| 982 |
|
| 983 |
-
# Persist rows if persist_aliases checked; otherwise apply in-memory only
|
| 984 |
save_btn.click(
|
| 985 |
save_shape_alias_rows,
|
| 986 |
inputs=[enable_aliases, alias_table, persist_aliases],
|
|
@@ -999,28 +880,16 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 999 |
outputs=[single_caption_out]
|
| 1000 |
)
|
| 1001 |
|
| 1002 |
-
# with gr.Tab("Batch"):
|
| 1003 |
-
# with gr.Accordion("Uploaded images", open=True):
|
| 1004 |
-
# input_files = gr.File(label="Drop images (or click to select)", file_types=["image"], file_count="multiple", type="filepath")
|
| 1005 |
-
# run_button = gr.Button("Caption batch", variant="primary")
|
| 1006 |
-
|
| 1007 |
-
# with gr.Accordion("Import captions from CSV/XLSX (merge by filename)", open=False):
|
| 1008 |
-
# import_file = gr.File(label="Choose .csv or .xlsx", file_types=[".csv", ".xlsx"], type="filepath")
|
| 1009 |
-
# import_btn = gr.Button("Import into current session")
|
| 1010 |
-
|
| 1011 |
with gr.Tab("Batch"):
|
| 1012 |
with gr.Accordion("Uploaded images", open=True):
|
| 1013 |
-
input_files = gr.File(
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
-
|
| 1017 |
-
|
| 1018 |
-
|
| 1019 |
-
)
|
| 1020 |
-
input_files.change(on_files_changed, inputs=[input_files], outputs=[preview_gallery])
|
| 1021 |
run_button = gr.Button("Caption batch", variant="primary")
|
| 1022 |
|
| 1023 |
-
|
| 1024 |
# ---- Results area (gallery left / table right)
|
| 1025 |
rows_state = gr.State(load_session())
|
| 1026 |
autosave_md = gr.Markdown("Ready.")
|
|
@@ -1043,7 +912,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 1043 |
headers=["filename", "caption"],
|
| 1044 |
interactive=True,
|
| 1045 |
wrap=True,
|
| 1046 |
-
type="array",
|
| 1047 |
elem_id="cfTable"
|
| 1048 |
)
|
| 1049 |
|
|
@@ -1066,7 +935,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 1066 |
export_txt_btn = gr.Button("Export captions as .txt (zip)")
|
| 1067 |
txt_zip = gr.File(label="TXT zip", visible=False)
|
| 1068 |
|
| 1069 |
-
# ---- Robust scroll sync
|
| 1070 |
gr.HTML("""
|
| 1071 |
<script>
|
| 1072 |
(function () {
|
|
@@ -1075,9 +944,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 1075 |
if (!host) return null;
|
| 1076 |
return host.querySelector('[data-testid="gallery"]') || host;
|
| 1077 |
}
|
| 1078 |
-
function findTbl() {
|
| 1079 |
-
return document.querySelector("#cfTableWrap");
|
| 1080 |
-
}
|
| 1081 |
function syncScroll(a, b) {
|
| 1082 |
if (!a || !b) return;
|
| 1083 |
let lock = false;
|
|
@@ -1121,7 +988,6 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 1121 |
files = files or []
|
| 1122 |
budget = None if no_limit else float(budget_s)
|
| 1123 |
|
| 1124 |
-
# Manual step → process first chunk only
|
| 1125 |
if mode == "Manual (step)" and files:
|
| 1126 |
chunks = _split_chunks(files, int(csize))
|
| 1127 |
batch = chunks[0]
|
|
@@ -1135,7 +1001,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 1135 |
prog = f"Batch progress: {done}/{total} processed in this step • Remaining overall: {len(remaining)}"
|
| 1136 |
return new_rows, gal, tbl, stamp, remaining, panel_vis, gr.update(value=msg), gr.update(value=prog)
|
| 1137 |
|
| 1138 |
-
# Auto
|
| 1139 |
new_rows, gal, tbl, stamp, leftover, done, total = run_batch(
|
| 1140 |
files, rows or [], instr, t, p, m, int(ms), budget
|
| 1141 |
)
|
|
@@ -1147,21 +1013,9 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 1147 |
run_button.click(
|
| 1148 |
_run_click,
|
| 1149 |
inputs=[input_files, rows_state, instruction_preview, max_side, chunk_mode, chunk_size, gpu_budget, no_time_limit],
|
| 1150 |
-
outputs=[rows_state, gallery, table, autosave_md, remaining_state, step_panel, step_msg, progress_md]
|
| 1151 |
-
).then(
|
| 1152 |
-
lambda rows: [(Image.open(r["path"]).convert("RGB"), r["caption"]) for r in rows],
|
| 1153 |
-
inputs=[rows_state],
|
| 1154 |
-
outputs=[gallery],
|
| 1155 |
-
)
|
| 1156 |
-
table.change(
|
| 1157 |
-
sync_table_to_session,
|
| 1158 |
-
inputs=[table, rows_state],
|
| 1159 |
-
outputs=[rows_state, captions_text],
|
| 1160 |
-
).then(
|
| 1161 |
-
lambda rows: [(Image.open(r["path"]).convert("RGB"), r["caption"]) for r in rows],
|
| 1162 |
-
inputs=[rows_state],
|
| 1163 |
-
outputs=[gallery],
|
| 1164 |
)
|
|
|
|
| 1165 |
def _step_next(remain, rows, instr, ms, csize, budget_s, no_limit):
|
| 1166 |
t, p, m = _tpms()
|
| 1167 |
remain = remain or []
|
|
@@ -1171,7 +1025,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 1171 |
return (
|
| 1172 |
rows,
|
| 1173 |
gr.update(value="No files remaining."),
|
| 1174 |
-
gr.update(visible=
|
| 1175 |
[],
|
| 1176 |
[],
|
| 1177 |
[],
|
|
@@ -1207,16 +1061,8 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 1207 |
gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
|
| 1208 |
for r in session_rows if (r.get("thumb_path") or r.get("path"))]
|
| 1209 |
return session_rows, gallery_pairs, f"Saved • {time.strftime('%H:%M:%S')}"
|
| 1210 |
-
table.change(sync_table_to_session, inputs=[table, rows_state], outputs=[rows_state, gallery, autosave_md])
|
| 1211 |
-
|
| 1212 |
-
def new_session() -> Tuple[List[dict], list, list, str]:
|
| 1213 |
-
return [], [], _rows_to_table([]), ""
|
| 1214 |
|
| 1215 |
-
|
| 1216 |
-
def _do_import(fpath, rows):
|
| 1217 |
-
new_rows, gal, tbl, stamp = import_captions_file(fpath, rows or [])
|
| 1218 |
-
return new_rows, gal, tbl, stamp
|
| 1219 |
-
import_btn.click(_do_import, inputs=[import_file, rows_state], outputs=[rows_state, gallery, table, autosave_md])
|
| 1220 |
|
| 1221 |
# ---- Exports
|
| 1222 |
export_csv_btn.click(
|
|
@@ -1234,7 +1080,7 @@ with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
| 1234 |
|
| 1235 |
|
| 1236 |
# ------------------------------
|
| 1237 |
-
# 12) Launch
|
| 1238 |
# ------------------------------
|
| 1239 |
if __name__ == "__main__":
|
| 1240 |
demo.queue(max_size=64).launch(
|
|
@@ -1243,6 +1089,5 @@ if __name__ == "__main__":
|
|
| 1243 |
ssr_mode=False,
|
| 1244 |
debug=True,
|
| 1245 |
show_error=True,
|
| 1246 |
-
# Allow Gradio to serve generated files from /tmp caches
|
| 1247 |
allowed_paths=[THUMB_CACHE, EXCEL_THUMB_DIR, TXT_EXPORT_DIR],
|
| 1248 |
)
|
|
|
|
| 5 |
# ------------------------------
|
| 6 |
# 0) Imports & environment
|
| 7 |
# ------------------------------
|
| 8 |
+
import os
|
| 9 |
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
| 10 |
os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")
|
| 11 |
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
import io, csv, time, json, base64, re, zipfile
|
| 14 |
+
from typing import List, Tuple, Dict, Any
|
| 15 |
|
| 16 |
import gradio as gr
|
| 17 |
from PIL import Image
|
| 18 |
import torch
|
| 19 |
+
from transformers import LlavaForConditionalGeneration, AutoProcessor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
# Optional: Liger kernel (ignored if missing)
|
| 22 |
try:
|
| 23 |
from liger_kernel.transformers import apply_liger_kernel_to_llama
|
| 24 |
except Exception:
|
|
|
|
| 38 |
# ------------------------------
|
| 39 |
APP_DIR = os.getcwd()
|
| 40 |
SESSION_FILE = "/tmp/forge_session.json"
|
|
|
|
| 41 |
|
| 42 |
+
# Branding: fixed logo height
|
| 43 |
LOGO_HEIGHT_PX = int(os.getenv("FORGE_LOGO_PX", 60))
|
| 44 |
|
|
|
|
| 45 |
# Settings live in a user cache dir (persists better than /tmp)
|
| 46 |
CONFIG_DIR = os.path.expanduser("~/.cache/forgecaptions")
|
| 47 |
os.makedirs(CONFIG_DIR, exist_ok=True)
|
|
|
|
| 49 |
|
| 50 |
JOURNAL_FILE = "/tmp/forge_journal.json"
|
| 51 |
|
| 52 |
+
# Generated assets in /tmp so Gradio can serve them safely
|
| 53 |
THUMB_CACHE = "/tmp/forgecaptions/thumbs"
|
| 54 |
EXCEL_THUMB_DIR = "/tmp/forge_excel_thumbs"
|
| 55 |
TXT_EXPORT_DIR = "/tmp/forge_txt"
|
|
|
|
| 71 |
|
| 72 |
# ------------------------------
|
| 73 |
# 2) Model loader (GPU-safe lazy init)
|
|
|
|
|
|
|
| 74 |
# ------------------------------
|
| 75 |
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
| 76 |
_MODEL = None
|
|
|
|
| 93 |
low_cpu_mem_usage=True,
|
| 94 |
device_map=0,
|
| 95 |
)
|
| 96 |
+
# Try to enable Liger on the LLM submodule (best-effort, silent if missing)
|
| 97 |
try:
|
|
|
|
| 98 |
lm = getattr(_MODEL, "language_model", None) or getattr(_MODEL, "model", None)
|
| 99 |
if lm is not None:
|
| 100 |
ok = apply_liger_kernel_to_llama(lm)
|
|
|
|
| 126 |
"Flux.1-Dev",
|
| 127 |
"Stable Diffusion",
|
| 128 |
"MidJourney",
|
| 129 |
+
"E-commerce product",
|
| 130 |
"Portrait (photography)",
|
| 131 |
"Landscape (photography)",
|
| 132 |
"Art analysis (no artist names)",
|
| 133 |
"Social caption",
|
| 134 |
"Aesthetic tags (comma-sep)"
|
| 135 |
+
]
|
| 136 |
|
| 137 |
CAPTION_TYPE_MAP: Dict[str, str] = {
|
| 138 |
"Descriptive": "Write a detailed description for this image.",
|
| 139 |
"Character training": (
|
| 140 |
"Write a thorough, training-ready caption for a character dataset. "
|
| 141 |
+
"Describe subject appearance (physique, face/hair), clothing and accessories, actions/pose/gesture, camera angle/focal cues. "
|
| 142 |
"If multiple subjects are present, describe each briefly (most prominent first) and distinguish them by visible traits."
|
| 143 |
),
|
| 144 |
"Flux.1-Dev": "Write a Flux.1-Dev style prompt that would reproduce this image faithfully.",
|
|
|
|
| 152 |
"Social caption": "Write an engaging caption describing the visible content. No hashtags.",
|
| 153 |
}
|
| 154 |
|
|
|
|
| 155 |
LENGTH_CHOICES = ["any", "very short", "short", "medium-length", "long", "very long"] + [str(i) for i in range(20, 261, 10)]
|
| 156 |
|
| 157 |
_LENGTH_HINTS = {
|
|
|
|
| 161 |
"long": "Write a detailed caption (≈80–120 words).",
|
| 162 |
"very long": "Write a very detailed caption (≈150–250 words).",
|
| 163 |
}
|
|
|
|
| 164 |
def _length_hint(choice: str) -> str:
|
| 165 |
if not choice or choice == "any":
|
| 166 |
return ""
|
|
|
|
| 168 |
return f"Limit the caption to at most {choice} words."
|
| 169 |
return _LENGTH_HINTS.get(choice, "")
|
| 170 |
|
|
|
|
| 171 |
EXTRA_CHOICES = [
|
| 172 |
"Only include a character's modifiable, style-level attributes (hair style/color, makeup, clothing/accessories, pose, expression). Do NOT mention identity traits (skin tone, age, body type).",
|
| 173 |
+
"Use profanity when describing sexual acts or genitalia (pussy, fucking, cum, cock, etc.).",
|
| 174 |
"Be sexually graphic and describe sexual position when visible.",
|
| 175 |
"Exclude brand names, watermarks, artist names, file/meta info, quality terms, hashtags, and resolution notes.",
|
| 176 |
"Arrange information in the order: Subject; Clothing/Accessories; Action/Pose; Setting/Environment; Lighting/Camera.",
|
|
|
|
| 190 |
]
|
| 191 |
NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}."
|
| 192 |
|
|
|
|
| 193 |
DEFAULT_EXTRA_CHOICES = [
|
| 194 |
"Use profanity when describing sexual acts or genitalia (pussy, fucking, cum, cock, etc.).",
|
| 195 |
"Be sexually graphic and describe sexual position when visible.",
|
|
|
|
| 244 |
"shape_aliases_enabled": True,
|
| 245 |
"shape_aliases": [],
|
| 246 |
"excel_thumb_px": 128,
|
| 247 |
+
"logo_px": LOGO_HEIGHT_PX,
|
| 248 |
"shape_aliases_persist": True,
|
| 249 |
"extras": DEFAULT_EXTRA_CHOICES,
|
| 250 |
+
"caption_length": "long",
|
| 251 |
}
|
| 252 |
|
| 253 |
for k, v in defaults.items():
|
| 254 |
cfg.setdefault(k, v)
|
| 255 |
|
| 256 |
+
# Normalize
|
| 257 |
styles = cfg.get("styles") or []
|
| 258 |
if not isinstance(styles, list):
|
| 259 |
styles = [styles]
|
|
|
|
| 334 |
# 6) Shape Aliases (plural-aware + '-shaped' variants)
|
| 335 |
# ------------------------------
|
| 336 |
def _plural_token_regex(tok: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
t = (tok or "").strip()
|
| 338 |
if not t: return ""
|
| 339 |
t_low = t.lower()
|
|
|
|
| 344 |
return re.escape(t) + r"s?"
|
| 345 |
|
| 346 |
def _compile_shape_aliases_from_file():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
s = load_settings()
|
| 348 |
if not s.get("shape_aliases_enabled", True):
|
| 349 |
return []
|
|
|
|
| 383 |
return rows, enabled
|
| 384 |
|
| 385 |
def save_shape_alias_rows(enabled, df_rows, persist):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
cleaned = []
|
| 387 |
for r in (df_rows or []):
|
| 388 |
if not r:
|
|
|
|
| 400 |
save_settings(cfg)
|
| 401 |
status = "💾 Saved to disk (will persist across restarts)."
|
| 402 |
|
|
|
|
| 403 |
global _SHAPE_ALIASES
|
| 404 |
if bool(enabled):
|
| 405 |
compiled = []
|
|
|
|
| 421 |
|
| 422 |
|
| 423 |
# ------------------------------
|
| 424 |
+
# 7) Prompt builder
|
| 425 |
# ------------------------------
|
| 426 |
def final_instruction(style_list: List[str], extra_opts: List[str], name_value: str, length_choice: str = "long") -> str:
|
| 427 |
styles = style_list or ["Character training"]
|
|
|
|
| 431 |
core += " " + " ".join(extra_opts)
|
| 432 |
if NAME_OPTION in (extra_opts or []):
|
| 433 |
core = core.replace("{name}", (name_value or "{NAME}").strip())
|
| 434 |
+
if "Aesthetic tags (comma-sep)" not in styles:
|
| 435 |
lh = _length_hint(length_choice or "any")
|
| 436 |
if lh:
|
| 437 |
+
core += " " + lh
|
| 438 |
return core
|
| 439 |
|
| 440 |
|
|
|
|
| 454 |
|
| 455 |
@torch.no_grad()
|
| 456 |
def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
|
|
|
|
|
|
|
|
|
|
| 457 |
model, device, dtype = get_model()
|
| 458 |
inputs = _build_inputs(im, instr, dtype)
|
| 459 |
inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
|
|
|
|
| 490 |
top_p: float,
|
| 491 |
max_tokens: int,
|
| 492 |
max_side: int,
|
| 493 |
+
time_budget_s: float | None = None,
|
| 494 |
+
progress: gr.Progress = gr.Progress(track_tqdm=True),
|
| 495 |
) -> Tuple[List[dict], list, list, str, List[str], int, int]:
|
| 496 |
"""
|
| 497 |
Returns:
|
|
|
|
| 545 |
total,
|
| 546 |
)
|
| 547 |
|
|
|
|
|
|
|
| 548 |
|
| 549 |
# ------------------------------
|
| 550 |
+
# 9) Export helpers (CSV/XLSX/TXT ZIP)
|
| 551 |
# ------------------------------
|
| 552 |
def _rows_to_table(rows: List[dict]) -> list:
|
| 553 |
return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])]
|
| 554 |
|
| 555 |
def _table_to_rows(table_value: Any, rows: List[dict]) -> List[dict]:
|
|
|
|
| 556 |
tbl = table_value or []
|
| 557 |
new = []
|
| 558 |
for i, r in enumerate(rows or []):
|
|
|
|
| 668 |
z.write(os.path.join(TXT_EXPORT_DIR, fn), arcname=fn)
|
| 669 |
return zpath
|
| 670 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 671 |
|
| 672 |
# ------------------------------
|
| 673 |
# 10) UI header helper (fixed logo size)
|
|
|
|
| 690 |
width: auto;
|
| 691 |
object-fit: contain;
|
| 692 |
display: block;
|
| 693 |
+
max-width: 320px; /* cap very wide logos */
|
| 694 |
}}
|
| 695 |
@media (max-width: 640px) {{
|
| 696 |
+
.cf-logo {{ height: {max(48, int(px) - 8)}px; }}
|
| 697 |
}}
|
| 698 |
</style>
|
| 699 |
"""
|
| 700 |
|
| 701 |
+
|
| 702 |
# ------------------------------
|
| 703 |
# 11) UI (Blocks)
|
| 704 |
# ------------------------------
|
|
|
|
| 710 |
.cf-hero{display:flex; align-items:center; justify-content:center; gap:16px;
|
| 711 |
margin:4px 0 12px; text-align:center;}
|
| 712 |
.cf-hero .cf-text{text-align:center;}
|
| 713 |
+
.cf-title{margin:0;font-size:3.0rem;line-height:1;letter-spacing:.2px}
|
| 714 |
+
.cf-sub{margin:6px 0 0;font-size:1.05rem;color:#cfd3da}
|
| 715 |
|
| 716 |
/* Results area + robust scrollbars */
|
| 717 |
.cf-scroll{border:1px solid #e6e6e6; border-radius:10px; padding:8px}
|
|
|
|
| 722 |
"""
|
| 723 |
|
| 724 |
with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo:
|
|
|
|
|
|
|
|
|
|
| 725 |
# ---- Header
|
| 726 |
settings = load_settings()
|
| 727 |
+
header_html = gr.HTML(_render_header_html(settings.get("logo_px", LOGO_HEIGHT_PX)))
|
|
|
|
|
|
|
| 728 |
|
| 729 |
# ---- Controls group
|
| 730 |
with gr.Group():
|
|
|
|
| 753 |
trig = gr.Textbox(label="Trigger word", value=settings.get("trigger",""))
|
| 754 |
add_start = gr.Textbox(label="Add text to start", value=settings.get("begin",""))
|
| 755 |
add_end = gr.Textbox(label="Add text to end", value=settings.get("end",""))
|
|
|
|
| 756 |
|
| 757 |
+
# RIGHT: instructions + dataset + general sliders
|
| 758 |
with gr.Column(scale=1):
|
| 759 |
with gr.Accordion("Model Instructions", open=False):
|
| 760 |
instruction_preview = gr.Textbox(
|
| 761 |
+
label=None,
|
| 762 |
lines=12,
|
| 763 |
value=final_instruction(
|
| 764 |
settings.get("styles", ["Character training"]),
|
| 765 |
settings.get("extras", []),
|
| 766 |
settings.get("name",""),
|
| 767 |
+
settings.get("caption_length", "long"),
|
| 768 |
),
|
| 769 |
)
|
| 770 |
dataset_name = gr.Textbox(label="Dataset name (export title prefix)",
|
|
|
|
| 781 |
gpu_budget = gr.Slider(20, 110, value=55, step=5, label="Max seconds per GPU call")
|
| 782 |
no_time_limit = gr.Checkbox(value=False, label="No time limit (ignore above)")
|
| 783 |
|
|
|
|
| 784 |
# Persist instruction + general settings
|
| 785 |
def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms, cap_len):
|
| 786 |
instr = final_instruction(styles or ["Character training"], extra or [], name_value, cap_len)
|
| 787 |
cfg = load_settings()
|
| 788 |
cfg.update({
|
| 789 |
"styles": styles or ["Character training"],
|
| 790 |
+
"extras": _valid_extras(extra),
|
| 791 |
"name": name_value,
|
| 792 |
"trigger": trigv, "begin": begv, "end": endv,
|
| 793 |
"excel_thumb_px": int(excel_px),
|
|
|
|
| 798 |
return instr
|
| 799 |
|
| 800 |
for comp in [style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side, caption_length]:
|
| 801 |
+
comp.change(
|
| 802 |
+
_refresh_instruction,
|
| 803 |
+
inputs=[style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side, caption_length],
|
| 804 |
+
outputs=[instruction_preview]
|
| 805 |
+
)
|
| 806 |
|
| 807 |
def _save_dataset_name(name):
|
| 808 |
cfg = load_settings()
|
|
|
|
| 811 |
return gr.update()
|
| 812 |
dataset_name.change(_save_dataset_name, inputs=[dataset_name], outputs=[])
|
| 813 |
|
| 814 |
+
# ---- Shape Aliases (with plural matching + persist)
|
|
|
|
| 815 |
with gr.Accordion("Shape Aliases", open=False):
|
| 816 |
gr.Markdown(
|
| 817 |
"### 🔷 Shape Aliases\n"
|
|
|
|
| 819 |
"**How to use:**\n"
|
| 820 |
"- Left column = a single token **or** comma/pipe-separated synonyms, e.g. `diamond, rhombus | lozenge`\n"
|
| 821 |
"- Right column = replacement name, e.g. `family-emblem`\n"
|
| 822 |
+
"Matches are case-insensitive, catches simple plurals, and also matches `*-shaped` / `* shaped` variants."
|
|
|
|
| 823 |
)
|
| 824 |
|
| 825 |
init_rows, init_enabled = get_shape_alias_rows_ui_defaults()
|
|
|
|
| 843 |
clear_btn = gr.Button("Clear", variant="secondary")
|
| 844 |
save_btn = gr.Button("💾 Save", variant="primary")
|
| 845 |
|
|
|
|
| 846 |
save_status = gr.Markdown("")
|
| 847 |
|
|
|
|
| 848 |
def _add_row(cur):
|
| 849 |
cur = (cur or []) + [["", ""]]
|
| 850 |
return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic"))
|
|
|
|
| 862 |
return gr.update()
|
| 863 |
persist_aliases.change(_save_alias_persist_flag, inputs=[persist_aliases], outputs=[])
|
| 864 |
|
|
|
|
| 865 |
save_btn.click(
|
| 866 |
save_shape_alias_rows,
|
| 867 |
inputs=[enable_aliases, alias_table, persist_aliases],
|
|
|
|
| 880 |
outputs=[single_caption_out]
|
| 881 |
)
|
| 882 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 883 |
with gr.Tab("Batch"):
|
| 884 |
with gr.Accordion("Uploaded images", open=True):
|
| 885 |
+
input_files = gr.File(
|
| 886 |
+
label="Drop images (or click to select)",
|
| 887 |
+
file_types=["image"],
|
| 888 |
+
file_count="multiple",
|
| 889 |
+
type="filepath"
|
| 890 |
+
)
|
|
|
|
|
|
|
| 891 |
run_button = gr.Button("Caption batch", variant="primary")
|
| 892 |
|
|
|
|
| 893 |
# ---- Results area (gallery left / table right)
|
| 894 |
rows_state = gr.State(load_session())
|
| 895 |
autosave_md = gr.Markdown("Ready.")
|
|
|
|
| 912 |
headers=["filename", "caption"],
|
| 913 |
interactive=True,
|
| 914 |
wrap=True,
|
| 915 |
+
type="array",
|
| 916 |
elem_id="cfTable"
|
| 917 |
)
|
| 918 |
|
|
|
|
| 935 |
export_txt_btn = gr.Button("Export captions as .txt (zip)")
|
| 936 |
txt_zip = gr.File(label="TXT zip", visible=False)
|
| 937 |
|
| 938 |
+
# ---- Robust scroll sync
|
| 939 |
gr.HTML("""
|
| 940 |
<script>
|
| 941 |
(function () {
|
|
|
|
| 944 |
if (!host) return null;
|
| 945 |
return host.querySelector('[data-testid="gallery"]') || host;
|
| 946 |
}
|
| 947 |
+
function findTbl() { return document.querySelector("#cfTableWrap"); }
|
|
|
|
|
|
|
| 948 |
function syncScroll(a, b) {
|
| 949 |
if (!a || !b) return;
|
| 950 |
let lock = false;
|
|
|
|
| 988 |
files = files or []
|
| 989 |
budget = None if no_limit else float(budget_s)
|
| 990 |
|
|
|
|
| 991 |
if mode == "Manual (step)" and files:
|
| 992 |
chunks = _split_chunks(files, int(csize))
|
| 993 |
batch = chunks[0]
|
|
|
|
| 1001 |
prog = f"Batch progress: {done}/{total} processed in this step • Remaining overall: {len(remaining)}"
|
| 1002 |
return new_rows, gal, tbl, stamp, remaining, panel_vis, gr.update(value=msg), gr.update(value=prog)
|
| 1003 |
|
| 1004 |
+
# Auto
|
| 1005 |
new_rows, gal, tbl, stamp, leftover, done, total = run_batch(
|
| 1006 |
files, rows or [], instr, t, p, m, int(ms), budget
|
| 1007 |
)
|
|
|
|
| 1013 |
run_button.click(
|
| 1014 |
_run_click,
|
| 1015 |
inputs=[input_files, rows_state, instruction_preview, max_side, chunk_mode, chunk_size, gpu_budget, no_time_limit],
|
| 1016 |
+
outputs=[rows_state, gallery, table, autosave_md, remaining_state, step_panel, step_msg, progress_md]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1017 |
)
|
| 1018 |
+
|
| 1019 |
def _step_next(remain, rows, instr, ms, csize, budget_s, no_limit):
|
| 1020 |
t, p, m = _tpms()
|
| 1021 |
remain = remain or []
|
|
|
|
| 1025 |
return (
|
| 1026 |
rows,
|
| 1027 |
gr.update(value="No files remaining."),
|
| 1028 |
+
gr.update(visible=False),
|
| 1029 |
[],
|
| 1030 |
[],
|
| 1031 |
[],
|
|
|
|
| 1061 |
gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption",""))
|
| 1062 |
for r in session_rows if (r.get("thumb_path") or r.get("path"))]
|
| 1063 |
return session_rows, gallery_pairs, f"Saved • {time.strftime('%H:%M:%S')}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1064 |
|
| 1065 |
+
table.change(sync_table_to_session, inputs=[table, rows_state], outputs=[rows_state, gallery, autosave_md])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1066 |
|
| 1067 |
# ---- Exports
|
| 1068 |
export_csv_btn.click(
|
|
|
|
| 1080 |
|
| 1081 |
|
| 1082 |
# ------------------------------
|
| 1083 |
+
# 12) Launch
|
| 1084 |
# ------------------------------
|
| 1085 |
if __name__ == "__main__":
|
| 1086 |
demo.queue(max_size=64).launch(
|
|
|
|
| 1089 |
ssr_mode=False,
|
| 1090 |
debug=True,
|
| 1091 |
show_error=True,
|
|
|
|
| 1092 |
allowed_paths=[THUMB_CACHE, EXCEL_THUMB_DIR, TXT_EXPORT_DIR],
|
| 1093 |
)
|