import os import glob import re import hashlib from typing import Dict, List, Optional, Any, Tuple import numpy as np import torch from torch.utils.data import Dataset, DataLoader from tqdm import tqdm import json # CONFIG & REGISTRY DROP_OBS_KEYS = [] DATA_DIR = "TrajectoryData_from_docker" INDEX_CACHE_PATH = os.path.join(DATA_DIR, "episode_index_cache_topk.json") NORM_CACHE_PATH = os.path.join(DATA_DIR, "norm_stats_v_topk.npz") PAD_ID = 0 UNK_ID = 1 SENSOR_START_ID = 2 ACTION_START_ID = 300 VOCAB_SIZE = 512 CONTEXT_LEN = 48 MAX_TOKENS_PER_STEP = 64 MAX_ZONES = 32 PHYSICS_HORIZON = 16 SEED = 42 USE_TOPK = True TOPK_FRAC = 0.8 TOPK_MODE = "filter" TOPK_ON = "energy" TOPK_BOOST = 3.0 # --- Action Discretization --- NUM_ACTION_BINS = 64 HTG_LOW, HTG_HIGH = 15.0, 30.0 CLG_LOW, CLG_HIGH = 15.0, 30.0 # --- Normalization & Scaling --- USE_NORMALIZATION = True ACTION_VALUE_INPUT_MODE = "prev" ACTION_VALUE_MASK_CONST = 0.0 COMFORT_SCALE = 1.0 # --- Preference conditioning --- PREF_MODE = "sample" PREF_FIXED_LAMBDA = 0.5 PREF_BETA_A = 5.0 PREF_BETA_B = 2.0 ZONE_SRC_REGEX = 1 ZONE_SRC_PAREN = 2 ZONE_SRC_CORE_PERIM = 3 ZONE_SRC_HASH = 4 HVAC_KEYWORD_MAP = { # Sensors (2..299) "temp": 10, "t_in": 10, "temperature": 10, "humidity": 11, "rh": 11, "co2": 12, "ppm": 12, "power": 13, "energy": 13, "kw": 13, "occupancy": 14, "occ": 14, "people": 14, "solar": 15, "rad": 15, "radiation": 15, "outdoor": 16, "site": 16, "environment": 16, "pressure": 17, "flow": 18, "fan": 19, "speed": 19, # Actions (offset from ACTION_START_ID) "setpoint": 10, "stpt": 10, "damper": 11, "position": 11, "valve": 12, } # ============================================================ # HELPER # ============================================================ def compute_comfort_indices_from_state_keys(state_keys: List[str]) -> List[int]: kl = [str(k).lower() for k in state_keys] any_idx = [i for i, k in enumerate(kl) if ("ash55" in k and "notcomfortable" in k and "any" in k)] if len(any_idx) > 0: return any_idx return [i for i, k in enumerate(kl) if ("ash55" in k and "notcomfortable" in k)] def extract_zone_id_with_source(name_lower: str) -> Tuple[int, int]: m = re.search(r'(?:\bzone\b|\bz\b|\bzn\b)[_\s\-]*?(\d+)\b', name_lower) if m: zid = int(m.group(1)) zid = min(max(zid, 0), MAX_ZONES - 1) return zid, ZONE_SRC_REGEX parens = re.findall(r'\(([^)]+)\)', name_lower) for chunk in parens: m2 = re.search(r'(?:\bzone\b|\bz\b|\bzn\b)[_\s\-]*?(\d+)\b', chunk) if m2: return min(max(int(m2.group(1)), 0), MAX_ZONES - 1), ZONE_SRC_PAREN m4 = re.search(r'(?:perimeter|perim|core)[_\s\-]*?(?:zn[_\s\-]*)?(\d+)\b', name_lower) if m4: return min(max(int(m4.group(1)), 0), MAX_ZONES - 1), ZONE_SRC_CORE_PERIM h = int(hashlib.md5(name_lower.encode()).hexdigest(), 16) return 1 + (h % max(1, (MAX_ZONES - 1))), ZONE_SRC_HASH def parse_feature_identity(name: str, is_action: bool = False) -> Tuple[int, int, int]: name_lower = str(name).lower() zone_id, zone_src = extract_zone_id_with_source(name_lower) found_id = UNK_ID for key, val in HVAC_KEYWORD_MAP.items(): if key in name_lower: found_id = val break if found_id == UNK_ID: hash_val = int(hashlib.md5(name_lower.encode()).hexdigest(), 16) found_id = 50 + (hash_val % 50) final_id = (ACTION_START_ID if is_action else SENSOR_START_ID) + found_id if final_id >= VOCAB_SIZE: final_id = UNK_ID return final_id, zone_id, zone_src def discretize_actions_to_bins(actions: np.ndarray, action_keys: List[str]) -> np.ndarray: out = np.zeros_like(actions, dtype=np.int64) for j, k in enumerate(action_keys): kl = k.lower() if "clg" in kl or "cool" in kl: lo, hi = CLG_LOW, CLG_HIGH else: lo, hi = HTG_LOW, HTG_HIGH a = np.clip(actions[:, j], lo, hi) x = (a - lo) / (hi - lo + 1e-12) bins = np.rint(x * (NUM_ACTION_BINS - 1)).astype(np.int64) out[:, j] = np.clip(bins, 0, NUM_ACTION_BINS - 1) return out def discounted_cumsum(x: np.ndarray, gamma: float = 1.0) -> np.ndarray: y = np.zeros_like(x, dtype=np.float32) running = 0.0 for t in range(len(x)-1, -1, -1): running = x[t] + gamma * running y[t] = running return y def _mix_u64(x: int) -> int: x &= 0xFFFFFFFFFFFFFFFF x ^= (x >> 33) x = (x * 0xff51afd7ed558ccd) & 0xFFFFFFFFFFFFFFFF x ^= (x >> 33) x = (x * 0xc4ceb9fe1a85ec53) & 0xFFFFFFFFFFFFFFFF x ^= (x >> 33) return x & 0xFFFFFFFFFFFFFFFF def dataset_signature(npz_paths: List[str]) -> str: parts = [] for p in npz_paths: try: st = os.stat(p) parts.append(f"{p}|{st.st_size}|{int(st.st_mtime)}") except FileNotFoundError: parts.append(f"{p}|missing") raw = "\n".join(parts).encode("utf-8") return hashlib.md5(raw).hexdigest() def compute_occupancy_indices_from_state_keys(state_keys: List[str]) -> List[int]: kl = [str(k).lower() for k in state_keys] return [i for i, k in enumerate(kl) if ("occ" in k and "count" in k)] # ============================================================ # 1) EPISODE INDEX # ============================================================ class EpisodeIndex: def __init__(self, npz_paths: List[str]): self.paths = list(npz_paths) self.T: List[int] = [] self.returns_energy: List[float] = [] self.returns_comfort: List[float] = [] self.s_meta: List[List[Tuple[int,int,int]]] = [] self.a_meta: List[List[Tuple[int,int,int]]] = [] self.state_keys: List[List[str]] = [] self.action_keys: List[List[str]] = [] self.keep_indices_map: List[List[int]] = [] self.comfort_idx: List[List[int]] = [] sig = dataset_signature(self.paths) if os.path.exists(INDEX_CACHE_PATH): try: with open(INDEX_CACHE_PATH, "r") as f: cache = json.load(f) if cache.get("signature") == sig and "returns_energy" in cache: print(f"[DataLoader] Loading cached index: {INDEX_CACHE_PATH}") self.T = cache["T"] self.returns_energy = cache["returns_energy"] self.returns_comfort = cache["returns_comfort"] self.state_keys = cache["state_keys"] self.action_keys = cache["action_keys"] self.keep_indices_map = cache.get("keep_indices_map", []) self.s_meta = [[parse_feature_identity(k, is_action=False) for k in ks] for ks in self.state_keys] self.a_meta = [[parse_feature_identity(k, is_action=True) for k in ks] for ks in self.action_keys] if "comfort_idx" in cache: self.comfort_idx = cache["comfort_idx"] else: print("[DataLoader] Cache missing comfort_idx. Rebuilding.") raise ValueError("Outdated Cache") print(f"[DataLoader] Cache loaded. Episodes indexed: {len(self.T)}") return else: print("[DataLoader] Cache signature mismatch") except Exception as e: print(f"[DataLoader] Failed load cache: {e}") for p in tqdm(self.paths, desc="Indexing"): try: with np.load(p, allow_pickle=True) as d: obs = d["observations"] if "rewards_energy" in d: r_e = d["rewards_energy"] r_c = d["rewards_comfort"] else: r_e = d["rewards"] r_c = np.zeros_like(r_e) ret_e = float(np.sum(r_e)) ret_c = float(np.sum(r_c)) T = int(obs.shape[0]) # Get RAW keys raw_s_keys = d["state_keys"].astype(object).tolist() if "state_keys" in d else [] a_keys = d["action_keys"].astype(object).tolist() if "action_keys" in d else [] raw_s_keys = list(map(str, raw_s_keys)) a_keys = list(map(str, a_keys)) c_idx = compute_comfort_indices_from_state_keys(raw_s_keys) keep_idxs = [i for i, k in enumerate(raw_s_keys) if k not in DROP_OBS_KEYS] s_keys = [raw_s_keys[i] for i in keep_idxs] s_meta = [parse_feature_identity(k, is_action=False) for k in s_keys] a_meta = [parse_feature_identity(k, is_action=True) for k in a_keys] self.T.append(T) self.returns_energy.append(ret_e) self.returns_comfort.append(ret_c) self.state_keys.append(s_keys) self.action_keys.append(a_keys) self.comfort_idx.append(c_idx) # Save indices relative to RAW array self.s_meta.append(s_meta) self.a_meta.append(a_meta) self.keep_indices_map.append(keep_idxs) except Exception as e: print(f"[IndexError] {p}: {e}") # Save Cache try: cache = { "signature": sig, "T": self.T, "returns_energy": self.returns_energy, "returns_comfort": self.returns_comfort, "state_keys": self.state_keys, "action_keys": self.action_keys, "keep_indices_map": self.keep_indices_map, "comfort_idx": self.comfort_idx, # Added } with open(INDEX_CACHE_PATH, "w") as f: json.dump(cache, f) print(f"[DataLoader] Saved index cache: {INDEX_CACHE_PATH}") except Exception as e: print(f"[DataLoader] Warning: failed to save cache: {e}") def __len__(self): return len(self.T) # ============================================================ # 2) NORMALIZATION # ============================================================ def compute_and_save_norm_stats(npz_paths: List[str], index: "EpisodeIndex", max_episodes: int = 1000, stride: int = 4): rng = np.random.default_rng(SEED) n = len(index) if n == 0: raise RuntimeError("EpisodeIndex is empty (no valid episodes).") k = min(max_episodes, n) eps_idx = rng.choice(np.arange(n), size=k, replace=False) obs_sum, obs_sumsq = None, None act_sum, act_sumsq = None, None count = 0 for ei in tqdm(eps_idx, desc="Computing norm stats"): p = index.paths[int(ei)] with np.load(p, allow_pickle=True) as d: obs = d["observations"].astype(np.float32) act = d["actions"].astype(np.float32) keep_idxs = index.keep_indices_map[int(ei)] obs = obs[:, keep_idxs] obs = obs[::stride] act = act[::stride] if obs_sum is None: obs_sum = np.zeros(obs.shape[1], dtype=np.float64) obs_sumsq = np.zeros(obs.shape[1], dtype=np.float64) act_sum = np.zeros(act.shape[1], dtype=np.float64) act_sumsq = np.zeros(act.shape[1], dtype=np.float64) obs_sum += obs.sum(axis=0) obs_sumsq += (obs**2).sum(axis=0) act_sum += act.sum(axis=0) act_sumsq += (act**2).sum(axis=0) count += obs.shape[0] if obs_sum is None or obs_sumsq is None or act_sum is None or act_sumsq is None: raise ValueError("obs_sum, obs_sumsq, act_sum, or act_sumsq is not initialized properly.") obs_mean = (obs_sum / max(count, 1)).astype(np.float32) obs_std = np.sqrt(np.maximum((obs_sumsq / max(count, 1)) - obs_mean**2, 1e-6)).astype(np.float32) act_mean = (act_sum / max(count, 1)).astype(np.float32) act_std = np.sqrt(np.maximum((act_sumsq / max(count, 1)) - act_mean**2, 1e-6)).astype(np.float32) all_re = np.abs(np.array(index.returns_energy)) all_rc = np.abs(np.array(index.returns_comfort)) scale_energy = float(np.percentile(all_re, 95)) if len(all_re) > 0 else 1.0 scale_comfort = float(np.percentile(all_rc, 95)) if len(all_rc) > 0 else 1.0 scale_energy = max(scale_energy, 1.0) scale_comfort = max(scale_comfort, 1.0) np.savez_compressed( NORM_CACHE_PATH, obs_mean=obs_mean, obs_std=obs_std, act_mean=act_mean, act_std=act_std, scale_energy=np.array([scale_energy], dtype=np.float32), scale_comfort=np.array([scale_comfort], dtype=np.float32), ) class GeneralistDataset(Dataset): def __init__( self, npz_paths: List[str], max_tokens: int = MAX_TOKENS_PER_STEP, seed: int = SEED, virtual_len: int = 60_000, gamma_rtg: float = 1.0, topk_frac: Optional[float] = None, topk_mode: Optional[str] = None, topk_on: Optional[str] = None, ): self.index = EpisodeIndex(npz_paths) self.max_tokens = int(max_tokens) self.seed = int(seed) self.virtual_len = int(virtual_len) self.epoch = 0 self.gamma_rtg = float(gamma_rtg) self.is_train = True self.all_eps = np.arange(len(self.index), dtype=np.int64) # ---------------- Top-K selection ---------------- self.use_topk = bool(USE_TOPK) if topk_frac is None else True self.topk_frac = float(TOPK_FRAC) if topk_frac is None else float(topk_frac) self.topk_mode = str(TOPK_MODE) if topk_mode is None else str(topk_mode) self.topk_on = str(TOPK_ON) if topk_on is None else str(topk_on) rets_e = np.asarray(self.index.returns_energy, dtype=np.float32) rets_c = np.asarray(self.index.returns_comfort, dtype=np.float32) self.sel_eps = self.all_eps self.weights = None if self.use_topk and len(self.all_eps) > 0: total_k = max(1, int(round(self.topk_frac * len(self.all_eps)))) # === STRATEGY 1: PARETO UNION (Energy + Comfort + Mixed) === if self.topk_on == "pareto": print("[Top-K] Strategy: Energy + Comfort + Mixed") k_part = max(1, total_k // 3) # 1. Best Energy idx_energy = np.argsort(rets_e)[::-1][:k_part] # 2. Best Comfort idx_comfort = np.argsort(rets_c)[::-1][:k_part] # 3. Best Mixed (Balanced) norm_e = (rets_e - rets_e.mean()) / (rets_e.std() + 1e-6) norm_c = (rets_c - rets_c.mean()) / (rets_c.std() + 1e-6) idx_mixed = np.argsort(norm_e + norm_c)[::-1][:k_part] # Combine unique indices top_eps = np.unique(np.concatenate([idx_energy, idx_comfort, idx_mixed])) else: if self.topk_on == "energy": rank_signal = rets_e elif self.topk_on == "comfort": rank_signal = rets_c elif self.topk_on == "mixed": rank_signal = rets_e + rets_c else: rank_signal = rets_e # Fallback order = np.argsort(rank_signal)[::-1] top_eps = order[:total_k] # === APPLY FILTER === if self.topk_mode == "filter": self.sel_eps = top_eps self.weights = None elif self.topk_mode == "weighted": self.sel_eps = top_eps self.weights = None # Load Norm Stats if USE_NORMALIZATION: if not os.path.exists(NORM_CACHE_PATH): print("[DataLoader] Computing Norm Stats...") compute_and_save_norm_stats(npz_paths, self.index) z = np.load(NORM_CACHE_PATH) self.obs_mean = z["obs_mean"].astype(np.float32) self.obs_std = z["obs_std"].astype(np.float32) self.act_mean = z["act_mean"].astype(np.float32) self.act_std = z["act_std"].astype(np.float32) self.scale_energy = float(z["scale_energy"][0]) self.scale_comfort = float(z["scale_comfort"][0]) else: self.obs_mean = None self.scale_energy = 1.0 self.scale_comfort = 1.0 def set_epoch(self, e: int): self.epoch = int(e) def __len__(self): return self.virtual_len def __getitem__(self, i: int) -> Dict[str, Any]: x = _mix_u64(self.seed ^ (self.epoch * 0x9E3779B97F4A7C15) ^ (int(i) * 0xD1B54A32D192ED03)) # Preference sampling if PREF_MODE == "fixed": lam = float(PREF_FIXED_LAMBDA) else: rng = np.random.default_rng(int(x & 0xFFFFFFFF)) lam = float(rng.beta(PREF_BETA_A, PREF_BETA_B)) if self.weights is None: ep_i = int(self.sel_eps[x % len(self.sel_eps)]) else: u = ((x & 0xFFFFFFFF) / 2**32) #Clip index to avoid out-of-bounds cdf = np.cumsum(self.weights) idx = int(np.searchsorted(cdf, u, side="right")) idx = min(idx, len(self.weights) - 1) ep_i = int(self.sel_eps[idx]) p = self.index.paths[ep_i] T_total = int(self.index.T[ep_i]) L = CONTEXT_LEN # 1. Load Data with np.load(p, allow_pickle=True) as d: raw_obs = d["observations"].astype(np.float32) at = d["actions"].astype(np.float32) if "rewards_energy" in d: re = d["rewards_energy"].astype(np.float32) rc = d["rewards_comfort"].astype(np.float32) else: re = d["rewards"].astype(np.float32) rc = np.zeros_like(re) if T_total >= L: total_r = re + rc num_candidates = 20 candidates = np.random.randint(0, T_total - L, size=num_candidates) scores = np.array([total_r[c : c + L].sum() for c in candidates]) scores_stab = (scores - np.max(scores)) / (np.std(scores) + 1e-6) probs = np.exp(scores_stab) probs /= probs.sum() s0 = np.random.choice(candidates, p=probs) else: s0 = 0 cidx = self.index.comfort_idx[ep_i] if len(cidx) > 0: ash55_raw_slice = raw_obs[:, cidx] else: ash55_raw_slice = np.zeros((T_total, 1), dtype=np.float32) keep_idxs = self.index.keep_indices_map[ep_i] st = raw_obs[:, keep_idxs] s_keys_ep = self.index.state_keys[ep_i] def find_idx(substring): for idx, k in enumerate(s_keys_ep): if substring in k.lower(): return idx return -1 idx_out = find_idx("outdoor_temp") idx_dew = find_idx("dewpoint") idx_hr = find_idx("hour") idx_mth = find_idx("month") idx_occ = compute_occupancy_indices_from_state_keys(s_keys_ep) def get_window(arr, pad_val=0.0): if T_total >= L: return arr[s0:s0+L] else: out = np.full((L, *arr.shape[1:]), pad_val, dtype=np.float32) out[:T_total] = arr return out st_win = get_window(st) at_win = get_window(at) at_win_raw = at_win.copy() re_win = get_window(re) rc_win = get_window(rc) ash55_win = get_window(ash55_raw_slice) ash55_any = ash55_win.mean(axis=1).astype(np.float32) tm_win = np.zeros((L,), dtype=np.float32) valid_len = min(T_total, L) tm_win[:valid_len] = 1.0 valid_mask = (tm_win > 0.5) FORECAST_STEPS = 48 future_start = s0 + L future_end = min(T_total, future_start + FORECAST_STEPS) forecast_temp = 0.0 if idx_out != -1: current_vals = st_win[valid_mask, idx_out] if len(current_vals) > 0: forecast_temp = current_vals.mean() if future_end > future_start: future_vals = st[future_start:future_end, idx_out] if len(future_vals) > 0: forecast_temp = future_vals.mean() # 3. Context Vector t_mean, t_std = 0.0, 0.0 if idx_out != -1 and valid_mask.sum() > 0: vals = st_win[valid_mask, idx_out] t_mean, t_std = vals.mean(), vals.std() d_mean = 0.0 if idx_dew != -1 and valid_mask.sum() > 0: d_mean = st_win[valid_mask, idx_dew].mean() occ_frac = 0.0 if len(idx_occ) > 0 and valid_mask.sum() > 0: occ_sum = st_win[valid_mask][:, idx_occ].sum(axis=1) occ_frac = (occ_sum > 0.5).mean() # Cyclical Time hr_sin, hr_cos = 0.0, 0.0 if idx_hr != -1 and valid_mask.sum() > 0: hr_val = st_win[valid_mask, idx_hr][0] hr_sin = np.sin(2 * np.pi * hr_val / 24.0) hr_cos = np.cos(2 * np.pi * hr_val / 24.0) mth_sin, mth_cos = 0.0, 0.0 if idx_mth != -1 and valid_mask.sum() > 0: mth_val = st_win[valid_mask, idx_mth][0] mth_sin = np.sin(2 * np.pi * mth_val / 12.0) mth_cos = np.cos(2 * np.pi * mth_val / 12.0) ctx_vec = np.array([ t_mean, t_std, d_mean, occ_frac, hr_sin, hr_cos, mth_sin, mth_cos, forecast_temp, 0.0 ], dtype=np.float32) next_st_win = np.zeros_like(st_win) future_4h_st_win = np.zeros_like(st_win) if T_total >= L: end_idx = min(s0 + L + 1, T_total) actual_len = end_idx - (s0 + 1) if actual_len > 0: next_st_win[:actual_len] = st[s0+1 : end_idx] f_end_idx = min(s0 + L + PHYSICS_HORIZON, T_total) f_actual_len = f_end_idx - (s0 + PHYSICS_HORIZON) if f_actual_len > 0: future_4h_st_win[:f_actual_len] = st[s0 + PHYSICS_HORIZON : f_end_idx] else: if T_total > 1: next_st_win[:T_total-1] = st[1:T_total] if USE_NORMALIZATION and (self.obs_mean is not None): st_win = (st_win - self.obs_mean) / self.obs_std next_st_win = (next_st_win - self.obs_mean) / self.obs_std future_4h_st_win = (future_4h_st_win - self.obs_mean) / self.obs_std at_win = (at_win - self.act_mean) / self.act_std delta_4h_win = future_4h_st_win - st_win full_rtg_e = discounted_cumsum(re, gamma=self.gamma_rtg) full_rtg_c = discounted_cumsum(rc, gamma=self.gamma_rtg) rtg_e_win = get_window(full_rtg_e) rtg_c_win = get_window(full_rtg_c) rtg_e_norm = rtg_e_win / self.scale_energy rtg_c_norm = rtg_c_win / self.scale_comfort rtg_combined = np.stack([rtg_e_norm, rtg_c_norm], axis=-1) if getattr(self, "is_train", True): rtg_combined += np.random.normal(0, 0.005, rtg_combined.shape).astype(np.float32) feat_ids = np.full((L, self.max_tokens), PAD_ID, dtype=np.int64) feat_vals = np.zeros((L, self.max_tokens), dtype=np.float32) zone_ids = np.zeros((L, self.max_tokens), dtype=np.int64) attn_mask = np.zeros((L, self.max_tokens), dtype=np.int64) target_toks = np.full((L, self.max_tokens), -100, dtype=np.int64) target_mask = np.zeros((L, self.max_tokens), dtype=np.float32) s_meta = self.index.s_meta[ep_i] a_meta = self.index.a_meta[ep_i] S_dim = min(len(s_meta), st_win.shape[1]) A_dim = min(len(a_meta), at_win.shape[1]) num_act_toks = min(A_dim, self.max_tokens) num_state_toks = min(S_dim, self.max_tokens - num_act_toks) if num_state_toks > 0: feat_ids[:, :num_state_toks] = [m[0] for m in s_meta[:num_state_toks]] zone_ids[:, :num_state_toks] = [m[1] for m in s_meta[:num_state_toks]] feat_vals[:, :num_state_toks] = st_win[:, :num_state_toks] attn_mask[:, :num_state_toks] = 1 if num_act_toks > 0: start = num_state_toks end = start + num_act_toks feat_ids[:, start:end] = [m[0] for m in a_meta[:num_act_toks]] zone_ids[:, start:end] = [m[1] for m in a_meta[:num_act_toks]] attn_mask[:, start:end] = 1 a_in = np.zeros((L, num_act_toks), dtype=np.float32) if L > 1: a_in[1:] = at_win[:-1, :num_act_toks] feat_vals[:, start:end] = a_in a_keys = self.index.action_keys[ep_i] at_discrete = discretize_actions_to_bins(at_win_raw, a_keys) target_toks[:, start:end] = at_discrete[:, :num_act_toks] target_mask[:, start:end] = 1.0 valid_t = (tm_win > 0.5)[:, None] attn_mask *= valid_t.astype(np.int64) target_mask *= valid_t return { "feature_ids": feat_ids, "feature_values": feat_vals, "zone_ids": zone_ids, "attention_mask": attn_mask, "target_action_tokens": target_toks, "target_mask": target_mask, "rtg": rtg_combined, "rtg_energy": rtg_e_norm, "rtg_comfort": rtg_c_norm, "rewards_energy": re_win, "rewards_comfort": rc_win, "pref_lambda": np.float32(lam), "ash55_any": ash55_any, "next_obs": next_st_win, "target_4h_delta": delta_4h_win, "time_mask": tm_win, "context": ctx_vec, } def generalist_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: def stack(k): return np.stack([b[k] for b in batch]) return { "feature_ids": torch.from_numpy(stack("feature_ids")).long(), "feature_values": torch.from_numpy(stack("feature_values")).float(), "zone_ids": torch.from_numpy(stack("zone_ids")).long(), "attention_mask": torch.from_numpy(stack("attention_mask")).long(), "target_action_tokens": torch.from_numpy(stack("target_action_tokens")).long(), "target_mask": torch.from_numpy(stack("target_mask")).float(), "rtg": torch.from_numpy(stack("rtg")).float(), "rtg_energy": torch.from_numpy(stack("rtg_energy")).float(), "rtg_comfort": torch.from_numpy(stack("rtg_comfort")).float(), "rewards_energy": torch.from_numpy(stack("rewards_energy")).float(), "rewards_comfort": torch.from_numpy(stack("rewards_comfort")).float(), "pref_lambda": torch.from_numpy(stack("pref_lambda")).float(), "ash55_any": torch.from_numpy(stack("ash55_any")).float(), "next_obs": torch.from_numpy(stack("next_obs")).float(), "target_4h_delta": torch.from_numpy(stack("target_4h_delta")).float(), "time_mask": torch.from_numpy(stack("time_mask")).float(), "context": torch.from_numpy(stack("context")).float(), } # ============================================================ # 4) DEBUG MAIN # ============================================================ def main(): npz_paths = sorted(glob.glob(os.path.join(DATA_DIR, "TrajectoryData_officesmall", "**", "traj_ep*_seed*.npz"), recursive=True)) npz_paths = [p for p in npz_paths if os.path.basename(p) not in ("norm_stats.npz",)] if not npz_paths: print(f"No data found in {DATA_DIR}") return ds = GeneralistDataset(npz_paths, max_tokens=64) loader = DataLoader(ds, batch_size=4, collate_fn=generalist_collate_fn, num_workers=0) batch = next(iter(loader)) if __name__ == "__main__": main()