3v324v23's picture
feat: add NNGen project under NNGen/ and ignore local secrets
0bdbec3
from __future__ import annotations
from datetime import datetime
from pathlib import Path
from typing import Callable
from .state import AppState
from .nodes import parser, planner, prompt_gen, gen_generate, gen_labels, judge, select, edit, archive
from .nodes import gen_fusion
def run_pipeline(state: AppState) -> AppState:
# ensure outdir
outdir = Path(state.get("outdir") or _default_outdir())
outdir.mkdir(parents=True, exist_ok=True)
state["outdir"] = str(outdir)
state["round"] = int(state.get("round", 0))
# 1) parse β†’ 2) plan β†’ 3) prompts β†’ 4) generate (skeleton) β†’ 5) generator_2 (labels) β†’ 6) judge β†’ 7) select
state = parser.run(state)
state = planner.run(state)
state = prompt_gen.run(state)
state = gen_generate.run(state)
state = gen_labels.run(state)
state = judge.run(state)
state = select.run(state)
# 8) edit loop (if hard violations or any violations, and round < T)
T = int(state.get("T", 0))
while (state.get("hard_violations") or state.get("violations")) and state.get("round", 0) < T:
state["round"] = int(state.get("round", 0)) + 1
state = edit.apply_edits(state)
# re-judge best image
state = _judge_best_only(state)
state = select.run(state)
# 9) archive
state = archive.run(state)
return state
def run_fusion_pipeline(state: AppState) -> AppState:
# ensure outdir
outdir = Path(state.get("outdir") or _default_outdir())
outdir.mkdir(parents=True, exist_ok=True)
state["outdir"] = str(outdir)
state["round"] = int(state.get("round", 0))
# Generate fused candidates from images + text instructions
state = gen_fusion.run(state)
# If we have candidates, select first as best; optionally judge later
if state.get("images"):
state["best_image"] = state["images"][0]
# Archive results (final.png etc.)
state = archive.run(state)
return state
def _judge_best_only(state: AppState) -> AppState:
# Only score the current best image again
from .llm.gemini import call_gemini
if not state.get("best_image"):
return state
res = call_gemini("judge", image_path=state["best_image"].path, spec=state.get("spec", {}))
vios = list(res.get("violations", []))
hard = [v for v in vios if str(v).strip().lower().startswith("hard:")]
if not hard:
hard = [v for v in vios if ("labels" in str(v).lower() and "missing" in str(v).lower())]
state["scores"] = [{
"image_path": state["best_image"].path,
"score": float(res.get("score", 0.0)),
"violations": vios,
}]
state["hard_violations"] = hard
return state
def _default_outdir() -> str:
return f"artifacts/run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"