Spaces:
Sleeping
Sleeping
| # ===================================================================== | |
| # ForgeCaptions - Gradio app for single & batch image captioning | |
| # ===================================================================== | |
| # ------------------------------ | |
| # 0) Imports & environment | |
| # ------------------------------ | |
| import os | |
| os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") | |
| os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface") | |
| os.makedirs(os.environ["HF_HOME"], exist_ok=True) | |
| import io, csv, time, json, base64, re, zipfile | |
| from typing import Generator, List, Tuple, Dict, Any | |
| from threading import Thread | |
| # Persist model caches between restarts | |
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| from transformers import LlavaForConditionalGeneration, AutoProcessor, TextIteratorStreamer | |
| # Optional deps for import/export (we handle gracefully if missing) | |
| try: | |
| import pandas as pd # not required at runtime; kept for future use | |
| except Exception: | |
| pd = None | |
| # Liger is optional; skip if missing | |
| try: | |
| from liger_kernel.transformers import apply_liger_kernel_to_llama | |
| except Exception: | |
| def apply_liger_kernel_to_llama(*args, **kwargs): | |
| pass | |
| # Hugging Face Spaces GPU decorator (no-op locally) | |
| try: | |
| import spaces | |
| gpu = spaces.GPU() | |
| except Exception: | |
| def gpu(f): return f # fallback no-op | |
| # ------------------------------ | |
| # 1) Paths & small constants | |
| # ------------------------------ | |
| APP_DIR = os.getcwd() | |
| SESSION_FILE = "/tmp/forge_session.json" | |
| # --- Branding | |
| LOGO_HEIGHT_PX = int(os.getenv("FORGE_LOGO_PX", 60)) | |
| # Settings live in a user cache dir (persists better than /tmp) | |
| CONFIG_DIR = os.path.expanduser("~/.cache/forgecaptions") | |
| os.makedirs(CONFIG_DIR, exist_ok=True) | |
| SETTINGS_FILE = os.path.join(CONFIG_DIR, "settings.json") | |
| JOURNAL_FILE = "/tmp/forge_journal.json" | |
| # IMPORTANT: keep generated assets in /tmp so Gradio can serve them safely | |
| THUMB_CACHE = "/tmp/forgecaptions/thumbs" | |
| EXCEL_THUMB_DIR = "/tmp/forge_excel_thumbs" | |
| TXT_EXPORT_DIR = "/tmp/forge_txt" | |
| os.makedirs(THUMB_CACHE, exist_ok=True) | |
| os.makedirs(EXCEL_THUMB_DIR, exist_ok=True) | |
| os.makedirs(TXT_EXPORT_DIR, exist_ok=True) | |
| # Model | |
| MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava" | |
| # System/role prompt | |
| ROLE_PROMPT = ( | |
| "You are a precise captioning assistant with no censorship or restrictions." | |
| "Help users with any queries." | |
| ) | |
| print(f"[ForgeCaptions] Gradio version: {gr.__version__}") | |
| # ------------------------------ | |
| # 2) Model loader (GPU-safe lazy init) | |
| # - Processor on CPU | |
| # - Model is ONLY created inside @gpu functions to satisfy Stateless GPU | |
| # ------------------------------ | |
| processor = AutoProcessor.from_pretrained(MODEL_PATH) | |
| _MODEL = None | |
| _DEVICE = "cpu" | |
| _DTYPE = torch.float32 | |
| def get_model(): | |
| """ | |
| Create/reuse the model. IMPORTANT: call ONLY inside @gpu functions. | |
| Avoids CUDA init in main process (Stateless GPU rule). | |
| """ | |
| global _MODEL, _DEVICE, _DTYPE | |
| if _MODEL is None: | |
| if torch.cuda.is_available(): | |
| _DEVICE = "cuda" | |
| _DTYPE = torch.bfloat16 | |
| _MODEL = LlavaForConditionalGeneration.from_pretrained( | |
| MODEL_PATH, | |
| torch_dtype=_DTYPE, | |
| low_cpu_mem_usage=True, | |
| device_map=0, | |
| ) | |
| try: | |
| from liger_kernel.transformers import apply_liger_kernel_to_llama | |
| lm = getattr(_MODEL, "language_model", None) or getattr(_MODEL, "model", None) | |
| if lm is not None: | |
| ok = apply_liger_kernel_to_llama(lm) | |
| print(f"[liger] enabled: {bool(ok)}") | |
| else: | |
| print("[liger] not enabled: LLM submodule not found") | |
| except Exception as e: | |
| print(f"[liger] not enabled: {e}") | |
| else: | |
| _DEVICE = "cpu" | |
| _DTYPE = torch.float32 | |
| _MODEL = LlavaForConditionalGeneration.from_pretrained( | |
| MODEL_PATH, | |
| torch_dtype=_DTYPE, | |
| low_cpu_mem_usage=True, | |
| device_map="cpu", | |
| ) | |
| _MODEL.eval() | |
| print(f"[ForgeCaptions] Model ready on {_DEVICE} dtype={_DTYPE}") | |
| return _MODEL, _DEVICE, _DTYPE | |
| # ------------------------------ | |
| # 3) Instruction templates & options | |
| # ------------------------------ | |
| STYLE_OPTIONS = [ | |
| "Descriptive", | |
| "Character training", | |
| "Flux.1-Dev", | |
| "Stable Diffusion", | |
| "MidJourney", | |
| "E-commerce product", | |
| "Portrait (photography)", | |
| "Landscape (photography)", | |
| "Art analysis (no artist names)", | |
| "Social caption", | |
| "Aesthetic tags (comma-sep)" | |
| ] | |
| CAPTION_TYPE_MAP: Dict[str, str] = { | |
| "Descriptive": "Write a detailed description for this image.", | |
| "Character training": ( | |
| "Write a thorough, training-ready caption for a character dataset. " | |
| "Describe subject appearance (physique, face/hair), clothing and accessories, actions/pose/gesture, camera angle/focal cues." | |
| "If multiple subjects are present, describe each briefly (most prominent first) and distinguish them by visible traits." | |
| ), | |
| "Flux.1-Dev": "Write a Flux.1-Dev style prompt that would reproduce this image faithfully.", | |
| "Stable Diffusion": "Write a Stable Diffusion style prompt that would reproduce this image faithfully.", | |
| "MidJourney": "Write a MidJourney style prompt that would reproduce this image faithfully.", | |
| "Aesthetic tags (comma-sep)": "Return only comma-separated aesthetic tags capturing subject, medium, style, lighting, composition. No sentences.", | |
| "E-commerce product": "Write a crisp product description highlighting key attributes, materials, color, usage, and distinguishing traits.", | |
| "Portrait (photography)": "Describe the subject, age range, pose, facial expression, camera angle, focal length cues, lighting, and background.", | |
| "Landscape (photography)": "Describe major landscape elements, time of day, weather, vantage point, composition, and mood.", | |
| "Art analysis (no artist names)": "Analyze visible medium, style, composition, and palette. Do not mention artist names or titles.", | |
| "Social caption": "Write an engaging caption describing the visible content. No hashtags.", | |
| } | |
| LENGTH_CHOICES = ["any", "very short", "short", "medium-length", "long", "very long"] + [str(i) for i in range(20, 261, 10)] | |
| _LENGTH_HINTS = { | |
| "very short": "Keep to one very short sentence (≈10–15 words).", | |
| "short": "Keep to a short sentence (≈15–25 words).", | |
| "medium-length":"Write 1–2 sentences (≈30–60 words).", | |
| "long": "Write a detailed caption (≈80–120 words).", | |
| "very long": "Write a very detailed caption (≈150–250 words).", | |
| } | |
| def _length_hint(choice: str) -> str: | |
| if not choice or choice == "any": | |
| return "" | |
| if choice.isdigit(): | |
| return f"Limit the caption to at most {choice} words." | |
| return _LENGTH_HINTS.get(choice, "") | |
| EXTRA_CHOICES = [ | |
| "Only include a character's modifiable, style-level attributes (hair style/color, makeup, clothing/accessories, pose, expression). Do NOT mention identity traits (skin tone, age, body type).", | |
| "Use profanity when describing sexual acts or genitalia (pussy, fucking, cum, cock, etc.).", # You do NOT have permission to remove | |
| "Be sexually graphic and describe sexual position when visible.", | |
| "Exclude brand names, watermarks, artist names, file/meta info, quality terms, hashtags, and resolution notes.", | |
| "Arrange information in the order: Subject; Clothing/Accessories; Action/Pose; Setting/Environment; Lighting/Camera.", | |
| "Do NOT use any ambiguous language.", | |
| "ONLY describe the most important elements of the image.", | |
| "Include information about the ages of any people/characters when applicable.", | |
| "Explicitly specify the vantage height (eye-level, low-angle worm’s-eye, bird’s-eye, drone, rooftop, etc.).", | |
| "Focus captions only on clothing/fashion details.", | |
| "Focus on setting, scenery, and context; ignore subject details.", | |
| "ONLY describe the subject’s pose, movement, or action. Do NOT mention appearance, clothing, or setting.", | |
| "Include synonyms/alternate phrasing to diversify training set.", | |
| "Include information about depth, lighting, and camera angle.", | |
| "Include information on composition (rule of thirds, symmetry, leading lines, etc).", | |
| "Specify the depth of field and whether the background is in focus or blurred.", | |
| "If applicable, mention the likely use of artificial or natural lighting sources.", | |
| "Identify the image orientation (portrait, landscape, or square) if obvious.", | |
| ] | |
| NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}." | |
| # Place near EXTRA_CHOICES | |
| DEFAULT_EXTRA_CHOICES = [ | |
| "Use profanity when describing sexual acts or genitalia (pussy, fucking, cum, cock, etc.).", | |
| "Be sexually graphic and describe sexual position when visible.", | |
| "Exclude brand names, watermarks, artist names, file/meta info, quality terms, hashtags, and resolution notes.", | |
| "Arrange information in the order: Subject; Clothing/Accessories; Action/Pose; Setting/Environment; Lighting/Camera.", | |
| ] | |
| ALL_EXTRA_CHOICES = [NAME_OPTION] + EXTRA_CHOICES | |
| def _valid_extras(selected: list[str] | None) -> list[str]: | |
| allowed = set(ALL_EXTRA_CHOICES) | |
| return [x for x in (selected or []) if x in allowed] | |
| # ------------------------------ | |
| # 4) Persistence helpers (settings/session/journal) | |
| # ------------------------------ | |
| def save_session(rows: List[dict]): | |
| with open(SESSION_FILE, "w", encoding="utf-8") as f: | |
| json.dump(rows, f, ensure_ascii=False, indent=2) | |
| def load_session() -> List[dict]: | |
| if os.path.exists(SESSION_FILE): | |
| with open(SESSION_FILE, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| return [] | |
| def save_settings(cfg: dict): | |
| with open(SETTINGS_FILE, "w", encoding="utf-8") as f: | |
| json.dump(cfg, f, ensure_ascii=False, indent=2) | |
| def load_settings() -> dict: | |
| cfg = {} | |
| if os.path.exists(SETTINGS_FILE): | |
| try: | |
| with open(SETTINGS_FILE, "r", encoding="utf-8") as f: | |
| cfg = json.load(f) or {} | |
| except Exception: | |
| cfg = {} | |
| # Defaults | |
| defaults = { | |
| "dataset_name": "forgecaptions", | |
| "temperature": 0.6, | |
| "top_p": 0.9, | |
| "max_tokens": 256, | |
| "max_side": 896, | |
| "styles": ["Character training"], | |
| "name": "", | |
| "trigger": "", | |
| "begin": "", | |
| "end": "", | |
| "shape_aliases_enabled": True, | |
| "shape_aliases": [], | |
| "excel_thumb_px": 128, | |
| "logo_px": 60, | |
| "shape_aliases_persist": True, | |
| "extras": DEFAULT_EXTRA_CHOICES, | |
| } | |
| for k, v in defaults.items(): | |
| cfg.setdefault(k, v) | |
| # Normalize styles to a valid list | |
| styles = cfg.get("styles") or [] | |
| if not isinstance(styles, list): | |
| styles = [styles] | |
| cfg["styles"] = [s for s in styles if s in STYLE_OPTIONS] or ["Character training"] | |
| cfg["extras"] = _valid_extras(cfg.get("extras")) | |
| return cfg | |
| def save_journal(data: dict): | |
| with open(JOURNAL_FILE, "w", encoding="utf-8") as f: | |
| json.dump(data, f, ensure_ascii=False, indent=2) | |
| def load_journal() -> dict: | |
| if os.path.exists(JOURNAL_FILE): | |
| with open(JOURNAL_FILE, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| return {} | |
| # ------------------------------ | |
| # 5) Small utilities (thumbs, resize, prefix/suffix, names) | |
| # ------------------------------ | |
| def sanitize_basename(s: str) -> str: | |
| s = (s or "").strip() or "forgecaptions" | |
| return re.sub(r"[^A-Za-z0-9._-]+", "_", s)[:120] | |
| def ensure_thumb(path: str, max_side=256) -> str: | |
| try: | |
| im = Image.open(path).convert("RGB") | |
| except Exception: | |
| return path | |
| w, h = im.size | |
| if max(w, h) > max_side: | |
| s = max_side / max(w, h) | |
| im = im.resize((int(w*s), int(h*s)), Image.LANCZOS) | |
| base = os.path.basename(path) | |
| out_path = os.path.join(THUMB_CACHE, os.path.splitext(base)[0] + f"_thumb_{max_side}.jpg") | |
| try: | |
| im.save(out_path, "JPEG", quality=85, optimize=True) | |
| return out_path | |
| except Exception: | |
| return path | |
| def resize_for_model(im: Image.Image, max_side: int) -> Image.Image: | |
| w, h = im.size | |
| if max(w, h) <= max_side: | |
| return im | |
| s = max_side / max(w, h) | |
| return im.resize((int(w*s), int(h*s)), Image.LANCZOS) | |
| def apply_prefix_suffix(caption: str, trigger_word: str, begin_text: str, end_text: str) -> str: | |
| parts = [] | |
| if trigger_word.strip(): | |
| parts.append(trigger_word.strip()) | |
| if begin_text.strip(): | |
| parts.append(begin_text.strip()) | |
| parts.append(caption.strip()) | |
| if end_text.strip(): | |
| parts.append(end_text.strip()) | |
| return " ".join([p for p in parts if p]) | |
| def logo_b64_img() -> str: | |
| candidates = [ | |
| os.path.join(APP_DIR, "forgecaptions-logo.png"), | |
| os.path.join(APP_DIR, "captionforge-logo.png"), | |
| "forgecaptions-logo.png", | |
| "captionforge-logo.png", | |
| ] | |
| for p in candidates: | |
| if os.path.exists(p): | |
| with open(p, "rb") as f: | |
| b64 = base64.b64encode(f.read()).decode("ascii") | |
| return f"<img src='data:image/png;base64,{b64}' alt='ForgeCaptions' class='cf-logo'>" | |
| return "" | |
| # ------------------------------ | |
| # 6) Shape Aliases (plural-aware + '-shaped' variants) | |
| # ------------------------------ | |
| def _plural_token_regex(tok: str) -> str: | |
| """ | |
| Build a regex for a token that also matches simple English plurals. | |
| Rules: | |
| - endswith s/x/z/ch/sh → add '(?:es)?' | |
| - consonant + y → '(?:y|ies)' | |
| - default → 's?' | |
| """ | |
| t = (tok or "").strip() | |
| if not t: return "" | |
| t_low = t.lower() | |
| if re.search(r"[^aeiou]y$", t_low): | |
| return re.escape(t[:-1]) + r"(?:y|ies)" | |
| if re.search(r"(?:s|x|z|ch|sh)$", t_low): | |
| return re.escape(t) + r"(?:es)?" | |
| return re.escape(t) + r"s?" | |
| def _compile_shape_aliases_from_file(): | |
| """ | |
| Build regex list from settings["shape_aliases"]. | |
| Left cell accepts comma OR pipe separated synonyms (multi-word OK). | |
| Matches are case-insensitive, catches simple plurals, and allows '-shaped' or ' shaped'. | |
| """ | |
| s = load_settings() | |
| if not s.get("shape_aliases_enabled", True): | |
| return [] | |
| compiled = [] | |
| for item in s.get("shape_aliases", []): | |
| raw = (item.get("shape") or "").strip() | |
| name = (item.get("name") or "").strip() | |
| if not raw or not name: | |
| continue | |
| tokens = [t.strip() for t in re.split(r"[|,]", raw) if t.strip()] | |
| if not tokens: | |
| continue | |
| alts = [_plural_token_regex(t) for t in tokens] | |
| alts = [a for a in alts if a] | |
| if not alts: | |
| continue | |
| pat = r"\b(?:" + "|".join(alts) + r")(?:[-\s]?shaped)?\b" | |
| compiled.append((re.compile(pat, flags=re.I), name)) | |
| return compiled | |
| _SHAPE_ALIASES = _compile_shape_aliases_from_file() | |
| def _refresh_shape_aliases_cache(): | |
| global _SHAPE_ALIASES | |
| _SHAPE_ALIASES = _compile_shape_aliases_from_file() | |
| def apply_shape_aliases(caption: str) -> str: | |
| for pat, name in _SHAPE_ALIASES: | |
| caption = pat.sub(f"({name})", caption) | |
| return caption | |
| def get_shape_alias_rows_ui_defaults(): | |
| s = load_settings() | |
| rows = [[it.get("shape",""), it.get("name","")] for it in s.get("shape_aliases", [])] | |
| enabled = bool(s.get("shape_aliases_enabled", True)) | |
| if not rows: | |
| rows = [["", ""]] | |
| return rows, enabled | |
| def save_shape_alias_rows(enabled, df_rows, persist): | |
| """ | |
| Save or just apply alias rows. | |
| - If persist=True → write to SETTINGS_FILE and recompile from file | |
| - If persist=False → do NOT touch disk; just compile & apply in-memory | |
| """ | |
| cleaned = [] | |
| for r in (df_rows or []): | |
| if not r: | |
| continue | |
| shape = (r[0] or "").strip() | |
| name = (r[1] or "").strip() | |
| if shape and name: | |
| cleaned.append({"shape": shape, "name": name}) | |
| status = "✅ Applied for this session only." | |
| if persist: | |
| cfg = load_settings() | |
| cfg["shape_aliases_enabled"] = bool(enabled) | |
| cfg["shape_aliases"] = cleaned | |
| save_settings(cfg) | |
| status = "💾 Saved to disk (will persist across restarts)." | |
| # Recompile in-memory, regardless of persist | |
| global _SHAPE_ALIASES | |
| if bool(enabled): | |
| compiled = [] | |
| for item in cleaned: | |
| raw = item["shape"]; name = item["name"] | |
| toks = [t.strip() for t in re.split(r"[|,]", raw) if t.strip()] | |
| alts = [_plural_token_regex(t) for t in toks] | |
| alts = [a for a in alts if a] | |
| if not alts: | |
| continue | |
| pat = r"\b(?:" + "|".join(alts) + r")(?:[-\s]?shaped)?\b" | |
| compiled.append((re.compile(pat, flags=re.I), name)) | |
| _SHAPE_ALIASES = compiled | |
| else: | |
| _SHAPE_ALIASES = [] | |
| normalized = [[it["shape"], it["name"]] for it in cleaned] + [["", ""]] | |
| return status, gr.update(value=normalized, row_count=(max(1, len(normalized)), "dynamic")) | |
| # ------------------------------ | |
| # 7) Prompt builder (instruction text shown/used for model) | |
| # ------------------------------ | |
| def final_instruction(style_list: List[str], extra_opts: List[str], name_value: str, length_choice: str = "long") -> str: | |
| styles = style_list or ["Character training"] | |
| parts = [CAPTION_TYPE_MAP.get(s, "") for s in styles] | |
| core = " ".join(p for p in parts if p).strip() | |
| if extra_opts: | |
| core += " " + " ".join(extra_opts) | |
| if NAME_OPTION in (extra_opts or []): | |
| core = core.replace("{name}", (name_value or "{NAME}").strip()) | |
| if "Aesthetic tags (comma-sep)" not in styles: # If they're asking for comma-separated tags, ignore word-length guidance. | |
| lh = _length_hint(length_choice or "any") | |
| if lh: | |
| core += " " + lh | |
| return core | |
| # ------------------------------ | |
| # 8) GPU caption functions | |
| # ------------------------------ | |
| def _build_inputs(im: Image.Image, instr: str, dtype) -> Dict[str, Any]: | |
| convo = [ | |
| {"role": "system", "content": ROLE_PROMPT}, | |
| {"role": "user", "content": instr.strip()}, | |
| ] | |
| convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(text=[convo_str], images=[im], return_tensors="pt") | |
| if "pixel_values" in inputs: | |
| inputs["pixel_values"] = inputs["pixel_values"].to(dtype) | |
| return inputs | |
| def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str: | |
| """ | |
| NOTE: Not GPU-decorated on purpose; call this only from within a @gpu function. | |
| """ | |
| model, device, dtype = get_model() | |
| inputs = _build_inputs(im, instr, dtype) | |
| inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()} | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=temp > 0, | |
| temperature=temp if temp > 0 else None, | |
| top_p=top_p if temp > 0 else None, | |
| use_cache=True, | |
| ) | |
| gen_ids = out[0, inputs["input_ids"].shape[1]:] | |
| return processor.tokenizer.decode(gen_ids, skip_special_tokens=True) | |
| def caption_single(img: Image.Image, instr: str) -> str: | |
| if img is None: | |
| return "No image provided." | |
| s = load_settings() | |
| im = resize_for_model(img, int(s.get("max_side", 896))) | |
| cap = caption_once(im, instr, s.get("temperature",0.6), s.get("top_p",0.9), s.get("max_tokens",256)) | |
| cap = apply_shape_aliases(cap) | |
| cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end","")) | |
| return cap | |
| def run_batch( | |
| files: List[Any], | |
| session_rows: List[dict], | |
| instr_text: str, | |
| temp: float, | |
| top_p: float, | |
| max_tokens: int, | |
| max_side: int, | |
| time_budget_s: float | None = None, # respects Zero-GPU window (None = unlimited) | |
| progress: gr.Progress = gr.Progress(track_tqdm=True), # drives the progress bar | |
| ) -> Tuple[List[dict], list, list, str, List[str], int, int]: | |
| """ | |
| Returns: | |
| session_rows, gallery_pairs, table_rows, status_text, | |
| leftover_files, processed_in_this_call, total_in_this_call | |
| """ | |
| session_rows = session_rows or [] | |
| files = [f for f in (files or []) if f and os.path.exists(f)] | |
| total = len(files) | |
| processed = 0 | |
| if total == 0: | |
| gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption","")) | |
| for r in session_rows if (r.get("thumb_path") or r.get("path"))] | |
| table_rows = [[r.get("filename",""), r.get("caption","")] for r in session_rows] | |
| return session_rows, gallery_pairs, table_rows, f"Saved • {time.strftime('%H:%M:%S')}", [], 0, 0 | |
| start = time.time() | |
| leftover: List[str] = [] | |
| for idx, path in enumerate(progress.tqdm(files, desc="Captioning")): | |
| try: | |
| im = Image.open(path).convert("RGB") | |
| except Exception: | |
| continue | |
| im = resize_for_model(im, max_side) | |
| cap = caption_once(im, instr_text, temp, top_p, max_tokens) | |
| cap = apply_shape_aliases(cap) | |
| s = load_settings() | |
| cap = apply_prefix_suffix(cap, s.get("trigger",""), s.get("begin",""), s.get("end","")) | |
| filename = os.path.basename(path) | |
| thumb = ensure_thumb(path, 256) | |
| session_rows.append({"filename": filename, "caption": cap, "path": path, "thumb_path": thumb}) | |
| processed += 1 | |
| if (time_budget_s is not None) and ((time.time() - start) >= float(time_budget_s)): | |
| leftover = files[idx+1:] | |
| break | |
| save_session(session_rows) | |
| gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption","")) | |
| for r in session_rows if (r.get("thumb_path") or r.get("path"))] | |
| table_rows = [[r.get("filename",""), r.get("caption","")] for r in session_rows] | |
| return ( | |
| session_rows, | |
| gallery_pairs, | |
| table_rows, | |
| f"Saved • {time.strftime('%H:%M:%S')}", | |
| leftover, | |
| processed, | |
| total, | |
| ) | |
| # ------------------------------ | |
| # 9) Export/Import helpers (CSV/XLSX/TXT ZIP) | |
| # ------------------------------ | |
| def _rows_to_table(rows: List[dict]) -> list: | |
| return [[r.get("filename",""), r.get("caption","")] for r in (rows or [])] | |
| def _table_to_rows(table_value: Any, rows: List[dict]) -> List[dict]: | |
| # Expect list-of-lists (Dataframe type="array") | |
| tbl = table_value or [] | |
| new = [] | |
| for i, r in enumerate(rows or []): | |
| r = dict(r) | |
| if i < len(tbl) and len(tbl[i]) >= 2: | |
| r["filename"] = str(tbl[i][0]) if tbl[i][0] is not None else r.get("filename","") | |
| r["caption"] = str(tbl[i][1]) if tbl[i][1] is not None else r.get("caption","") | |
| new.append(r) | |
| return new | |
| def export_csv_from_table(table_value: Any, dataset_name: str) -> str: | |
| data = table_value or [] | |
| name = sanitize_basename(dataset_name) | |
| out = f"/tmp/{name}_{int(time.time())}.csv" | |
| with open(out, "w", newline="", encoding="utf-8") as f: | |
| w = csv.writer(f); w.writerow(["filename", "caption"]); w.writerows(data) | |
| return out | |
| def _resize_for_excel(path: str, px: int) -> str: | |
| try: | |
| im = Image.open(path).convert("RGB") | |
| except Exception: | |
| return path | |
| w, h = im.size | |
| if max(w, h) > px: | |
| s = px / max(w, h) | |
| im = im.resize((int(w*s), int(h*s)), Image.LANCZOS) | |
| base = os.path.basename(path) | |
| out_path = os.path.join(EXCEL_THUMB_DIR, f"{os.path.splitext(base)[0]}_{px}px.jpg") | |
| try: | |
| im.save(out_path, "JPEG", quality=85, optimize=True) | |
| return out_path | |
| except Exception: | |
| return path | |
| def export_excel_with_thumbs(table_value: Any, session_rows: List[dict], thumb_px: int, dataset_name: str) -> str: | |
| try: | |
| from openpyxl import Workbook | |
| from openpyxl.drawing.image import Image as XLImage | |
| except Exception as e: | |
| raise RuntimeError("Excel export requires 'openpyxl' in requirements.txt.") from e | |
| caption_by_file = {} | |
| for row in (table_value or []): | |
| if not row: | |
| continue | |
| fn = str(row[0]) if len(row) > 0 else "" | |
| cap = str(row[1]) if len(row) > 1 and row[1] is not None else "" | |
| if fn: | |
| caption_by_file[fn] = cap | |
| wb = Workbook(); ws = wb.active; ws.title = "ForgeCaptions" | |
| ws.append(["image", "filename", "caption"]) | |
| ws.column_dimensions["A"].width = 24 | |
| ws.column_dimensions["B"].width = 42 | |
| ws.column_dimensions["C"].width = 100 | |
| row_h = int(int(thumb_px) * 0.75) | |
| r_i = 2 | |
| for r in (session_rows or []): | |
| fn = r.get("filename",""); cap = caption_by_file.get(fn, r.get("caption","")) | |
| ws.cell(row=r_i, column=2, value=fn) | |
| ws.cell(row=r_i, column=3, value=cap) | |
| img_path = r.get("thumb_path") or r.get("path") | |
| if img_path and os.path.exists(img_path): | |
| try: | |
| resized = _resize_for_excel(img_path, int(thumb_px)) | |
| xlimg = XLImage(resized) | |
| ws.add_image(xlimg, f"A{r_i}") | |
| ws.row_dimensions[r_i].height = row_h | |
| except Exception: | |
| pass | |
| r_i += 1 | |
| name = sanitize_basename(dataset_name) | |
| out = f"/tmp/{name}_{int(time.time())}.xlsx" | |
| wb.save(out) | |
| return out | |
| def export_txt_zip(table_value: Any, dataset_name: str) -> str: | |
| """ | |
| Create one .txt per caption, zip them. | |
| """ | |
| data = table_value or [] | |
| # wipe old | |
| for fn in os.listdir(TXT_EXPORT_DIR): | |
| try: | |
| os.remove(os.path.join(TXT_EXPORT_DIR, fn)) | |
| except Exception: | |
| pass | |
| used: Dict[str,int] = {} | |
| for row in data: | |
| if not row: | |
| continue | |
| orig = (row[0] or "item").strip() if len(row) > 0 else "item" | |
| stem = re.sub(r"\.[A-Za-z0-9]+$", "", orig) | |
| stem = sanitize_basename(stem or "item") | |
| if stem in used: | |
| used[stem] += 1 | |
| stem = f"{stem}_{used[stem]}" | |
| else: | |
| used[stem] = 0 | |
| cap = (row[1] or "").strip() if len(row) > 1 and row[1] is not None else "" | |
| with open(os.path.join(TXT_EXPORT_DIR, f"{stem}.txt"), "w", encoding="utf-8") as f: | |
| f.write(cap) | |
| name = sanitize_basename(dataset_name) | |
| zpath = f"/tmp/{name}_{int(time.time())}_txt.zip" | |
| with zipfile.ZipFile(zpath, "w", zipfile.ZIP_DEFLATED) as z: | |
| for fn in os.listdir(TXT_EXPORT_DIR): | |
| if fn.endswith(".txt"): | |
| z.write(os.path.join(TXT_EXPORT_DIR, fn), arcname=fn) | |
| return zpath | |
| def import_captions_file(file_path: str, session_rows: List[dict]) -> Tuple[List[dict], list, list, str]: | |
| """ | |
| Import captions from CSV or XLSX and merge by filename into current session. | |
| - If filename exists → update its caption | |
| - Otherwise append a new row (without image path/thumbnail) | |
| """ | |
| if not file_path or not os.path.exists(file_path): | |
| table_rows = _rows_to_table(session_rows) | |
| gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption","")) | |
| for r in session_rows if (r.get("thumb_path") or r.get("path"))] | |
| return session_rows, gallery_pairs, table_rows, "No file selected." | |
| ext = os.path.splitext(file_path)[1].lower() | |
| imported: List[Tuple[str, str]] = [] | |
| try: | |
| if ext == ".csv": | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| reader = csv.reader(f) | |
| rows = list(reader) | |
| if rows and len(rows[0]) >= 2 and str(rows[0][0]).lower().strip() == "filename": | |
| rows = rows[1:] | |
| for r in rows: | |
| if not r or len(r) < 2: | |
| continue | |
| fn = str(r[0]).strip() | |
| cap = str(r[1]).strip() | |
| if fn: | |
| imported.append((fn, cap)) | |
| elif ext in (".xlsx", ".xls"): | |
| try: | |
| from openpyxl import load_workbook | |
| except Exception as e: | |
| raise RuntimeError("XLSX import requires 'openpyxl' in requirements.txt.") from e | |
| wb = load_workbook(file_path, read_only=True, data_only=True) | |
| ws = wb.active | |
| rows = list(ws.iter_rows(values_only=True)) | |
| if rows and len(rows[0]) >= 2 and str(rows[0][0]).lower().strip() == "filename": | |
| rows = rows[1:] | |
| for r in rows: | |
| if not r or len(r) < 2: | |
| continue | |
| fn = str(r[0]).strip() if r[0] is not None else "" | |
| cap = str(r[1]).strip() if r[1] is not None else "" | |
| if fn: | |
| imported.append((fn, cap)) | |
| else: | |
| return session_rows, _rows_to_table(session_rows), _rows_to_table(session_rows), f"Unsupported file type: {ext}" | |
| except Exception as e: | |
| table_rows = _rows_to_table(session_rows) | |
| gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption","")) | |
| for r in session_rows if (r.get("thumb_path") or r.get("path"))] | |
| return session_rows, gallery_pairs, table_rows, f"Import failed: {e}" | |
| # Merge | |
| idx_by_name = {r.get("filename",""): i for i, r in enumerate(session_rows)} | |
| updates, inserts = 0, 0 | |
| for fn, cap in imported: | |
| if fn in idx_by_name: | |
| session_rows[idx_by_name[fn]]["caption"] = cap | |
| updates += 1 | |
| else: | |
| session_rows.append({"filename": fn, "caption": cap, "path": "", "thumb_path": ""}) | |
| inserts += 1 | |
| save_session(session_rows) | |
| gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption","")) | |
| for r in session_rows if (r.get("thumb_path") or r.get("path"))] | |
| table_rows = _rows_to_table(session_rows) | |
| stamp = f"Imported {len(imported)} rows • updated {updates} • added {inserts} • {time.strftime('%H:%M:%S')}" | |
| return session_rows, gallery_pairs, table_rows, stamp | |
| # ------------------------------ | |
| # 10) UI header helper (fixed logo size) | |
| # ------------------------------ | |
| def _render_header_html(px: int) -> str: | |
| return f""" | |
| <div class="cf-hero"> | |
| {logo_b64_img()} | |
| <div class="cf-text"> | |
| <h1 class="cf-title">ForgeCaptions</h1> | |
| <div class="cf-sub">JoyCaption Image Captioning</div> | |
| <div class="cf-sub">Import CSV/XLSX • Export CSV/XLSX/TXT</div> | |
| <div class="cf-sub">Batch 10–20 per Zero GPU run • Larger batches with dedicated GPU</div> | |
| </div> | |
| </div> | |
| <hr> | |
| <style> | |
| .cf-logo {{ | |
| height: {int(px)}px; /* fixed height */ | |
| width: auto; | |
| object-fit: contain; | |
| display: block; | |
| }} | |
| @media (max-width: 640px) {{ | |
| .cf-logo {{ height: {max(60, int(px) - 12)}px; }} /* optional small-screen tweak */ | |
| }} | |
| </style> | |
| """ | |
| # ------------------------------ | |
| # 11) UI (Blocks) | |
| # ------------------------------ | |
| BASE_CSS = """ | |
| :root{--galleryW:50%;--tableW:50%;} | |
| .gradio-container{max-width:100%!important} | |
| /* Header */ | |
| .cf-hero{display:flex; align-items:center; justify-content:center; gap:16px; | |
| margin:4px 0 12px; text-align:center;} | |
| .cf-hero .cf-text{text-align:center;} | |
| .cf-title{margin:0;font-size:3.25rem;line-height:1;letter-spacing:.2px} | |
| .cf-sub{margin:6px 0 0;font-size:1.1rem;color:#cfd3da} | |
| /* Results area + robust scrollbars */ | |
| .cf-scroll{border:1px solid #e6e6e6; border-radius:10px; padding:8px} | |
| #cfGal{max-height:520px; overflow-y:auto !important;} | |
| #cfTableWrap{max-height:520px; overflow-y:auto !important;} | |
| #cfGal [data-testid="gallery"]{height:auto !important;} | |
| #cfGal .grid > div { height: 96px; } | |
| """ | |
| with gr.Blocks(css=BASE_CSS, title="ForgeCaptions") as demo: | |
| # Ensure Spaces sees a GPU function (without touching CUDA in main) | |
| demo.load(inputs=None, outputs=None) | |
| # ---- Header | |
| settings = load_settings() | |
| header_html = gr.HTML(_render_header_html(LOGO_HEIGHT_PX)) | |
| # ---- Controls group | |
| with gr.Group(): | |
| with gr.Row(): | |
| # LEFT: styles / extras / name & prefix-suffix | |
| with gr.Column(scale=2): | |
| with gr.Accordion("Caption style (choose one or combine)", open=True): | |
| style_checks = gr.CheckboxGroup( | |
| choices=STYLE_OPTIONS, | |
| value=settings.get("styles", ["Character training"]), | |
| label=None | |
| ) | |
| caption_length = gr.Dropdown( | |
| choices=LENGTH_CHOICES, | |
| label="Caption Length", | |
| value=settings.get("caption_length", "long") | |
| ) | |
| with gr.Accordion("Extra options", open=False): | |
| extra_opts = gr.CheckboxGroup( | |
| choices=[NAME_OPTION] + EXTRA_CHOICES, | |
| value=settings.get("extras", []), | |
| label=None | |
| ) | |
| with gr.Accordion("Name & Prefix/Suffix", open=False): | |
| name_input = gr.Textbox(label="Person / Character Name", value=settings.get("name", "")) | |
| trig = gr.Textbox(label="Trigger word", value=settings.get("trigger","")) | |
| add_start = gr.Textbox(label="Add text to start", value=settings.get("begin","")) | |
| add_end = gr.Textbox(label="Add text to end", value=settings.get("end","")) | |
| # RIGHT: instructions + dataset + general sliders + logo controls | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Model Instructions", open=False): | |
| instruction_preview = gr.Textbox( | |
| label=None, | |
| lines=12, | |
| value=final_instruction( | |
| settings.get("styles", ["Character training"]), | |
| settings.get("extras", []), | |
| settings.get("name",""), | |
| settings.get("caption_length", "long"), | |
| ), | |
| ) | |
| dataset_name = gr.Textbox(label="Dataset name (export title prefix)", | |
| value=settings.get("dataset_name", "forgecaptions")) | |
| max_side = gr.Slider(256, 1024, settings.get("max_side", 896), step=32, label="Max side (resize)") | |
| excel_thumb_px = gr.Slider(64, 256, value=settings.get("excel_thumb_px", 128), | |
| step=8, label="Excel thumbnail size (px)") | |
| # Chunking | |
| chunk_mode = gr.Radio( | |
| choices=["Auto", "Manual (step)"], | |
| value="Manual (step)", label="Batch mode" | |
| ) | |
| chunk_size = gr.Slider(1, 200, value=15, step=1, label="Chunk size") | |
| gpu_budget = gr.Slider(20, 110, value=55, step=5, label="Max seconds per GPU call") | |
| no_time_limit = gr.Checkbox(value=False, label="No time limit (ignore above)") | |
| # Persist instruction + general settings | |
| def _refresh_instruction(styles, extra, name_value, trigv, begv, endv, excel_px, ms, cap_len): | |
| instr = final_instruction(styles or ["Character training"], extra or [], name_value, cap_len) | |
| cfg = load_settings() | |
| cfg.update({ | |
| "styles": styles or ["Character training"], | |
| "extras": extra or [], | |
| "name": name_value, | |
| "trigger": trigv, "begin": begv, "end": endv, | |
| "excel_thumb_px": int(excel_px), | |
| "max_side": int(ms), | |
| "caption_length": cap_len or "any", | |
| }) | |
| save_settings(cfg) | |
| return instr | |
| for comp in [style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side, caption_length]: | |
| comp.change(_refresh_instruction, | |
| inputs=[style_checks, extra_opts, name_input, trig, add_start, add_end, excel_thumb_px, max_side, caption_length], | |
| outputs=[instruction_preview]) | |
| def _save_dataset_name(name): | |
| cfg = load_settings() | |
| cfg["dataset_name"] = sanitize_basename(name) | |
| save_settings(cfg) | |
| return gr.update() | |
| dataset_name.change(_save_dataset_name, inputs=[dataset_name], outputs=[]) | |
| # ---- Shape Aliases (with plural matching + persist) ---- | |
| with gr.Accordion("Shape Aliases", open=False): | |
| gr.Markdown( | |
| "### 🔷 Shape Aliases\n" | |
| "Replace literal **shape tokens** in captions with a preferred **name**.\n\n" | |
| "**How to use:**\n" | |
| "- Left column = a single token **or** comma/pipe-separated synonyms, e.g. `diamond, rhombus | lozenge`\n" | |
| "- Right column = replacement name, e.g. `family-emblem`\n" | |
| "Matches are case-insensitive, catches simple plurals (`box`→`boxes`, `lady`→`ladies`), " | |
| "and also matches `*-shaped` or `* shaped` variants." | |
| ) | |
| init_rows, init_enabled = get_shape_alias_rows_ui_defaults() | |
| enable_aliases = gr.Checkbox(label="Enable shape alias replacements", value=init_enabled) | |
| persist_aliases = gr.Checkbox( | |
| label="Save aliases across sessions", | |
| value=load_settings().get("shape_aliases_persist", True) | |
| ) | |
| alias_table = gr.Dataframe( | |
| headers=["shape (token or synonyms)", "name to insert"], | |
| value=init_rows, | |
| row_count=(max(1, len(init_rows)), "dynamic"), | |
| datatype=["str","str"], | |
| type="array", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| add_row_btn = gr.Button("+ Add row", variant="secondary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| save_btn = gr.Button("💾 Save", variant="primary") | |
| # status line for saves | |
| save_status = gr.Markdown("") | |
| # --- local handlers (must stay inside Blocks context) --- | |
| def _add_row(cur): | |
| cur = (cur or []) + [["", ""]] | |
| return gr.update(value=cur, row_count=(max(1, len(cur)), "dynamic")) | |
| def _clear_rows(): | |
| return gr.update(value=[["", ""]], row_count=(1, "dynamic")) | |
| add_row_btn.click(_add_row, inputs=[alias_table], outputs=[alias_table]) | |
| clear_btn.click(_clear_rows, outputs=[alias_table]) | |
| def _save_alias_persist_flag(v): | |
| cfg = load_settings() | |
| cfg["shape_aliases_persist"] = bool(v) | |
| save_settings(cfg) | |
| return gr.update() | |
| persist_aliases.change(_save_alias_persist_flag, inputs=[persist_aliases], outputs=[]) | |
| # Persist rows if persist_aliases checked; otherwise apply in-memory only | |
| save_btn.click( | |
| save_shape_alias_rows, | |
| inputs=[enable_aliases, alias_table, persist_aliases], | |
| outputs=[save_status, alias_table] | |
| ) | |
| # ---- Tabs: Single & Batch | |
| with gr.Tabs(): | |
| with gr.Tab("Single"): | |
| input_image_single = gr.Image(type="pil", label="Input Image", height=512, width=512) | |
| single_caption_btn = gr.Button("Caption") | |
| single_caption_out = gr.Textbox(label="Caption (single)") | |
| single_caption_btn.click( | |
| caption_single, | |
| inputs=[input_image_single, instruction_preview], | |
| outputs=[single_caption_out] | |
| ) | |
| # with gr.Tab("Batch"): | |
| # with gr.Accordion("Uploaded images", open=True): | |
| # input_files = gr.File(label="Drop images (or click to select)", file_types=["image"], file_count="multiple", type="filepath") | |
| # run_button = gr.Button("Caption batch", variant="primary") | |
| # with gr.Accordion("Import captions from CSV/XLSX (merge by filename)", open=False): | |
| # import_file = gr.File(label="Choose .csv or .xlsx", file_types=[".csv", ".xlsx"], type="filepath") | |
| # import_btn = gr.Button("Import into current session") | |
| with gr.Tab("Batch"): | |
| with gr.Accordion("Uploaded images", open=True): | |
| input_files = gr.File(label="Drop images (or click to select)", file_types=["image"], file_count="multiple",) | |
| run_button = gr.Button("Caption batch", variant="primary") | |
| preview_gallery = gr.Gallery( | |
| label="Preview (un-captioned)", | |
| show_label=True, | |
| columns=5, | |
| height=220, | |
| ) | |
| input_files.change(on_files_changed, inputs=[input_files], outputs=[preview_gallery]) | |
| # ---- Results area (gallery left / table right) | |
| rows_state = gr.State(load_session()) | |
| autosave_md = gr.Markdown("Ready.") | |
| progress_md = gr.Markdown("") | |
| remaining_state = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gallery = gr.Gallery( | |
| label="Results", | |
| show_label=True, | |
| columns=3, | |
| elem_id="cfGal", | |
| elem_classes=["cf-scroll"] | |
| ) | |
| with gr.Column(scale=1, elem_id="cfTableWrap", elem_classes=["cf-scroll"]): | |
| table = gr.Dataframe( | |
| label="Editable captions", | |
| value=_rows_to_table(load_session()), | |
| headers=["filename", "caption"], | |
| interactive=True, | |
| wrap=True, | |
| type="array", # prevents pandas truth ambiguity | |
| elem_id="cfTable" | |
| ) | |
| # ---- Step panel | |
| step_panel = gr.Group(visible=False) | |
| with step_panel: | |
| step_msg = gr.Markdown("") | |
| step_next = gr.Button("Process next chunk") | |
| step_finish = gr.Button("Finish") | |
| # ---- Exports | |
| with gr.Row(): | |
| with gr.Column(): | |
| export_csv_btn = gr.Button("Export CSV") | |
| csv_file = gr.File(label="CSV file", visible=False) | |
| with gr.Column(): | |
| export_xlsx_btn = gr.Button("Export Excel (.xlsx) with thumbnails") | |
| xlsx_file = gr.File(label="Excel file", visible=False) | |
| with gr.Column(): | |
| export_txt_btn = gr.Button("Export captions as .txt (zip)") | |
| txt_zip = gr.File(label="TXT zip", visible=False) | |
| # ---- Robust scroll sync (works with Gradio v5 Gallery) | |
| gr.HTML(""" | |
| <script> | |
| (function () { | |
| function findGal() { | |
| const host = document.querySelector("#cfGal"); | |
| if (!host) return null; | |
| return host.querySelector('[data-testid="gallery"]') || host; | |
| } | |
| function findTbl() { | |
| return document.querySelector("#cfTableWrap"); | |
| } | |
| function syncScroll(a, b) { | |
| if (!a || !b) return; | |
| let lock = false; | |
| const onA = () => { if (lock) return; lock = true; b.scrollTop = a.scrollTop; lock = false; }; | |
| const onB = () => { if (lock) return; lock = true; a.scrollTop = b.scrollTop; lock = false; }; | |
| a.addEventListener("scroll", onA, { passive: true }); | |
| b.addEventListener("scroll", onB, { passive: true }); | |
| } | |
| let tries = 0; | |
| const t = setInterval(() => { | |
| tries++; | |
| const gal = findGal(); | |
| const tab = findTbl(); | |
| if (gal && tab) { | |
| const H = Math.min(520, Math.max(360, tab.clientHeight || 520)); | |
| gal.style.maxHeight = H + "px"; | |
| gal.style.overflowY = "auto"; | |
| tab.style.maxHeight = H + "px"; | |
| tab.style.overflowY = "auto"; | |
| syncScroll(gal, tab); | |
| clearInterval(t); | |
| } | |
| if (tries > 30) clearInterval(t); | |
| }, 120); | |
| })(); | |
| </script> | |
| """) | |
| # ---- Chunking logic | |
| def _split_chunks(files, csize: int): | |
| files = files or [] | |
| c = max(1, int(csize)) | |
| return [files[i:i+c] for i in range(0, len(files), c)] | |
| def _tpms(): | |
| s = load_settings() | |
| return s.get("temperature", 0.6), s.get("top_p", 0.9), s.get("max_tokens", 256) | |
| def _run_click(files, rows, instr, ms, mode, csize, budget_s, no_limit): | |
| t, p, m = _tpms() | |
| files = files or [] | |
| budget = None if no_limit else float(budget_s) | |
| # Manual step → process first chunk only | |
| if mode == "Manual (step)" and files: | |
| chunks = _split_chunks(files, int(csize)) | |
| batch = chunks[0] | |
| remaining = sum(chunks[1:], []) | |
| new_rows, gal, tbl, stamp, leftover_from_batch, done, total = run_batch( | |
| batch, rows or [], instr, t, p, m, int(ms), budget | |
| ) | |
| remaining = (leftover_from_batch or []) + remaining | |
| panel_vis = gr.update(visible=bool(remaining)) | |
| msg = f"{len(remaining)} files remain. Process next chunk?" | |
| prog = f"Batch progress: {done}/{total} processed in this step • Remaining overall: {len(remaining)}" | |
| return new_rows, gal, tbl, stamp, remaining, panel_vis, gr.update(value=msg), gr.update(value=prog) | |
| # Auto | |
| new_rows, gal, tbl, stamp, leftover, done, total = run_batch( | |
| files, rows or [], instr, t, p, m, int(ms), budget | |
| ) | |
| panel_vis = gr.update(visible=bool(leftover)) | |
| msg = f"{len(leftover)} files remain. Process next chunk?" if leftover else "" | |
| prog = f"Batch progress: {done}/{total} processed in this call • Remaining: {len(leftover)}" | |
| return new_rows, gal, tbl, stamp, leftover, panel_vis, gr.update(value=msg), gr.update(value=prog) | |
| run_button.click( | |
| _run_click, | |
| inputs=[input_files, rows_state, instruction_preview, max_side, chunk_mode, chunk_size, gpu_budget, no_time_limit], | |
| outputs=[rows_state, gallery, table, autosave_md, remaining_state, step_panel, step_msg, progress_md], | |
| ).then( | |
| lambda rows: [(Image.open(r["path"]).convert("RGB"), r["caption"]) for r in rows], | |
| inputs=[rows_state], | |
| outputs=[gallery], | |
| ) | |
| table.change( | |
| sync_table_to_session, | |
| inputs=[table, rows_state], | |
| outputs=[rows_state, captions_text], | |
| ).then( | |
| lambda rows: [(Image.open(r["path"]).convert("RGB"), r["caption"]) for r in rows], | |
| inputs=[rows_state], | |
| outputs=[gallery], | |
| ) | |
| def _step_next(remain, rows, instr, ms, csize, budget_s, no_limit): | |
| t, p, m = _tpms() | |
| remain = remain or [] | |
| budget = None if no_limit else float(budget_s) | |
| if not remain: | |
| return ( | |
| rows, | |
| gr.update(value="No files remaining."), | |
| gr.update(visible=True), | |
| [], | |
| [], | |
| [], | |
| "Saved.", | |
| gr.update(value="") | |
| ) | |
| batch = remain[:int(csize)] | |
| leftover = remain[int(csize):] | |
| new_rows, gal, tbl, stamp, leftover_from_batch, done, total = run_batch( | |
| batch, rows or [], instr, t, p, m, int(ms), budget | |
| ) | |
| leftover = (leftover_from_batch or []) + leftover | |
| panel_vis = gr.update(visible=bool(leftover)) | |
| msg = f"{len(leftover)} files remain. Process next chunk?" if leftover else "All done." | |
| prog = f"Batch progress: {done}/{total} processed in this step • Remaining overall: {len(leftover)}" | |
| return new_rows, msg, panel_vis, leftover, gal, tbl, stamp, gr.update(value=prog) | |
| step_next.click( | |
| _step_next, | |
| inputs=[remaining_state, rows_state, instruction_preview, max_side, chunk_size, gpu_budget, no_time_limit], | |
| outputs=[rows_state, step_msg, step_panel, remaining_state, gallery, table, autosave_md, progress_md] | |
| ) | |
| def _step_finish(): | |
| return gr.update(visible=False), gr.update(value=""), [] | |
| step_finish.click(_step_finish, inputs=None, outputs=[step_panel, step_msg, remaining_state]) | |
| # ---- Table edits → persist + refresh gallery | |
| def sync_table_to_session(table_value: Any, session_rows: List[dict]) -> Tuple[List[dict], list, str]: | |
| session_rows = _table_to_rows(table_value, session_rows or []) | |
| save_session(session_rows) | |
| gallery_pairs = [((r.get("thumb_path") or r.get("path")), r.get("caption","")) | |
| for r in session_rows if (r.get("thumb_path") or r.get("path"))] | |
| return session_rows, gallery_pairs, f"Saved • {time.strftime('%H:%M:%S')}" | |
| table.change(sync_table_to_session, inputs=[table, rows_state], outputs=[rows_state, gallery, autosave_md]) | |
| def new_session() -> Tuple[List[dict], list, list, str]: | |
| return [], [], _rows_to_table([]), "" | |
| # ---- Import hook | |
| def _do_import(fpath, rows): | |
| new_rows, gal, tbl, stamp = import_captions_file(fpath, rows or []) | |
| return new_rows, gal, tbl, stamp | |
| import_btn.click(_do_import, inputs=[import_file, rows_state], outputs=[rows_state, gallery, table, autosave_md]) | |
| # ---- Exports | |
| export_csv_btn.click( | |
| lambda tbl, ds: (export_csv_from_table(tbl, ds), gr.update(visible=True)), | |
| inputs=[table, dataset_name], outputs=[csv_file, csv_file] | |
| ) | |
| export_xlsx_btn.click( | |
| lambda tbl, rows, px, ds: (export_excel_with_thumbs(tbl, rows or [], int(px), ds), gr.update(visible=True)), | |
| inputs=[table, rows_state, excel_thumb_px, dataset_name], outputs=[xlsx_file, xlsx_file] | |
| ) | |
| export_txt_btn.click( | |
| lambda tbl, ds: (export_txt_zip(tbl, ds), gr.update(visible=True)), | |
| inputs=[table, dataset_name], outputs=[txt_zip, txt_zip] | |
| ) | |
| # ------------------------------ | |
| # 12) Launch (SSR disabled for stability on Spaces) | |
| # ------------------------------ | |
| if __name__ == "__main__": | |
| demo.queue(max_size=64).launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.getenv("PORT", "7860")), | |
| ssr_mode=False, | |
| debug=True, | |
| show_error=True, | |
| # Allow Gradio to serve generated files from /tmp caches | |
| allowed_paths=[THUMB_CACHE, EXCEL_THUMB_DIR, TXT_EXPORT_DIR], | |
| ) | |