Spaces:
Sleeping
Sleeping
| """ | |
| server/data_cleaning_env.py | |
| --------------------------- | |
| DataCleaningEnvironment β the heart of the environment. | |
| Implements the three abstract methods from openenv.core.env_server.interfaces.Environment: | |
| reset(seed, episode_id, **kwargs) -> CleanObservation | |
| step(action, timeout_s, **kwargs) -> CleanObservation | |
| state (property) -> CleanState | |
| Architecture | |
| ------------ | |
| Live DataFrames (_dirty_df, _clean_df) live as instance variables for speed. | |
| CleanState holds lightweight CSV snapshots used only for WebSocket state() | |
| responses β not for every step. This avoids serialising a 400-row DataFrame | |
| on every call. | |
| Action dispatch | |
| --------------- | |
| Each CleanAction.command routes to a private _apply_* method that mutates | |
| _dirty_df in place. Errors in those methods (bad column name, out-of-bounds | |
| row) are caught and returned as (success=False, error_msg=...) so the agent | |
| gets corrective feedback instead of a 500. | |
| Reward | |
| ------ | |
| compute_reward() implements the dense reward formula designed in the plan: | |
| progress term β grader score delta (main signal) | |
| efficiency bonus β small reward for early completion | |
| false-positive penalty β for dropping a valid-extreme row (medium task) | |
| early-DONE penalty β for calling DONE with a low score | |
| step cost β -0.005 every step to discourage padding | |
| """ | |
| from __future__ import annotations | |
| import sys | |
| import os | |
| from typing import Any, Optional | |
| from uuid import uuid4 | |
| import numpy as np | |
| import pandas as pd | |
| # ββ OpenEnv imports (try relative β absolute) βββββββββββββββββββββββββββββββββ | |
| try: | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import EnvironmentMetadata | |
| except ImportError: | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import EnvironmentMetadata | |
| # ββ Local imports (try relative β absolute for both server and standalone) βββ | |
| try: | |
| from ..models import ( | |
| CleanAction, CleanObservation, CleanState, | |
| MAX_STEPS, DONE_THRESHOLD, | |
| ) | |
| from ..dataset_factory import make_dataset, TaskDataset | |
| from ..graders import grade, GradeResult | |
| except ImportError: | |
| try: | |
| from models import ( | |
| CleanAction, CleanObservation, CleanState, | |
| MAX_STEPS, DONE_THRESHOLD, | |
| ) | |
| from dataset_factory import make_dataset, TaskDataset | |
| from graders import grade, GradeResult | |
| except ImportError: | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| from models import ( | |
| CleanAction, CleanObservation, CleanState, | |
| MAX_STEPS, DONE_THRESHOLD, | |
| ) | |
| from dataset_factory import make_dataset, TaskDataset | |
| from graders import grade, GradeResult | |
| # ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Per-step cost that discourages infinite loops / padding | |
| STEP_COST = -0.005 | |
| # Penalty for calling DONE before the score is reasonable | |
| EARLY_DONE_PENALTY = -0.20 | |
| EARLY_DONE_THRESHOLD = 0.60 # DONE below this score triggers the penalty | |
| # Penalty for removing a valid-extreme row in the medium task | |
| FALSE_POSITIVE_PENALTY = -0.15 | |
| # Efficiency bonus multiplier (only awarded when episode is solved) | |
| EFFICIENCY_BONUS_WEIGHT = 0.10 | |
| # Date formats the STANDARDIZE_COL handler will try, in priority order | |
| _DATE_PARSE_FORMATS = [ | |
| "%Y-%m-%d", # ISO β most reliable, try first | |
| "%m/%d/%Y", # US | |
| "%d.%m.%Y", # EU | |
| "%d/%m/%Y", # EU alt | |
| "%Y/%m/%d", # Asian | |
| ] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # DataCleaningEnvironment | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class DataCleaningEnvironment(Environment): | |
| """ | |
| Gym-style environment for the data cleaning pipeline task. | |
| Each episode: | |
| 1. reset(task_id="easy"|"medium"|"hard") loads a dirty/clean CSV pair. | |
| 2. The agent calls step() repeatedly, each time sending a CleanAction. | |
| 3. The episode ends when the agent sends DONE, the score crosses the | |
| task threshold, or the step budget is exhausted. | |
| The environment is fully stateless between sessions β all mutable state | |
| lives in instance variables, so concurrent sessions each get their own | |
| isolated copy (SUPPORTS_CONCURRENT_SESSIONS = True). | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self) -> None: | |
| super().__init__() | |
| # Live DataFrames β mutated by each step() | |
| self._dirty_df: Optional[pd.DataFrame] = None | |
| self._clean_df: Optional[pd.DataFrame] = None | |
| # Full task dataset from dataset_factory (holds metadata for grader) | |
| self._dataset: Optional[TaskDataset] = None | |
| # Pydantic state (lightweight; updated on demand) | |
| self._state: Optional[CleanState] = None | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # reset() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| task_id: str = "easy", | |
| **kwargs: Any, | |
| ) -> CleanObservation: | |
| """ | |
| Reset the environment for a new episode. | |
| Parameters | |
| ---------- | |
| seed | |
| Ignored β datasets use fixed seeds per task for reproducibility. | |
| episode_id | |
| Optional; auto-generated if not provided. | |
| task_id | |
| Which task to load: "easy", "medium", or "hard". | |
| """ | |
| if task_id not in MAX_STEPS: | |
| raise ValueError( | |
| f"Unknown task_id {task_id!r}. Must be one of: {list(MAX_STEPS)}" | |
| ) | |
| # Load dataset (always deterministic via fixed seed in dataset_factory) | |
| self._dataset = make_dataset(task_id) | |
| self._dirty_df = self._dataset.dirty_df.copy(deep=True) | |
| self._clean_df = self._dataset.clean_df.copy(deep=True) | |
| max_steps = MAX_STEPS[task_id] | |
| # Run grader on the initial dirty state so we have a starting score | |
| initial_result = grade( | |
| task_id=task_id, | |
| agent_df=self._dirty_df, | |
| clean_df=self._clean_df, | |
| metadata=self._dataset.metadata, | |
| initial_dirty_cells=self._dataset.total_dirty_cells, | |
| ) | |
| self._state = CleanState( | |
| episode_id=episode_id or str(uuid4()), | |
| step_count=0, | |
| task_id=task_id, | |
| dirty_csv_snapshot=self._df_to_csv(self._dirty_df), | |
| clean_csv_snapshot=self._df_to_csv(self._clean_df), | |
| initial_dirty_cells=self._dataset.total_dirty_cells, | |
| current_score=initial_result.score, | |
| previous_score=0.0, | |
| task_metadata=self._dataset.metadata, | |
| schema_hint=self._dataset.schema_hint, | |
| max_steps=max_steps, | |
| ) | |
| return self._build_observation( | |
| reward=None, | |
| done=False, | |
| last_action_success=True, | |
| last_action_error=None, | |
| grader_result=initial_result, | |
| column_status=self._compute_column_status(), | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # step() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def step( | |
| self, | |
| action: CleanAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> CleanObservation: | |
| """ | |
| Apply one CleanAction and return the resulting observation. | |
| Never raises for bad action inputs β instead returns | |
| last_action_success=False with a descriptive error message so the | |
| agent can self-correct on the next step. | |
| """ | |
| if self._state is None or self._dirty_df is None: | |
| raise RuntimeError("Environment not initialised. Call reset() first.") | |
| self._state.step_count += 1 | |
| # ββ Save previous score before mutating ββββββββββββββββββββββββββββββ | |
| prev_score = self._state.current_score | |
| self._state.previous_score = prev_score | |
| # ββ DONE handling ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # The environment decides termination β not the agent. | |
| # DONE is only accepted when score >= task threshold. | |
| # Below threshold: penalise heavily, refuse to terminate. | |
| if action.command == "DONE": | |
| threshold = DONE_THRESHOLD[self._state.task_id] | |
| if self._state.current_score >= threshold: | |
| # Legitimate DONE β episode ends with efficiency bonus | |
| reward = self._compute_reward( | |
| action=action, | |
| prev_score=prev_score, | |
| curr_score=prev_score, | |
| action_success=True, | |
| was_false_positive=False, | |
| ) | |
| done = True | |
| self._state.dirty_csv_snapshot = self._df_to_csv(self._dirty_df) | |
| col_status = self._compute_column_status() | |
| return self._build_observation( | |
| reward=reward, | |
| done=done, | |
| last_action_success=True, | |
| last_action_error=None, | |
| grader_result=GradeResult( | |
| score=prev_score, | |
| issues_remaining=0, | |
| detail="Agent signalled DONE β accepted.", | |
| ), | |
| column_status=col_status, | |
| ) | |
| else: | |
| # Premature DONE β refuse termination, apply strong penalty | |
| col_status = self._compute_column_status() | |
| still_broken = [ | |
| c for c, s in col_status.items() | |
| if s.get("missing", 0) > 0 or not s.get("standardized", True) | |
| ] | |
| error_msg = ( | |
| f"DONE rejected: score {self._state.current_score:.4f} < " | |
| f"required {threshold:.2f}. " | |
| f"Still broken: {still_broken[:5]}" | |
| ) | |
| # Hard penalty: -1.0 clamped to floor | |
| reward = round(float(np.clip(-1.0 + STEP_COST, -1.0, 1.0)), 4) | |
| self._state.dirty_csv_snapshot = self._df_to_csv(self._dirty_df) | |
| return self._build_observation( | |
| reward=reward, | |
| done=False, # episode continues | |
| last_action_success=False, | |
| last_action_error=error_msg, | |
| grader_result=GradeResult( | |
| score=self._state.current_score, | |
| issues_remaining=self._state.initial_dirty_cells | |
| - int(self._state.current_score * self._state.initial_dirty_cells), | |
| detail="Premature DONE blocked.", | |
| ), | |
| column_status=col_status, | |
| ) | |
| # ββ Apply action to _dirty_df ββββββββββββββββββββββββββββββββββββββββ | |
| action_success, error_msg, was_false_positive = self._apply_action(action) | |
| # ββ Grade the result ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| grader_result = grade( | |
| task_id=self._state.task_id, | |
| agent_df=self._dirty_df, | |
| clean_df=self._clean_df, | |
| metadata=self._state.task_metadata, | |
| initial_dirty_cells=self._state.initial_dirty_cells, | |
| ) | |
| curr_score = grader_result.score | |
| self._state.current_score = curr_score | |
| # ββ Compute reward ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| reward = self._compute_reward( | |
| action=action, | |
| prev_score=prev_score, | |
| curr_score=curr_score, | |
| action_success=action_success, | |
| was_false_positive=was_false_positive, | |
| ) | |
| # ββ Check termination ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| done = ( | |
| curr_score >= DONE_THRESHOLD[self._state.task_id] | |
| or self._state.step_count >= self._state.max_steps | |
| ) | |
| # ββ Sync state snapshot ββββββββββββββββββββββββββββββββββββββββββββββ | |
| self._state.dirty_csv_snapshot = self._df_to_csv(self._dirty_df) | |
| return self._build_observation( | |
| reward=reward, | |
| done=done, | |
| last_action_success=action_success, | |
| last_action_error=error_msg, | |
| grader_result=grader_result, | |
| column_status=self._compute_column_status(), | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # state (property) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def state(self) -> CleanState: | |
| """Return the current environment state (serialisable snapshot).""" | |
| if self._state is None: | |
| raise RuntimeError("Environment not initialised. Call reset() first.") | |
| # Keep snapshot fresh in case step() was called without triggering a sync | |
| if self._dirty_df is not None: | |
| self._state.dirty_csv_snapshot = self._df_to_csv(self._dirty_df) | |
| return self._state | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Action dispatch | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _apply_action( | |
| self, action: CleanAction | |
| ) -> tuple[bool, Optional[str], bool]: | |
| """ | |
| Mutate self._dirty_df according to the action. | |
| Returns | |
| ------- | |
| (success, error_msg, was_false_positive) | |
| success β True if action applied without error | |
| error_msg β human-readable description if success=False | |
| was_false_positive β True if a DROP_ROW removed a valid-extreme row | |
| """ | |
| cmd = action.command | |
| if cmd == "SET_VALUE": | |
| return self._apply_set_value(action) | |
| elif cmd == "DROP_ROW": | |
| return self._apply_drop_row(action) | |
| elif cmd == "STANDARDIZE_COL": | |
| return self._apply_standardize_col(action) | |
| elif cmd == "FILL_MISSING": | |
| return self._apply_fill_missing(action) | |
| else: | |
| return False, f"Unknown command: {cmd!r}", False | |
| # ββ SET_VALUE βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _apply_set_value( | |
| self, action: CleanAction | |
| ) -> tuple[bool, Optional[str], bool]: | |
| df = self._dirty_df | |
| row_idx = action.row_index | |
| col = action.column | |
| val = action.value | |
| # Validate column | |
| if col not in df.columns: | |
| return ( | |
| False, | |
| f"Column {col!r} not found. Available: {list(df.columns)}", | |
| False, | |
| ) | |
| # Validate row index (positional) | |
| if row_idx < 0 or row_idx >= len(df): | |
| return ( | |
| False, | |
| f"Row index {row_idx} out of range. DataFrame has {len(df)} rows (0β{len(df)-1}).", | |
| False, | |
| ) | |
| # Try to cast value to the column's expected type | |
| cast_val, cast_err = self._cast_value(val, df, col) | |
| if cast_err: | |
| return False, cast_err, False | |
| df.iloc[row_idx, df.columns.get_loc(col)] = cast_val | |
| return True, None, False | |
| # ββ DROP_ROW ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _apply_drop_row( | |
| self, action: CleanAction | |
| ) -> tuple[bool, Optional[str], bool]: | |
| df = self._dirty_df | |
| row_idx = action.row_index | |
| if row_idx < 0 or row_idx >= len(df): | |
| return ( | |
| False, | |
| f"Row index {row_idx} out of range. DataFrame has {len(df)} rows.", | |
| False, | |
| ) | |
| # Detect false positive for medium task: is this a valid-extreme row? | |
| was_false_positive = self._is_valid_extreme_row(row_idx) | |
| # Drop the row and reset positional index so future iloc references stay valid | |
| self._dirty_df = df.drop(df.index[row_idx]).reset_index(drop=True) | |
| return True, None, was_false_positive | |
| def _is_valid_extreme_row(self, iloc_idx: int) -> bool: | |
| """ | |
| Return True if dropping this row would be a false positive. | |
| Only applies to the medium task, which tracks valid_extreme_rows | |
| by their original tx_id. | |
| """ | |
| if self._state is None or self._state.task_id != "medium": | |
| return False | |
| valid_extreme_rows: list = self._state.task_metadata.get( | |
| "valid_extreme_rows", [] | |
| ) | |
| if not valid_extreme_rows or self._clean_df is None: | |
| return False | |
| df = self._dirty_df | |
| if "tx_id" not in df.columns: | |
| return False | |
| # Get the tx_id of the row being dropped | |
| try: | |
| tx_id_to_drop = int(df.iloc[iloc_idx]["tx_id"]) | |
| except (IndexError, ValueError, KeyError): | |
| return False | |
| # Check if any valid-extreme row in clean_df has this tx_id | |
| for orig_idx in valid_extreme_rows: | |
| if orig_idx >= len(self._clean_df): | |
| continue | |
| if int(self._clean_df.iloc[orig_idx]["tx_id"]) == tx_id_to_drop: | |
| return True | |
| return False | |
| # ββ STANDARDIZE_COL βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _apply_standardize_col( | |
| self, action: CleanAction | |
| ) -> tuple[bool, Optional[str], bool]: | |
| df = self._dirty_df | |
| col = action.column | |
| if col not in df.columns: | |
| return ( | |
| False, | |
| f"Column {col!r} not found. Available: {list(df.columns)}", | |
| False, | |
| ) | |
| series = df[col].copy() | |
| # ββ Try date normalisation first ββββββββββββββββββββββββββββββββββββββ | |
| if self._looks_like_date_column(col, series): | |
| normalised, err = self._normalise_dates(series) | |
| if err: | |
| return False, f"Date normalisation failed for column {col!r}: {err}", False | |
| self._dirty_df[col] = normalised | |
| return True, None, False | |
| # ββ Try numeric coercion ββββββββββββββββββββββββββββββββββββββββββββββ | |
| if self._looks_like_numeric_column(col, series): | |
| numeric = pd.to_numeric(series, errors="coerce") | |
| # Only apply if we didn't lose more than 20% of non-null values | |
| original_non_null = series.notna().sum() | |
| coerced_non_null = numeric.notna().sum() | |
| if original_non_null == 0 or coerced_non_null / original_non_null >= 0.8: | |
| self._dirty_df[col] = numeric | |
| return True, None, False | |
| # ββ String normalisation: strip whitespace βββββββββββββββββββββββββββ | |
| self._dirty_df[col] = series.apply( | |
| lambda x: str(x).strip() if not _is_nan(x) else x | |
| ) | |
| return True, None, False | |
| def _looks_like_date_column(self, col: str, series: pd.Series) -> bool: | |
| """Heuristic: column name contains 'date' or most non-null values parse as dates.""" | |
| if "date" in col.lower(): | |
| return True | |
| sample = series.dropna().astype(str).head(5) | |
| parsed = 0 | |
| for s in sample: | |
| for fmt in _DATE_PARSE_FORMATS: | |
| try: | |
| pd.to_datetime(s, format=fmt) | |
| parsed += 1 | |
| break | |
| except Exception: | |
| pass | |
| return parsed >= max(1, len(sample) // 2) | |
| def _looks_like_numeric_column(self, col: str, series: pd.Series) -> bool: | |
| """Heuristic: column name or majority of values suggests numeric data.""" | |
| numeric_keywords = {"price", "amount", "value", "quantity", "qty", "count", "id", "num"} | |
| if any(kw in col.lower() for kw in numeric_keywords): | |
| return True | |
| sample = series.dropna().head(10) | |
| if len(sample) == 0: | |
| return False | |
| convertible = pd.to_numeric(sample, errors="coerce").notna().sum() | |
| return convertible / len(sample) >= 0.7 | |
| def _normalise_dates(self, series: pd.Series) -> tuple[pd.Series, Optional[str]]: | |
| """Parse dates in any supported format and reformat as YYYY-MM-DD.""" | |
| def _parse_one(x: Any) -> Any: | |
| if _is_nan(x): | |
| return x | |
| s = str(x).strip() | |
| for fmt in _DATE_PARSE_FORMATS: | |
| try: | |
| return pd.to_datetime(s, format=fmt).strftime("%Y-%m-%d") | |
| except Exception: | |
| pass | |
| # Last resort: let pandas guess | |
| try: | |
| parsed = pd.to_datetime(s, dayfirst=False) | |
| if 2000 <= parsed.year <= 2030: | |
| return parsed.strftime("%Y-%m-%d") | |
| except Exception: | |
| pass | |
| return x # leave unchanged if unparseable | |
| return series.apply(_parse_one), None | |
| # ββ FILL_MISSING ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _apply_fill_missing( | |
| self, action: CleanAction | |
| ) -> tuple[bool, Optional[str], bool]: | |
| df = self._dirty_df | |
| col = action.column | |
| strategy = action.fill_strategy | |
| if col not in df.columns: | |
| return ( | |
| False, | |
| f"Column {col!r} not found. Available: {list(df.columns)}", | |
| False, | |
| ) | |
| series = df[col].copy() | |
| numeric = pd.to_numeric(series, errors="coerce") | |
| has_numeric = numeric.notna().sum() > 0 | |
| if strategy == "mean": | |
| if not has_numeric: | |
| return False, f"Cannot compute mean for non-numeric column {col!r}.", False | |
| fill_val = numeric.mean() | |
| self._dirty_df[col] = numeric.fillna(round(fill_val, 2)) | |
| elif strategy == "median": | |
| if not has_numeric: | |
| return False, f"Cannot compute median for non-numeric column {col!r}.", False | |
| fill_val = numeric.median() | |
| self._dirty_df[col] = numeric.fillna(round(fill_val, 2)) | |
| elif strategy == "mode": | |
| mode_result = series.mode(dropna=True) | |
| if mode_result.empty: | |
| return False, f"No mode found for column {col!r} (all values missing?).", False | |
| self._dirty_df[col] = series.fillna(mode_result.iloc[0]) | |
| elif strategy == "drop": | |
| before = len(self._dirty_df) | |
| self._dirty_df = self._dirty_df.dropna(subset=[col]).reset_index(drop=True) | |
| after = len(self._dirty_df) | |
| return True, None, False | |
| else: | |
| return False, f"Unknown fill_strategy: {strategy!r}", False | |
| return True, None, False | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Reward computation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _compute_reward( | |
| self, | |
| action: CleanAction, | |
| prev_score: float, | |
| curr_score: float, | |
| action_success: bool, | |
| was_false_positive: bool, | |
| ) -> float: | |
| """ | |
| Dense per-step reward in the range [-0.5, +1.0]. | |
| Components | |
| ---------- | |
| progress score delta (main learning signal) | |
| efficiency bonus small reward for solving with steps to spare | |
| fp_penalty penalise removing a valid-extreme row (medium task) | |
| early_done_penalty penalise calling DONE with a very low score | |
| step_cost tiny constant cost to discourage padding | |
| """ | |
| if self._state is None: | |
| return 0.0 | |
| max_steps = self._state.max_steps | |
| step_count = self._state.step_count | |
| # 1. Progress term | |
| progress = curr_score - prev_score | |
| # 2. Efficiency bonus (only when task is solved this step) | |
| threshold = DONE_THRESHOLD[self._state.task_id] | |
| just_solved = prev_score < threshold <= curr_score | |
| step_fraction = step_count / max_steps | |
| efficiency = EFFICIENCY_BONUS_WEIGHT * (1.0 - step_fraction) if just_solved else 0.0 | |
| # 3. False-positive penalty | |
| fp_penalty = FALSE_POSITIVE_PENALTY if was_false_positive else 0.0 | |
| # 4. Early-DONE penalty | |
| early_done = ( | |
| EARLY_DONE_PENALTY | |
| if action.command == "DONE" and curr_score < EARLY_DONE_THRESHOLD | |
| else 0.0 | |
| ) | |
| # 5. Step cost | |
| step_cost = STEP_COST | |
| reward = progress + efficiency + fp_penalty + early_done + step_cost | |
| return round(float(np.clip(reward, -0.5, 1.0)), 4) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Column status (for observation + DONE guard) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _compute_column_status(self) -> dict: | |
| """ | |
| Return a per-column health summary. | |
| Format:: | |
| { | |
| "price": {"missing": 0, "standardized": True, "issues": []}, | |
| "order_date": {"missing": 0, "standardized": False, "issues": ["bad_date_format"]}, | |
| } | |
| Rules | |
| ----- | |
| missing β count of NaN / None / empty-string values | |
| standardized β True when the column has no detectable type/format errors | |
| issues β human-readable list of problem types found | |
| Used by: | |
| - The DONE hard-block (DONE rejected if any col has missing>0 or standardized=False) | |
| - The observation (agent reads this directly to plan its next action) | |
| """ | |
| if self._dirty_df is None: | |
| return {} | |
| df = self._dirty_df | |
| status: dict = {} | |
| for col in df.columns: | |
| series = df[col] | |
| issues = [] | |
| # ββ Missing count ββββββββββββββββββββββββββββββββββββββββββββββ | |
| missing = int(series.apply(_is_nan).sum()) | |
| if missing > 0: | |
| issues.append(f"{missing}_missing") | |
| raw = series.dropna().astype(object).apply(str) | |
| # ββ Date column checks βββββββββββββββββββββββββββββββββββββββββ | |
| if "date" in col.lower(): | |
| import re | |
| iso_pat = re.compile(r"^\d{4}-\d{2}-\d{2}$") | |
| bad_dates = raw.apply( | |
| lambda x: not iso_pat.match(x.strip()) | |
| ).sum() | |
| if bad_dates > 0: | |
| issues.append(f"{int(bad_dates)}_bad_date_format") | |
| # Far-future years are a known injection pattern (year > 2030) | |
| def _bad_year(s: str) -> bool: | |
| try: | |
| return int(s[:4]) > 2030 | |
| except Exception: | |
| return False | |
| future = raw.apply(lambda x: _bad_year(x.strip())).sum() | |
| if future > 0: | |
| issues.append(f"{int(future)}_future_year") | |
| # ββ Numeric column checks ββββββββββββββββββββββββββββββββββββββ | |
| elif self._looks_like_numeric_column(col, series): | |
| non_numeric = pd.to_numeric(raw, errors="coerce").isna().sum() | |
| if non_numeric > 0: | |
| issues.append(f"{int(non_numeric)}_non_numeric_text") | |
| # ββ String column checks βββββββββββββββββββββββββββββββββββββββ | |
| else: | |
| whitespace = raw.apply( | |
| lambda x: x != x.strip() | |
| ).sum() | |
| if whitespace > 0: | |
| issues.append(f"{int(whitespace)}_has_whitespace") | |
| standardized = len(issues) == 0 and missing == 0 | |
| status[col] = { | |
| "missing": missing, | |
| "standardized": standardized, | |
| "issues": issues, | |
| } | |
| return status | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Observation builder | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_observation( | |
| self, | |
| reward: Optional[float], | |
| done: bool, | |
| last_action_success: bool, | |
| last_action_error: Optional[str], | |
| grader_result: GradeResult, | |
| column_status: Optional[dict] = None, | |
| ) -> CleanObservation: | |
| if self._state is None: | |
| raise RuntimeError("State not initialised.") | |
| return CleanObservation( | |
| # Inherited from Observation base | |
| done=done, | |
| reward=reward, | |
| # Task context | |
| task_id=self._state.task_id, | |
| schema_hint=self._state.schema_hint, | |
| initial_dirty_cells=self._state.initial_dirty_cells, | |
| # Per-step state | |
| dirty_csv=self._df_to_csv(self._dirty_df), | |
| current_score=grader_result.score, | |
| issues_remaining=grader_result.issues_remaining, | |
| step_number=self._state.step_count, | |
| max_steps=self._state.max_steps, | |
| # Last-action feedback | |
| last_action_success=last_action_success, | |
| last_action_error=last_action_error, | |
| # Per-column status | |
| column_status=column_status or {}, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Utilities | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _df_to_csv(df: Optional[pd.DataFrame]) -> str: | |
| """Serialise DataFrame to CSV string with the integer position index.""" | |
| if df is None: | |
| return "" | |
| return df.to_csv(index=True, index_label="row_index") | |
| def _cast_value( | |
| val: str, df: pd.DataFrame, col: str | |
| ) -> tuple[Any, Optional[str]]: | |
| """ | |
| Try to cast a string value to the appropriate type for `col`. | |
| Returns (cast_value, error_message). error_message is None on success. | |
| """ | |
| # Determine target type from the clean (non-null, non-text) column values | |
| sample = pd.to_numeric( | |
| df[col].dropna().astype(str).str.strip(), errors="coerce" | |
| ) | |
| majority_numeric = sample.notna().sum() / max(len(df[col].dropna()), 1) >= 0.5 | |
| if majority_numeric: | |
| try: | |
| float_val = float(val.strip().replace(",", "")) | |
| # If all sample values are whole numbers, keep as int | |
| if (sample.dropna() % 1 == 0).all() and float_val % 1 == 0: | |
| return int(float_val), None | |
| return round(float_val, 2), None | |
| except (ValueError, AttributeError): | |
| return ( | |
| None, | |
| f"Cannot cast {val!r} to numeric for column {col!r}. " | |
| f"Provide a plain number (e.g. '29.99').", | |
| ) | |
| # String column β accept as-is (strip whitespace) | |
| return val.strip(), None | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Lifecycle | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def close(self) -> None: | |
| self._dirty_df = None | |
| self._clean_df = None | |
| self._dataset = None | |
| self._state = None | |
| def get_metadata(self) -> EnvironmentMetadata: | |
| return EnvironmentMetadata( | |
| name="data_cleaning_env", | |
| description=( | |
| "Data cleaning pipeline: the agent receives a dirty CSV " | |
| "and must fix type errors, outliers, missing values, and " | |
| "schema inconsistencies to match a hidden ground truth." | |
| ), | |
| version="1.0.0", | |
| author="hackathon", | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _is_nan(x: Any) -> bool: | |
| """Return True if x is any flavour of missing value.""" | |
| if x is None: | |
| return True | |
| try: | |
| return bool(pd.isna(x)) | |
| except (TypeError, ValueError): | |
| return False | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Smoke test | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| SEP = "β" * 64 | |
| for task_id in ("easy", "medium", "hard"): | |
| print(f"\n{SEP}\nTASK: {task_id.upper()}\n{SEP}") | |
| env = DataCleaningEnvironment() | |
| # ββ reset ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| obs = env.reset(task_id=task_id) | |
| print(f"reset() β score={obs.current_score:.4f} " | |
| f"issues={obs.issues_remaining} done={obs.done}") | |
| assert obs.reward is None, "reward must be None after reset" | |
| assert obs.done is False, "done must be False after reset" | |
| lines = obs.dirty_csv.strip().split("\n") | |
| print(f" CSV: {len(lines)} rows, {len(lines[0].split(','))} cols") | |
| print(f" Hint: {obs.schema_hint[:70]}β¦") | |
| # ββ state() ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st = env.state | |
| print(f"state() β episode_id={st.episode_id[:8]}β¦ step_count={st.step_count}") | |
| # ββ step: bad column (should give feedback, not crash) βββββββββββββββ | |
| bad_action = CleanAction( | |
| command="SET_VALUE", row_index=0, column="DOES_NOT_EXIST", value="0" | |
| ) | |
| obs2 = env.step(bad_action) | |
| assert obs2.last_action_success is False | |
| print(f"step (bad col) β success={obs2.last_action_success} " | |
| f"error='{obs2.last_action_error[:50]}β¦'") | |
| # ββ step: out-of-bounds row ββββββββββββββββββββββββββββββββββββββββββ | |
| bad_row = CleanAction( | |
| command="SET_VALUE", row_index=9999, column="price", value="10.0" | |
| ) | |
| obs3 = env.step(bad_row) | |
| assert obs3.last_action_success is False | |
| print(f"step (bad row) β success={obs3.last_action_success} " | |
| f"error='{obs3.last_action_error[:50]}β¦'") | |
| # ββ step: valid fix ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if task_id == "easy": | |
| # Find the first injected dirty cell and fix it | |
| injected = env._dataset.metadata.get("injected_cells", []) | |
| if injected: | |
| row, col = injected[0] | |
| clean_val = str(env._clean_df.iloc[row][col]) | |
| fix_action = CleanAction( | |
| command="SET_VALUE", row_index=row, column=col, value=clean_val | |
| ) | |
| obs4 = env.step(fix_action) | |
| print(f"step (fix row={row} col={col!r}) β " | |
| f"success={obs4.last_action_success} " | |
| f"score={obs4.current_score:.4f} " | |
| f"reward={obs4.reward:.4f}") | |
| assert obs4.last_action_success is True | |
| assert obs4.reward is not None | |
| elif task_id == "medium": | |
| # Fix one outlier row via FILL_MISSING on amount | |
| obs4 = env.step(CleanAction( | |
| command="FILL_MISSING", column="amount", fill_strategy="median" | |
| )) | |
| print(f"step (FILL_MISSING amount/median) β " | |
| f"score={obs4.current_score:.4f} reward={obs4.reward:.4f}") | |
| elif task_id == "hard": | |
| # Standardize the date column | |
| obs4 = env.step(CleanAction( | |
| command="STANDARDIZE_COL", column="date" | |
| )) | |
| print(f"step (STANDARDIZE_COL date) β " | |
| f"success={obs4.last_action_success} " | |
| f"score={obs4.current_score:.4f} reward={obs4.reward:.4f}") | |
| # ββ DONE action βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # DONE below threshold is now BLOCKED β done=False, reward=-1.0, error set | |
| done_obs = env.step(CleanAction(command="DONE")) | |
| threshold = DONE_THRESHOLD[task_id] | |
| if env.state.current_score >= threshold: | |
| # Score met threshold β DONE accepted | |
| assert done_obs.done is True, f"Expected done=True at score {env.state.current_score:.4f}" | |
| print(f"step (DONE, accepted) β done={done_obs.done} " | |
| f"reward={done_obs.reward:.4f} score={done_obs.current_score:.4f}") | |
| else: | |
| # Score below threshold β DONE blocked | |
| assert done_obs.done is False, \ | |
| f"Expected done=False (premature DONE blocked) at score {done_obs.current_score:.4f}" | |
| assert done_obs.last_action_success is False | |
| assert done_obs.reward == round(float(np.clip(-1.0 + STEP_COST, -1.0, 1.0)), 4) | |
| print(f"step (DONE, blocked) β done={done_obs.done} " | |
| f"reward={done_obs.reward:.4f} score={done_obs.current_score:.4f}") | |
| print(f" error: {done_obs.last_action_error[:80]}") | |
| env.close() | |
| print(f"\n{SEP}\nAll smoke tests passed.\n{SEP}") |