JS6969 commited on
Commit
5570fdd
·
verified ·
1 Parent(s): 613dc76

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +836 -4
app.py CHANGED
@@ -1,7 +1,839 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import os, io, csv, time, json, hashlib, base64, zipfile, re
2
+ from typing import List, Any, Tuple, Dict
3
+
4
+ import spaces
5
  import gradio as gr
6
+ from PIL import Image
7
+ import torch
8
+ from transformers import LlavaForConditionalGeneration, AutoProcessor
9
+
10
+ # ────────────────────────────────────────────────────────
11
+ # Paths & caches
12
+ # ────────────────────────────────────────────────────────
13
+ os.environ.setdefault("HF_HOME", "/home/user/.cache/huggingface")
14
+ os.makedirs(os.environ["HF_HOME"], exist_ok=True)
15
+
16
+ APP_DIR = os.getcwd()
17
+ SESSION_FILE = "/tmp/session.json"
18
+ SETTINGS_FILE = "/tmp/cf_settings.json"
19
+ JOURNAL_FILE = "/tmp/cf_journal.json"
20
+ TXT_EXPORT_DIR = "/tmp/cf_txt"
21
+ THUMB_CACHE = os.path.expanduser("~/.cache/captionforge/thumbs")
22
+ os.makedirs(THUMB_CACHE, exist_ok=True)
23
+ os.makedirs(TXT_EXPORT_DIR, exist_ok=True)
24
+
25
+ # ────────────────────────────────────────────────────────
26
+ # Model setup
27
+ # ────────────────────────────────────────────────────────
28
+ MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
29
+
30
+ def _detect_gpu():
31
+ if torch.cuda.is_available():
32
+ p = torch.cuda.get_device_properties(0)
33
+ return "cuda", int(p.total_memory/(1024**3)), p.name
34
+ return "cpu", 0, "CPU"
35
+
36
+ BACKEND, VRAM_GB, GPU_NAME = _detect_gpu()
37
+ DTYPE = torch.bfloat16 if BACKEND == "cuda" else torch.float32
38
+ MAX_SIDE_CAP = 1024 if BACKEND == "cuda" else 640
39
+ DEVICE = "cuda" if BACKEND == "cuda" else "cpu"
40
+
41
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
42
+ model = LlavaForConditionalGeneration.from_pretrained(
43
+ MODEL_PATH,
44
+ torch_dtype=DTYPE,
45
+ low_cpu_mem_usage=True,
46
+ device_map=0 if BACKEND == "cuda" else "cpu",
47
+ )
48
+ model.eval()
49
+
50
+ print(f"[CaptionForge] Backend={BACKEND} GPU={GPU_NAME} VRAM={VRAM_GB}GB dtype={DTYPE}")
51
+ print(f"[CaptionForge] Gradio version: {gr.__version__}")
52
+
53
+ # ────────────────────────────────────────────────────────
54
+ # Instruction templates & options
55
+ # ────────────────────────────────────────────────────────
56
+ STYLE_OPTIONS = [
57
+ "Descriptive (short)", "Descriptive (long)",
58
+ "Character training (short)", "Character training (long)",
59
+ "LoRA (Flux_D Realism) (short)", "LoRA (Flux_D Realism) (long)",
60
+ "E-commerce product (short)", "E-commerce product (long)",
61
+ "Portrait (photography) (short)", "Portrait (photography) (long)",
62
+ "Landscape (photography) (short)", "Landscape (photography) (long)",
63
+ "Art analysis (no artist names) (short)", "Art analysis (no artist names) (long)",
64
+ "Social caption (short)", "Social caption (long)",
65
+ "Aesthetic tags (comma-sep)"
66
+ ]
67
+
68
+ CAPTION_TYPE_MAP = {
69
+ "Descriptive (short)": "One sentence (≤25 words) describing the most important visible elements only. No speculation.",
70
+ "Descriptive (long)": "Write a detailed description for this image.",
71
+
72
+ "Character training (short)": (
73
+ "Output a concise, prompt-like caption for character LoRA/ID training. "
74
+ "Include visible character name {name} if provided, distinct physical traits, clothing, pose, camera/cinematic cues. "
75
+ "No backstory; no non-visible traits. Prefer comma-separated phrases."
76
+ ),
77
+ "Character training (long)": (
78
+ "Write a thorough, training-ready caption for a character dataset. "
79
+ "Use {name} if provided; describe only what is visible: physique, face/hair, clothing, accessories, actions, pose, "
80
+ "camera angle/focal cues, lighting, background context. 1–3 sentences; no backstory or meta."
81
+ ),
82
+
83
+ "Flux_D (short)": "Output a short Flux.Dev prompt that is indistinguishable from a real Flux.Dev prompt.",
84
+ "Flux_D (long)": "Output a long Flux.Dev prompt that is indistinguishable from a real Flux.Dev prompt.",
85
+
86
+ "Aesthetic tags (comma-sep)": "Return only comma-separated aesthetic tags capturing subject, medium, style, lighting, composition. No sentences.",
87
+
88
+ "E-commerce product (short)": "One sentence highlighting key attributes, material, color, use case. No fluff.",
89
+ "E-commerce product (long)": "Write a crisp product description highlighting key attributes, materials, color, usage, and distinguishing traits.",
90
+
91
+ "Portrait (photography) (short)": "One sentence portrait description: subject, pose/expression, camera angle, lighting, background.",
92
+ "Portrait (photography) (long)": "Describe a portrait: subject, age range, pose, facial expression, camera angle, focal length cues, lighting, background.",
93
+
94
+ "Landscape (photography) (short)": "One sentence landscape description: major elements, time of day, weather, vantage point, mood.",
95
+ "Landscape (photography) (long)": "Describe landscape elements, time of day, weather, vantage point, composition, and mood.",
96
+
97
+ "Art analysis (no artist names) (short)": "One sentence describing medium, style, composition, palette; do not mention artist/title.",
98
+ "Art analysis (no artist names) (long)": "Analyze the artwork's visible elements, medium, style, composition, palette. Do not mention artist names or titles.",
99
+
100
+ "Social caption (short)": "Write a short, catchy caption (max 25 words) describing the visible content. No hashtags.",
101
+ "Social caption (long)": "Write a slightly longer, engaging caption (≤50 words) describing the visible content. No hashtags."
102
+ }
103
+
104
+ EXTRA_CHOICES = [
105
+ "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
106
+ "Do NOT include information about whether there is a watermark or not.",
107
+ "Do NOT use any ambiguous language.",
108
+ "ONLY describe the most important elements of the image.",
109
+ "Include information about the ages of any people/characters when applicable.",
110
+ "Explicitly specify the vantage height (eye-level, low-angle worm’s-eye, bird’s-eye, drone, rooftop, etc.).",
111
+ "Focus captions only on clothing/fashion details.",
112
+ "Focus on setting, scenery, and context; ignore subject details.",
113
+ "ONLY describe the subject’s pose, movement, or action. Do NOT mention appearance, clothing, or setting.",
114
+ "Do NOT include anything sexual; keep it PG.",
115
+ "Include synonyms/alternate phrasing to diversify training set.",
116
+ "ALWAYS arrange caption elements in the order → Subject, Clothing/Accessories, Action/Pose, Setting/Environment, Lighting/Camera/Style.",
117
+ "Do NOT mention the image's resolution.",
118
+ "Include information about depth, lighting, and camera angle.",
119
+ "Include information on composition (rule of thirds, symmetry, leading lines, etc).",
120
+ "Specify the depth of field and whether the background is in focus or blurred.",
121
+ "If applicable, mention the likely use of artificial or natural lighting sources.",
122
+ "If it is a work of art, do not include the artist's name or the title of the work.",
123
+ "Identify the image orientation (portrait, landscape, or square) if obvious.",
124
+ ]
125
+ NAME_OPTION = "If there is a person/character in the image you must refer to them as {name}."
126
+
127
+ # ────────────────────────────────────────────────────────
128
+ # Helpers (hashing, thumbs, resize)
129
+ # ────────────────────────────────────────────────────────
130
+ def sha1_file(path: str) -> str:
131
+ h = hashlib.sha1()
132
+ with open(path, "rb") as f:
133
+ for chunk in iter(lambda: f.read(8192), b""):
134
+ h.update(chunk)
135
+ return h.hexdigest()
136
+
137
+ def thumb_path(sha1: str, max_side=96) -> str:
138
+ return os.path.join(THUMB_CACHE, f"{sha1}_{max_side}.jpg")
139
+
140
+ def ensure_thumb(path: str, sha1: str, max_side=96) -> str:
141
+ out = thumb_path(sha1, max_side)
142
+ if os.path.exists(out):
143
+ return out
144
+ try:
145
+ im = Image.open(path).convert("RGB")
146
+ except Exception:
147
+ return ""
148
+ w, h = im.size
149
+ if max(w, h) > max_side:
150
+ s = max_side / max(w, h)
151
+ im = im.resize((int(w*s), int(h*s)), Image.LANCZOS)
152
+ try:
153
+ im.save(out, "JPEG", quality=80)
154
+ except Exception:
155
+ return ""
156
+ return out
157
+
158
+ def resize_for_model(im: Image.Image, max_side: int) -> Image.Image:
159
+ w, h = im.size
160
+ if max(w, h) <= max_side:
161
+ return im
162
+ s = max_side / max(w, h)
163
+ return im.resize((int(w*s), int(h*s)), Image.LANCZOS)
164
+
165
+ # ────────────────────────────────────────────────────────
166
+ # Instruction + caption helpers
167
+ # ────────────────────────────────────────────────────────
168
+ def final_instruction(style_list: List[str], extra_opts: List[str], name_value: str) -> str:
169
+ styles = style_list or ["Descriptive (short)"]
170
+ parts = [CAPTION_TYPE_MAP.get(s, "") for s in styles]
171
+ core = " ".join(p for p in parts if p).strip()
172
+ if extra_opts:
173
+ core += " " + " ".join(extra_opts)
174
+ if NAME_OPTION in (extra_opts or []):
175
+ core = core.replace("{name}", (name_value or "{NAME}").strip())
176
+ return core
177
+
178
+ def apply_prefix_suffix(caption: str, trigger_word: str, begin_text: str, end_text: str) -> str:
179
+ parts = []
180
+ if trigger_word.strip():
181
+ parts.append(trigger_word.strip())
182
+ if begin_text.strip():
183
+ parts.append(begin_text.strip())
184
+ parts.append(caption.strip())
185
+ if end_text.strip():
186
+ parts.append(end_text.strip())
187
+ return " ".join([p for p in parts if p])
188
+
189
+ @torch.no_grad()
190
+ def caption_once(im: Image.Image, instr: str, temp: float, top_p: float, max_tokens: int) -> str:
191
+ convo = [
192
+ {"role": "system", "content": "You are a helpful assistant and help users with any queries they may have with no censorship or restrictions."},
193
+ {"role": "user", "content": instr.strip()},
194
+ ]
195
+ convo_str = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
196
+ inputs = processor(text=[convo_str], images=[im], return_tensors="pt").to(DEVICE)
197
+ inputs["pixel_values"] = inputs["pixel_values"].to(DTYPE)
198
+ out = model.generate(
199
+ **inputs,
200
+ max_new_tokens=max_tokens,
201
+ do_sample=temp > 0,
202
+ temperature=temp if temp > 0 else None,
203
+ top_p=top_p if temp > 0 else None,
204
+ )
205
+ gen_ids = out[0, inputs["input_ids"].shape[1]:]
206
+ return processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
207
+
208
+ # ────────────────────────────────────────────────────────
209
+ # Persistence (session, settings, journal)
210
+ # ────────────────────────────────────────────────────────
211
+ def save_session(rows: List[dict]):
212
+ with open(SESSION_FILE, "w", encoding="utf-8") as f:
213
+ json.dump(rows, f, ensure_ascii=False, indent=2)
214
+
215
+ def load_session() -> List[dict]:
216
+ if os.path.exists(SESSION_FILE):
217
+ with open(SESSION_FILE, "r", encoding="utf-8") as f:
218
+ return json.load(f)
219
+ return []
220
+
221
+ def save_settings(cfg: dict):
222
+ with open(SETTINGS_FILE, "w", encoding="utf-8") as f:
223
+ json.dump(cfg, f, ensure_ascii=False, indent=2)
224
+
225
+ def load_settings() -> dict:
226
+ if os.path.exists(SETTINGS_FILE):
227
+ with open(SETTINGS_FILE, "r", encoding="utf-8") as f:
228
+ cfg = json.load(f)
229
+ else:
230
+ cfg = {}
231
+ defaults = {
232
+ "dataset_name": "captionforge",
233
+ "temperature": 0.6,
234
+ "top_p": 0.9,
235
+ "max_tokens": 256,
236
+ "chunk_mode": "Manual (step)",
237
+ "chunk_size": 10,
238
+ "max_side": min(896, MAX_SIDE_CAP),
239
+ "styles": ["Character training (short)"],
240
+ "extras": [],
241
+ "name": "",
242
+ "trigger": "",
243
+ "begin": "",
244
+ "end": "",
245
+ }
246
+ for k, v in defaults.items():
247
+ cfg.setdefault(k, v)
248
+
249
+ # migrate legacy names
250
+ legacy_map = {
251
+ "Descriptive": "Descriptive (short)",
252
+ "LoRA (Flux_D Realism)": "LoRA (Flux_D Realism) (short)",
253
+ "Portrait (photography)": "Portrait (photography) (short)",
254
+ "Landscape (photography)": "Landscape (photography) (short)",
255
+ "Art analysis (no artist names)": "Art analysis (no artist names) (short)",
256
+ "E-commerce product": "E-commerce product (short)",
257
+ }
258
+ styles = cfg.get("styles") or []
259
+ migrated = []
260
+ for s in styles if isinstance(styles, list) else [styles]:
261
+ migrated.append(legacy_map.get(s, s))
262
+ migrated = [s for s in migrated if s in STYLE_OPTIONS]
263
+ if not migrated:
264
+ migrated = ["Descriptive (short)"]
265
+ cfg["styles"] = migrated
266
+ return cfg
267
+
268
+ def save_journal(data: dict):
269
+ with open(JOURNAL_FILE, "w", encoding="utf-8") as f:
270
+ json.dump(data, f, ensure_ascii=False, indent=2)
271
+
272
+ def load_journal() -> dict:
273
+ if os.path.exists(JOURNAL_FILE):
274
+ with open(JOURNAL_FILE, "r", encoding="utf-8") as f:
275
+ return json.load(f)
276
+ return {}
277
+
278
+ # ────────────────────────────────────────────────────────
279
+ # Logo loader
280
+ # ────────────────────────────────────────────────────────
281
+ def logo_b64_img() -> str:
282
+ candidates = [
283
+ os.path.join(APP_DIR, "captionforge-logo.png"),
284
+ "/home/user/app/captionforge-logo.png",
285
+ "captionforge-logo.png",
286
+ ]
287
+ for p in candidates:
288
+ if os.path.exists(p):
289
+ with open(p, "rb") as f:
290
+ b64 = base64.b64encode(f.read()).decode("ascii")
291
+ return f"<img src='data:image/png;base64,{b64}' alt='CaptionForge' class='cf-logo'>"
292
+ return ""
293
+
294
+ # ────────────────────────────────────────────────────────
295
+ # Import / Export helpers
296
+ # ────────────────────────────────────────────────────────
297
+ def sanitize_basename(s: str) -> str:
298
+ s = (s or "").strip() or "captionforge"
299
+ return re.sub(r"[^A-Za-z0-9._-]+", "_", s)[:120]
300
+
301
+ def ts_stamp() -> str:
302
+ return time.strftime("%Y%m%d_%H%M%S")
303
+
304
+ def export_csv(rows: List[dict], basename: str) -> str:
305
+ base = sanitize_basename(basename)
306
+ out = f"/tmp/{base}_{ts_stamp()}.csv"
307
+ with open(out, "w", newline="", encoding="utf-8") as f:
308
+ w = csv.writer(f)
309
+ w.writerow(["sha1", "filename", "caption"])
310
+ for r in rows:
311
+ w.writerow([r.get("sha1",""), r.get("filename",""), r.get("caption","")])
312
+ return out
313
+
314
+ def export_excel(rows: List[dict], basename: str) -> str:
315
+ try:
316
+ from openpyxl import Workbook
317
+ from openpyxl.drawing.image import Image as XLImage
318
+ except Exception as e:
319
+ raise RuntimeError("Excel export requires 'openpyxl' in requirements.txt.") from e
320
+ base = sanitize_basename(basename)
321
+ out = f"/tmp/{base}_{ts_stamp()}.xlsx"
322
+ wb = Workbook(); ws = wb.active; ws.title = "CaptionForge"
323
+ ws.append(["image", "filename", "caption"])
324
+ ws.column_dimensions["A"].width = 24
325
+ ws.column_dimensions["B"].width = 42
326
+ ws.column_dimensions["C"].width = 100
327
+ r_i = 2
328
+ for r in rows:
329
+ fn = r.get("filename",""); cap = r.get("caption","")
330
+ ws.cell(row=r_i, column=2, value=fn)
331
+ ws.cell(row=r_i, column=3, value=cap)
332
+ img_path = r.get("thumb_path") or r.get("path")
333
+ if img_path and os.path.exists(img_path):
334
+ try:
335
+ ximg = XLImage(img_path); ws.add_image(ximg, f"A{r_i}"); ws.row_dimensions[r_i].height = 110
336
+ except Exception:
337
+ pass
338
+ r_i += 1
339
+ wb.save(out); return out
340
+
341
+ def export_txt_zip(rows: List[dict], basename: str) -> str:
342
+ base = sanitize_basename(basename)
343
+ for name in os.listdir(TXT_EXPORT_DIR):
344
+ try: os.remove(os.path.join(TXT_EXPORT_DIR, name))
345
+ except Exception: pass
346
+ used: Dict[str,int] = {}
347
+ for r in rows:
348
+ orig = r.get("filename") or r.get("sha1") or "item"
349
+ stem = re.sub(r"\.[A-Za-z0-9]+$", "", orig)
350
+ stem = sanitize_basename(stem)
351
+ if stem in used:
352
+ used[stem] += 1
353
+ stem = f"{stem}_{used[stem]}"
354
+ else:
355
+ used[stem] = 0
356
+ txt_path = os.path.join(TXT_EXPORT_DIR, f"{stem}.txt")
357
+ with open(txt_path, "w", encoding="utf-8") as f:
358
+ f.write(r.get("caption", ""))
359
+ zpath = f"/tmp/{base}_{ts_stamp()}_txt.zip"
360
+ with zipfile.ZipFile(zpath, "w", zipfile.ZIP_DEFLATED) as z:
361
+ for name in os.listdir(TXT_EXPORT_DIR):
362
+ if name.endswith(".txt"):
363
+ z.write(os.path.join(TXT_EXPORT_DIR, name), arcname=name)
364
+ return zpath
365
+
366
+ def import_csv_jsonl(path: str, rows: List[dict]) -> List[dict]:
367
+ """Merge captions from CSV/JSONL into current rows by sha1→filename, updating only 'caption'."""
368
+ if not path or not os.path.exists(path):
369
+ return rows or []
370
+ rows = rows or []
371
+ by_sha = {r.get("sha1"): r for r in rows if r.get("sha1")}
372
+ by_fn = {r.get("filename"): r for r in rows if r.get("filename")}
373
+
374
+ def _apply(item: dict):
375
+ if not isinstance(item, dict):
376
+ return
377
+ cap = (item.get("caption") or "").strip()
378
+ if not cap:
379
+ return
380
+ sha = (item.get("sha1") or "").strip()
381
+ fn = (item.get("filename") or "").strip()
382
+ tgt = by_sha.get(sha) if sha else None
383
+ if not tgt and fn:
384
+ tgt = by_fn.get(fn)
385
+ if tgt is not None:
386
+ prev = tgt.get("caption", "")
387
+ if cap != prev:
388
+ tgt.setdefault("history", []).append(prev)
389
+ tgt["caption"] = cap
390
+
391
+ ext = os.path.splitext(path)[1].lower()
392
+ try:
393
+ if ext == ".csv":
394
+ with open(path, "r", encoding="utf-8") as f:
395
+ for row in csv.DictReader(f):
396
+ _apply(row)
397
+ elif ext in (".jsonl", ".ndjson"):
398
+ with open(path, "r", encoding="utf-8") as f:
399
+ for line in f:
400
+ line = line.strip()
401
+ if not line:
402
+ continue
403
+ try:
404
+ obj = json.loads(line)
405
+ except Exception:
406
+ continue
407
+ _apply(obj)
408
+ else:
409
+ with open(path, "r", encoding="utf-8") as f:
410
+ data = json.load(f)
411
+ if isinstance(data, list):
412
+ for obj in data:
413
+ _apply(obj)
414
+ except Exception as e:
415
+ print("[CaptionForge] Import merge failed:", e)
416
+ save_session(rows)
417
+ return rows
418
+
419
+ # ────────────────────────────────────────────────────────
420
+ # Batching / processing
421
+ # ────────────────────────────────────────────────────────
422
+ def adaptive_defaults(vram_gb: int) -> Tuple[int, int]:
423
+ if vram_gb >= 40: return 24, min(1024, MAX_SIDE_CAP)
424
+ if vram_gb >= 24: return 16, min(1024, MAX_SIDE_CAP)
425
+ if vram_gb >= 12: return 10, 896
426
+ if vram_gb >= 8: return 6, 768
427
+ return 4, 640
428
+
429
+ def warmup_if_needed(temp: float, top_p: float, max_tokens: int, max_side: int):
430
+ try:
431
+ im = Image.new("RGB", (min(128, max_side), min(128, max_side)), (127,127,127))
432
+ _ = caption_once(im, "Warm up.", temp=0.0, top_p=top_p, max_tokens=min(16, max_tokens))
433
+ except Exception as e:
434
+ print("[CaptionForge] Warm-up skipped:", e)
435
+
436
+ @spaces.GPU()
437
+ @torch.no_grad()
438
+ def process_batch(files: List[str], include_dups: bool, rows: List[dict],
439
+ instr_text: str, temp: float, top_p: float, max_tokens: int,
440
+ trigger_word: str, begin_text: str, end_text: str,
441
+ mode: str, chunk_size: int, max_side: int, step_once: bool,
442
+ progress=gr.Progress(track_tqdm=True)) -> Tuple[List[dict], List[str], List[str]]:
443
+ if torch.cuda.is_available():
444
+ torch.cuda.empty_cache()
445
+ files = files or []
446
+ if not files: return rows, [], []
447
+
448
+ auto_mode = mode == "Auto"
449
+ if auto_mode:
450
+ chunk_size, max_side = adaptive_defaults(VRAM_GB)
451
+
452
+ sha_seen = {r.get("sha1") for r in rows}
453
+ f_infos = []
454
+ for f in files:
455
+ try:
456
+ s = sha1_file(f)
457
+ except Exception:
458
+ continue
459
+ if (not include_dups) and (s in sha_seen):
460
+ continue
461
+ f_infos.append((f, os.path.basename(f), s))
462
+ if not f_infos: return rows, [], []
463
+
464
+ warmup_if_needed(temp, top_p, max_tokens, max_side)
465
+
466
+ chunks = [f_infos[i:i+chunk_size] for i in range(0, len(f_infos), chunk_size)]
467
+ if step_once:
468
+ chunks = chunks[:1]
469
+
470
+ error_log: List[str] = []
471
+ remaining: List[str] = []
472
+
473
+ for ci, chunk in enumerate(chunks):
474
+ for (f, fname, sha) in progress.tqdm(chunk, desc=f"Chunk {ci+1}/{len(chunks)}"):
475
+ thumb = ensure_thumb(f, sha, 96)
476
+ attempt = 0
477
+ cap = None
478
+ while attempt < 2:
479
+ attempt += 1
480
+ try:
481
+ im = Image.open(f).convert("RGB")
482
+ im = resize_for_model(im, max_side)
483
+ cap = caption_once(im, instr_text, temp, top_p, max_tokens)
484
+ cap = apply_prefix_suffix(cap, trigger_word, begin_text, end_text)
485
+ break
486
+ except Exception as e:
487
+ error_log.append(f"{fname} ({sha[:8]}): {type(e).__name__}: {e} [attempt {attempt}]")
488
+ rows.append({
489
+ "sha1": sha, "filename": fname, "path": f, "thumb_path": thumb,
490
+ "caption": cap or "", "history": []
491
+ })
492
+ save_session(rows)
493
+
494
+ if step_once and len(f_infos) > len(chunks[0]):
495
+ remaining = [fi[0] for fi in f_infos[len(chunks[0]):]]
496
+ save_journal({"remaining_files": remaining})
497
+ return rows, remaining, error_log
498
+
499
+ # ────────────────────────────────────────────────────────
500
+ # UI
501
+ # ────────────────────────────────────────────────────────
502
+ BASE_CSS = """
503
+ :root{--galleryW:50%;--tableW:50%;}
504
+ .gradio-container{max-width:100%!important}
505
+ .cf-hero{display:grid;grid-template-columns:auto 1fr;column-gap:18px;align-items:center;justify-content:center;
506
+ gap: 16px; margin:4px 0 12px;}
507
+ .cf-logo{height:calc(3.25rem + 3 * 1.1rem + 18px);width:auto;object-fit:contain}
508
+ .cf-title{margin:0;font-size:3.25rem;line-height:1;letter-spacing:.2px}
509
+ .cf-sub{margin:6px 0 0;font-size:1.1rem;color:#cfd3da}
510
+ .cf-row{display:flex;gap:12px}
511
+ .cf-col-gallery{flex:0 0 var(--galleryW)}
512
+ .cf-col-table{flex:0 0 var(--tableW)}
513
+ /* Shared scroll look */
514
+ .cf-scroll{max-height:70vh; overflow-y:auto; border:1px solid #e6e6e6; border-radius:10px; padding:8px}
515
+ /* Uniform sizes */
516
+ #cfGal .grid > div { height: 96px; }
517
+ /* Hide SHA1 in table (first column) */
518
+ #cfTable table thead tr th:first-child { display:none; }
519
+ #cfTable table tbody tr td:first-child { display:none; }
520
+ """
521
+
522
+ with gr.Blocks(css=BASE_CSS, title="CaptionForge") as demo:
523
+ settings = load_settings()
524
+ settings["styles"] = [s for s in settings.get("styles", []) if s in STYLE_OPTIONS] or ["Descriptive (short)"]
525
+
526
+ header = gr.HTML(value="""
527
+ <div class="cf-hero">
528
+ %s
529
+ <div>
530
+ <h1 class="cf-title">CaptionForge</h1>
531
+ <div class="cf-sub">Batch captioning</div>
532
+ <div class="cf-sub">Scrollable editor & autosave</div>
533
+ <div class="cf-sub">CSV / Excel / TXT export</div>
534
+ </div>
535
+ </div>
536
+ <hr>
537
+ """ % logo_b64_img())
538
+
539
+ # ===== Controls (top)
540
+ with gr.Group():
541
+ with gr.Row():
542
+ with gr.Column(scale=2):
543
+ style_checks = gr.CheckboxGroup(
544
+ choices=STYLE_OPTIONS,
545
+ value=settings.get("styles", ["Descriptive (short)"]),
546
+ label="Caption style (choose one or combine)"
547
+ )
548
+ with gr.Accordion("Extra options", open=True):
549
+ extra_opts = gr.CheckboxGroup(
550
+ choices=[NAME_OPTION] + EXTRA_CHOICES,
551
+ value=settings.get("extras", []),
552
+ label=None
553
+ )
554
+ with gr.Accordion("Name & Prefix/Suffix", open=True):
555
+ name_input = gr.Textbox(label="Person / Character Name", value=settings.get("name", ""))
556
+ trig = gr.Textbox(label="Trigger word")
557
+ add_start = gr.Textbox(label="Add text to start")
558
+ add_end = gr.Textbox(label="Add text to end")
559
+
560
+ with gr.Column(scale=1):
561
+ instruction_preview = gr.Textbox(label="Model Instructions", lines=14)
562
+ dataset_name = gr.Textbox(label="Dataset name (used for export file titles)", value=settings.get("dataset_name", "captionforge"))
563
+
564
+ with gr.Row():
565
+ chunk_mode = gr.Radio(
566
+ choices=["Auto", "Manual (all at once)", "Manual (step)"],
567
+ value=settings.get("chunk_mode", "Manual (step)"),
568
+ label="Batch mode"
569
+ )
570
+ chunk_size = gr.Slider(1, 50, settings.get("chunk_size", 10), step=1, label="Chunk size")
571
+ max_side = gr.Slider(256, MAX_SIDE_CAP, settings.get("max_side", min(896, MAX_SIDE_CAP)), step=32, label="Max side (resize)")
572
+
573
+ # Auto-refresh instruction & persist key controls
574
+ def _refresh_instruction(styles, extra, name_value, trigv, begv, endv):
575
+ instr = final_instruction(styles or ["Descriptive (short)"], extra or [], name_value)
576
+ cfg = load_settings()
577
+ cfg.update({
578
+ "styles": styles or ["Descriptive (short)"],
579
+ "extras": extra or [],
580
+ "name": name_value,
581
+ "trigger": trigv, "begin": begv, "end": endv
582
+ })
583
+ save_settings(cfg)
584
+ return instr
585
+
586
+ for comp in [style_checks, extra_opts, name_input, trig, add_start, add_end]:
587
+ comp.change(_refresh_instruction, inputs=[style_checks, extra_opts, name_input, trig, add_start, add_end], outputs=[instruction_preview])
588
+
589
+ # ===== File inputs & actions
590
+ with gr.Accordion("Uploaded images", open=True) as uploads_acc:
591
+ input_files = gr.File(label="Drop images", file_types=["image"], file_count="multiple", type="filepath")
592
+ import_file = gr.File(label="Import CSV or JSONL (merge captions)")
593
+
594
+ with gr.Row():
595
+ import_btn = gr.Button("Import → Merge")
596
+ resume_btn = gr.Button("Resume last run", variant="secondary")
597
+ run_button = gr.Button("Caption batch", variant="primary")
598
+
599
+ # Step-chunk controls
600
+ step_panel = gr.Group(visible=False)
601
+ with step_panel:
602
+ step_msg = gr.Markdown("")
603
+ step_next = gr.Button("Process next chunk")
604
+ step_finish = gr.Button("Finish")
605
+
606
+ # States
607
+ rows_state = gr.State(load_session())
608
+ visible_count_state = gr.State(100)
609
+ pending_files = gr.State([])
610
+ remaining_state = gr.State([])
611
+ autosave_md = gr.Markdown("Ready.")
612
+
613
+ # Error log
614
+ with gr.Accordion("Error log", open=False):
615
+ log_md = gr.Markdown("")
616
+
617
+ # Side-by-side gallery + table with shared scroll
618
+ with gr.Row():
619
+ with gr.Column(scale=1):
620
+ gallery = gr.Gallery(
621
+ label="Results (images)",
622
+ show_label=True,
623
+ columns=10, # 10 per row as requested
624
+ height=520,
625
+ elem_id="cfGal",
626
+ elem_classes=["cf-scroll"]
627
+ )
628
+ with gr.Column(scale=1):
629
+ table = gr.Dataframe(
630
+ label="Editable captions (whole session)",
631
+ value=[], # filled by _render
632
+ headers=["sha1", "filename", "caption"], # include sha1, but it's hidden by CSS
633
+ interactive=True,
634
+ elem_id="cfTable",
635
+ elem_classes=["cf-scroll"]
636
+ )
637
+
638
+ # Exports aligned under buttons
639
+ with gr.Row():
640
+ with gr.Column():
641
+ export_csv_btn = gr.Button("Export CSV")
642
+ csv_file = gr.File(label="CSV file", visible=False)
643
+ with gr.Column():
644
+ export_xlsx_btn = gr.Button("Export Excel (.xlsx)")
645
+ xlsx_file = gr.File(label="Excel file", visible=False)
646
+ with gr.Column():
647
+ export_txt_btn = gr.Button("Export captions as .txt (zip)")
648
+ txt_zip = gr.File(label="TXT zip", visible=False)
649
+
650
+ # Render helpers
651
+ def _render(rows, vis_n):
652
+ rows = rows or []
653
+ n = max(0, int(vis_n) if vis_n else 0)
654
+ win = rows[:n]
655
+
656
+ # Gallery expects list of paths (Gradio 5 ok)
657
+ gal = []
658
+ for r in win:
659
+ p = r.get("thumb_path") or r.get("path")
660
+ if isinstance(p, str) and os.path.exists(p):
661
+ gal.append(p)
662
+
663
+ # Dataframe rows (sha1 kept for mapping; hidden via CSS)
664
+ df = [[str(r.get("sha1","")), str(r.get("filename","")), str(r.get("caption",""))] for r in win]
665
+ info = f"Showing {len(win)} of {len(rows)}"
666
+ return gal, df, info
667
+
668
+ # Initial loads
669
+ demo.load(lambda s,e,n: final_instruction(s or ["Descriptive (short)"], e or [], n),
670
+ inputs=[style_checks, extra_opts, name_input], outputs=[instruction_preview])
671
+ demo.load(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
672
+
673
+ # Scroll sync (Gallery + Table)
674
+ gr.HTML("""
675
+ <script>
676
+ (function () {
677
+ function findGalleryScrollRoot() {
678
+ const host = document.querySelector("#cfGal");
679
+ if (!host) return null;
680
+ return host.querySelector(".grid") || host.querySelector("[data-testid='gallery']") || host;
681
+ }
682
+ function findTableScrollRoot() {
683
+ const host = document.querySelector("#cfTable");
684
+ if (!host) return null;
685
+ return host.querySelector(".wrap") ||
686
+ host.querySelector(".dataframe-wrap") ||
687
+ (host.querySelector("table") ? host.querySelector("table").parentElement : null) ||
688
+ host;
689
+ }
690
+ function syncScroll(a, b) {
691
+ if (!a || !b) return;
692
+ let lock = false;
693
+ const onScrollA = () => { if (lock) return; lock = true; b.scrollTop = a.scrollTop; lock = false; };
694
+ const onScrollB = () => { if (lock) return; lock = true; a.scrollTop = b.scrollTop; lock = false; };
695
+ a.addEventListener("scroll", onScrollA, { passive: true });
696
+ b.addEventListener("scroll", onScrollB, { passive: true });
697
+ }
698
+ let tries = 0;
699
+ const timer = setInterval(() => {
700
+ tries++;
701
+ const gal = findGalleryScrollRoot();
702
+ const tab = findTableScrollRoot();
703
+ if (gal && tab) {
704
+ const H = Math.min(gal.clientHeight || 520, tab.clientHeight || 520);
705
+ gal.style.maxHeight = H + "px";
706
+ gal.style.overflowY = "auto";
707
+ tab.style.maxHeight = H + "px";
708
+ tab.style.overflowY = "auto";
709
+ syncScroll(gal, tab);
710
+ clearInterval(timer);
711
+ }
712
+ if (tries > 20) clearInterval(timer);
713
+ }, 100);
714
+ })();
715
+ </script>
716
+ """)
717
+
718
+ # Auto refresh instruction persists (already wired above)
719
+ # Table edits -> save + rerender (map by sha1, which is hidden in UI)
720
+ def on_table_edit(df_rows, rows):
721
+ by_sha = {r.get("sha1"): r for r in (rows or [])}
722
+ for row in (df_rows or []):
723
+ if not row:
724
+ continue
725
+ sha = str(row[0]) if len(row) > 0 else ""
726
+ tgt = by_sha.get(sha)
727
+ if not tgt:
728
+ continue
729
+ prev = tgt.get("caption", "")
730
+ newc = row[2] if len(row) > 2 else prev
731
+ if newc != prev:
732
+ hist = tgt.setdefault("history", [])
733
+ hist.append(prev)
734
+ tgt["caption"] = newc
735
+ save_session(rows or [])
736
+ return rows, f"Saved • {time.strftime('%H:%M:%S')}"
737
+ table.change(on_table_edit, inputs=[table, rows_state], outputs=[rows_state, autosave_md])\
738
+ .then(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
739
+
740
+ # Import merge
741
+ def do_import(file, rows):
742
+ if not file:
743
+ return rows, gr.update(value="No file"), rows
744
+ rows = import_csv_jsonl(file.name, rows or [])
745
+ return rows, gr.update(value=f"Merged from {os.path.basename(file.name)} at {time.strftime('%H:%M:%S')}"), rows
746
+ import_btn.click(do_import, inputs=[import_file, rows_state], outputs=[rows_state, autosave_md, rows_state])\
747
+ .then(_render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md])
748
+
749
+ # Prepare run (hash dupes)
750
+ def prepare_run(files, rows, pending):
751
+ files = files or []
752
+ rows = rows or []
753
+ pending = pending or []
754
+ if not files:
755
+ return [], gr.update(open=True), []
756
+ existing = {r.get("sha1") for r in rows if r.get("sha1")}
757
+ pending_sha = set()
758
+ for f in pending:
759
+ try:
760
+ pending_sha.add(sha1_file(f))
761
+ except Exception:
762
+ pass
763
+ keep = []
764
+ for f in files:
765
+ try:
766
+ s = sha1_file(f)
767
+ except Exception:
768
+ continue
769
+ if s not in existing and s not in pending_sha:
770
+ keep.append(f)
771
+ return keep, gr.update(open=False), keep
772
+ input_files.change(prepare_run, inputs=[input_files, rows_state, pending_files], outputs=[pending_files, uploads_acc, pending_files])
773
+
774
+ # Step visibility helper (dynamic message)
775
+ def _step_visibility(remain, mode):
776
+ if (mode == "Manual (step)") and remain:
777
+ return gr.update(visible=True), gr.update(value=f"{len(remain)} files remain. Process next chunk?")
778
+ return gr.update(visible=False), gr.update(value="")
779
+
780
+ # Run helpers (use settings for gen params; sliders removed from UI)
781
+ def _run_once(pending, rows, instr, trigv, begv, endv, mode, csize, mside):
782
+ cfg = load_settings()
783
+ t = cfg.get("temperature", 0.6)
784
+ p = cfg.get("top_p", 0.9)
785
+ m = cfg.get("max_tokens", 256)
786
+ step_once = (mode == "Manual (step)")
787
+ new_rows, remaining, errors = process_batch(
788
+ pending, False, rows or [], instr, t, p, m, trigv, begv, endv,
789
+ mode, int(csize), int(mside), step_once
790
+ )
791
+ log = "\n".join(errors) if errors else "(no errors)"
792
+ return new_rows, remaining, log, f"Saved • {time.strftime('%H:%M:%S')}"
793
+
794
+ run_button.click(
795
+ _run_once,
796
+ inputs=[pending_files, rows_state, instruction_preview, trig, add_start, add_end, chunk_mode, chunk_size, max_side],
797
+ outputs=[rows_state, remaining_state, log_md, autosave_md]
798
+ ).then(
799
+ _render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md]
800
+ ).then(
801
+ _step_visibility, inputs=[remaining_state, chunk_mode], outputs=[step_panel, step_msg]
802
+ )
803
+
804
+ def _resume(rows):
805
+ j = load_journal()
806
+ rem = j.get("remaining_files", [])
807
+ return rem, ("Nothing to resume" if not rem else f"Loaded {len(rem)} remaining")
808
+ resume_btn.click(_resume, inputs=[rows_state], outputs=[pending_files, autosave_md])
809
+
810
+ def _step_next(remain, rows, instr, trigv, begv, endv, mode, csize, mside):
811
+ return _run_once(remain, rows, instr, trigv, begv, endv, mode, csize, mside)
812
+ step_next.click(
813
+ _step_next,
814
+ inputs=[remaining_state, rows_state, instruction_preview, trig, add_start, add_end, chunk_mode, chunk_size, max_side],
815
+ outputs=[rows_state, remaining_state, log_md, autosave_md]
816
+ ).then(
817
+ _render, inputs=[rows_state, visible_count_state], outputs=[gallery, table, autosave_md]
818
+ ).then(
819
+ _step_visibility, inputs=[remaining_state, chunk_mode], outputs=[step_panel, step_msg]
820
+ )
821
+
822
+ step_finish.click(lambda: (gr.update(visible=False), gr.update(value=""), []), outputs=[step_panel, step_msg, remaining_state])
823
 
824
+ # Exports (reveal file widgets when created)
825
+ export_csv_btn.click(
826
+ lambda rows, base: (export_csv(rows or [], base), gr.update(visible=True)),
827
+ inputs=[rows_state, dataset_name], outputs=[csv_file, csv_file]
828
+ )
829
+ export_xlsx_btn.click(
830
+ lambda rows, base: (export_excel(rows or [], base), gr.update(visible=True)),
831
+ inputs=[rows_state, dataset_name], outputs=[xlsx_file, xlsx_file]
832
+ )
833
+ export_txt_btn.click(
834
+ lambda rows, base: (export_txt_zip(rows or [], base), gr.update(visible=True)),
835
+ inputs=[rows_state, dataset_name], outputs=[txt_zip, txt_zip]
836
+ )
837
 
838
+ # Launch
839
+ demo.queue(max_size=64).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))