| import json | |
| import os | |
| from pathlib import Path | |
| import farconf | |
| from cleanba.config import Args | |
| from cleanba.environments import SokobanConfig | |
| soko_env = SokobanConfig( | |
| max_episode_steps=100, num_envs=1, dim_room=(10, 10), num_boxes=1, asynchronous=False, tinyworld_obs=True | |
| ).make() | |
| def parameter_count(root: Path) -> str: | |
| model_dir = os.listdir(root)[0] | |
| cp_dir = os.listdir(root / model_dir)[0] | |
| with open(root / model_dir / cp_dir / "cfg.json", "r") as f: | |
| cfg = json.load(f) | |
| args = farconf.from_dict(cfg["cfg"], Args) | |
| num = args.net.count_params(soko_env) | |
| return f"{num:,} ({num/1_000_000:.2f}M)" | |
| print("- DRC(3, 3): ", parameter_count(Path("drc33"))) | |
| print("- DRC(1, 1): ", parameter_count(Path("drc11"))) | |
| print("- ResNet: ", parameter_count(Path("resnet"))) | |