File size: 2,796 Bytes
0bdbec3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from __future__ import annotations

from datetime import datetime
from pathlib import Path
from typing import Callable

from .state import AppState
from .nodes import parser, planner, prompt_gen, gen_generate, gen_labels, judge, select, edit, archive
from .nodes import gen_fusion


def run_pipeline(state: AppState) -> AppState:
    # ensure outdir
    outdir = Path(state.get("outdir") or _default_outdir())
    outdir.mkdir(parents=True, exist_ok=True)
    state["outdir"] = str(outdir)
    state["round"] = int(state.get("round", 0))

    # 1) parse β†’ 2) plan β†’ 3) prompts β†’ 4) generate (skeleton) β†’ 5) generator_2 (labels) β†’ 6) judge β†’ 7) select
    state = parser.run(state)
    state = planner.run(state)
    state = prompt_gen.run(state)
    state = gen_generate.run(state)
    state = gen_labels.run(state)
    state = judge.run(state)
    state = select.run(state)

    # 8) edit loop (if hard violations or any violations, and round < T)
    T = int(state.get("T", 0))
    while (state.get("hard_violations") or state.get("violations")) and state.get("round", 0) < T:
        state["round"] = int(state.get("round", 0)) + 1
        state = edit.apply_edits(state)
        # re-judge best image
        state = _judge_best_only(state)
        state = select.run(state)

    # 9) archive
    state = archive.run(state)
    return state


def run_fusion_pipeline(state: AppState) -> AppState:
    # ensure outdir
    outdir = Path(state.get("outdir") or _default_outdir())
    outdir.mkdir(parents=True, exist_ok=True)
    state["outdir"] = str(outdir)
    state["round"] = int(state.get("round", 0))

    # Generate fused candidates from images + text instructions
    state = gen_fusion.run(state)

    # If we have candidates, select first as best; optionally judge later
    if state.get("images"):
        state["best_image"] = state["images"][0]

    # Archive results (final.png etc.)
    state = archive.run(state)
    return state


def _judge_best_only(state: AppState) -> AppState:
    # Only score the current best image again
    from .llm.gemini import call_gemini

    if not state.get("best_image"):
        return state
    res = call_gemini("judge", image_path=state["best_image"].path, spec=state.get("spec", {}))
    vios = list(res.get("violations", []))
    hard = [v for v in vios if str(v).strip().lower().startswith("hard:")]
    if not hard:
        hard = [v for v in vios if ("labels" in str(v).lower() and "missing" in str(v).lower())]
    state["scores"] = [{
        "image_path": state["best_image"].path,
        "score": float(res.get("score", 0.0)),
        "violations": vios,
    }]
    state["hard_violations"] = hard
    return state


def _default_outdir() -> str:
    return f"artifacts/run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"