| import copy |
| import json |
| import random |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any, Dict, Optional, List |
|
|
|
|
| @dataclass |
| class StepResult: |
| obs: Dict[str, Any] |
| reward: float |
| done: bool |
| info: Dict[str, Any] |
|
|
|
|
| class FlowDebugEnv: |
| """ |
| This is a simple environment made for OpenEnv. |
| Here's what it does: |
| - It gives you information in text or JSON. |
| - You can only do one thing: fix the 'inputs.expression' in a 'Condition_Check' step. |
| - If your fix is exactly right, you win! |
| """ |
| def __init__(self, cases: List[Dict[str, Any]], max_attempts: int = 3, seed: Optional[int] = None): |
| self.cases = cases |
| self.max_attempts = max_attempts |
| self.rng = random.Random(seed) |
| self.current_case: Optional[Dict[str, Any]] = None |
| self.attempts_left = max_attempts |
|
|
| @classmethod |
| def from_json(cls, cases_json_path: str, max_attempts: int = 3, seed: Optional[int] = None): |
| path = Path(cases_json_path) |
| with open(path, "r", encoding="utf-8") as f: |
| cases = json.load(f) |
| return cls(cases=cases, max_attempts=max_attempts, seed=seed) |
|
|
| def reset(self) -> Dict[str, Any]: |
| self.current_case = copy.deepcopy(self.rng.choice(self.cases)) |
| self.attempts_left = self.max_attempts |
| return self._make_observation() |
|
|
| def step(self, action: Dict[str, Any]) -> StepResult: |
| if self.current_case is None: |
| raise RuntimeError("Call reset() before step().") |
|
|
| self.attempts_left -= 1 |
|
|
| if action.get("action") != "patch_step": |
| return self._invalid_action("Unsupported action type") |
|
|
| step_name = action.get("step") |
| field = action.get("field") |
| value = action.get("value") |
|
|
| patched_ok = self._apply_patch(step_name, field, value) |
| if not patched_ok: |
| return self._invalid_action("Patch failed (step/field not found)") |
|
|
| gold = self.current_case["gold_fix"] |
| solved = (step_name == gold["step"] and field == gold["field"] and value == gold["value"]) |
|
|
| if solved: |
| self._mark_success() |
| obs = self._make_observation(run_status="Succeeded", error=None, failed_step=None) |
| return StepResult(obs=obs, reward=1.0, done=True, |
| info={"result": "success", "case_id": self.current_case["case_id"]}) |
|
|
| if self.attempts_left <= 0: |
| obs = self._make_observation() |
| return StepResult(obs=obs, reward=-0.2, done=True, |
| info={"result": "out_of_attempts", "case_id": self.current_case["case_id"]}) |
|
|
| obs = self._make_observation() |
| return StepResult(obs=obs, reward=-0.1, done=False, |
| info={"result": "still_failed", "case_id": self.current_case["case_id"]}) |
|
|
| |
| def _apply_patch(self, step_name: str, field: str, value: str) -> bool: |
| for step in self.current_case["steps"]: |
| if step["name"] == step_name: |
| if field == "inputs.expression": |
| step.setdefault("inputs", {}) |
| step["inputs"]["expression"] = value |
| return True |
| return False |
|
|
| def _mark_success(self): |
| for step in self.current_case["steps"]: |
| step["status"] = "Succeeded" |
|
|
| def _make_observation(self, run_status="Failed", error="keep", failed_step="keep"): |
| if error == "keep": |
| err_obj = self.current_case["error"] |
| else: |
| err_obj = error |
|
|
| if failed_step == "keep": |
| failed = self.current_case["failed_step"] |
| else: |
| failed = failed_step |
|
|
| return { |
| "case_id": self.current_case["case_id"], |
| "run_status": run_status, |
| "failed_step": failed, |
| "error": err_obj, |
| "steps": self.current_case["steps"], |
| "attempts_left": self.attempts_left |
| } |
|
|
| def _invalid_action(self, msg: str) -> StepResult: |
| obs = self._make_observation() |
| done = (self.attempts_left <= 0) |
| return StepResult(obs=obs, reward=-0.1, done=done, |
| info={"result": "invalid_action", "message": msg, "case_id": self.current_case["case_id"]}) |
|
|