| from typing import Callable, List, Type |
| import gymnasium as gym |
| import numpy as np |
| from mani_skill.envs.sapien_env import BaseEnv |
| from mani_skill.utils import common, gym_utils |
| import argparse |
| import yaml |
| import torch |
| from collections import deque |
| from PIL import Image |
| import cv2 |
| import imageio |
| from functools import partial |
|
|
| from diffusion_policy.workspace.robotworkspace import RobotWorkspace |
|
|
| def parse_args(args=None): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("-e", "--env-id", type=str, default="PickCube-v1", help=f"Environment to run motion planning solver on. ") |
| parser.add_argument("-o", "--obs-mode", type=str, default="rgb", help="Observation mode to use. Usually this is kept as 'none' as observations are not necesary to be stored, they can be replayed later via the mani_skill.trajectory.replay_trajectory script.") |
| parser.add_argument("-n", "--num-traj", type=int, default=25, help="Number of trajectories to generate.") |
| parser.add_argument("--only-count-success", action="store_true", help="If true, generates trajectories until num_traj of them are successful and only saves the successful trajectories/videos") |
| parser.add_argument("--reward-mode", type=str) |
| parser.add_argument("-b", "--sim-backend", type=str, default="auto", help="Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'") |
| parser.add_argument("--render-mode", type=str, default="rgb_array", help="can be 'sensors' or 'rgb_array' which only affect what is saved to videos") |
| parser.add_argument("--vis", action="store_true", help="whether or not to open a GUI to visualize the solution live") |
| parser.add_argument("--save-video", action="store_true", help="whether or not to save videos locally") |
| parser.add_argument("--traj-name", type=str, help="The name of the trajectory .h5 file that will be created.") |
| parser.add_argument("--shader", default="default", type=str, help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer") |
| parser.add_argument("--record-dir", type=str, default="demos", help="where to save the recorded trajectories") |
| parser.add_argument("--num-procs", type=int, default=1, help="Number of processes to use to help parallelize the trajectory replay process. This uses CPU multiprocessing and only works with the CPU simulation backend at the moment.") |
| parser.add_argument("--random_seed", type=int, default=0, help="Random seed for the environment.") |
| parser.add_argument("--pretrained_path", type=str, default=None, help="Random seed for the environment.") |
|
|
| return parser.parse_args() |
|
|
| task2lang = { |
| "PegInsertionSide-v1": "Pick up a orange-white peg and insert the orange end into the box with a hole in it.", |
| "PickCube-v1": "Grasp a red cube and move it to a target goal position.", |
| "StackCube-v1": "Pick up a red cube and stack it on top of a green cube and let go of the cube without it falling.", |
| "PlugCharger-v1": "Pick up one of the misplaced shapes on the board/kit and insert it into the correct empty slot.", |
| "PushCube-v1": "Push and move a cube to a goal region in front of it." |
| } |
| import random |
| import os |
|
|
| args = parse_args() |
| seed = args.random_seed |
| random.seed(seed) |
| os.environ['PYTHONHASHSEED'] = str(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
| env_id = args.env_id |
| env = gym.make( |
| env_id, |
| obs_mode=args.obs_mode, |
| control_mode="pd_joint_pos", |
| render_mode=args.render_mode, |
| reward_mode="dense" if args.reward_mode is None else args.reward_mode, |
| sensor_configs=dict(shader_pack=args.shader), |
| human_render_camera_configs=dict(shader_pack=args.shader), |
| viewer_camera_configs=dict(shader_pack=args.shader), |
| sim_backend=args.sim_backend |
| ) |
|
|
| from diffusion_policy.workspace.robotworkspace import RobotWorkspace |
| import hydra |
| import dill |
|
|
| checkpoint_path = args.pretrained_path |
| print(f"Loading policy from {checkpoint_path}. Task is {task2lang[env_id]}") |
|
|
| def get_policy(output_dir, device): |
| |
| |
| payload = torch.load(open(checkpoint_path, 'rb'), pickle_module=dill) |
| cfg = payload['cfg'] |
| cls = hydra.utils.get_class(cfg._target_) |
| workspace = cls(cfg, output_dir=output_dir) |
| workspace: RobotWorkspace |
| workspace.load_payload(payload, exclude_keys=None, include_keys=None) |
| |
| |
| policy = workspace.model |
| if cfg.training.use_ema: |
| policy = workspace.ema_model |
| |
| device = torch.device(device) |
| policy.to(device) |
| policy.eval() |
|
|
| return policy |
|
|
| policy = get_policy('./', device = 'cuda') |
| MAX_EPISODE_STEPS = 400 |
| total_episodes = args.num_traj |
| success_count = 0 |
| base_seed = 20241201 |
| instr = task2lang[env_id] |
| import tqdm |
|
|
| DATA_STAT = {'state_min': [-0.7463043928146362, -0.0801204964518547, -0.4976441562175751, -2.657780647277832, -0.5742632150650024, 1.8309762477874756, -2.2423808574676514, 0.0, 0.0], 'state_max': [0.7645499110221863, 1.4967026710510254, 0.4650936424732208, -0.3866899907588959, 0.5505855679512024, 3.2900545597076416, 2.5737812519073486, 0.03999999910593033, 0.03999999910593033], 'action_min': [-0.7472005486488342, -0.08631071448326111, -0.4995281398296356, -2.658363103866577, -0.5751323103904724, 1.8290787935256958, -2.245187997817993, -1.0], 'action_max': [0.7654682397842407, 1.4984270334243774, 0.46786263585090637, -0.38181185722351074, 0.5517147779464722, 3.291581630706787, 2.575840711593628, 1.0], 'action_std': [0.2199309915304184, 0.18780815601348877, 0.13044124841690063, 0.30669933557510376, 0.1340624988079071, 0.24968451261520386, 0.9589747190475464, 0.9827960729598999], 'action_mean': [-0.00885344110429287, 0.5523102879524231, -0.007564723491668701, -2.0108158588409424, 0.004714342765510082, 2.615924596786499, 0.08461848646402359, -0.19301606714725494]} |
|
|
| state_min = torch.tensor(DATA_STAT['state_min']).cuda() |
| state_max = torch.tensor(DATA_STAT['state_max']).cuda() |
| action_min = torch.tensor(DATA_STAT['action_min']).cuda() |
| action_max = torch.tensor(DATA_STAT['action_max']).cuda() |
|
|
| for episode in tqdm.trange(total_episodes): |
| obs_window = deque(maxlen=2) |
| obs, _ = env.reset(seed = episode + base_seed) |
|
|
| img = env.render().cuda().float() |
| proprio = obs['agent']['qpos'][:].cuda() |
| proprio = (proprio - state_min) / (state_max - state_min) * 2 - 1 |
| obs_window.append({ |
| 'agent_pos': proprio, |
| "head_cam": img.permute(0, 3, 1, 2), |
| }) |
| obs_window.append({ |
| 'agent_pos': proprio, |
| "head_cam": img.permute(0, 3, 1, 2), |
| }) |
| |
| global_steps = 0 |
| video_frames = [] |
|
|
| success_time = 0 |
| done = False |
|
|
| while global_steps < MAX_EPISODE_STEPS and not done: |
| obs = obs_window[-1] |
| actions = policy.predict_action(obs) |
| actions = actions['action_pred'].squeeze(0) |
| actions = (actions + 1) / 2 * (action_max - action_min) + action_min |
| actions = actions.detach().cpu().numpy() |
| actions = actions[:8] |
| for idx in range(actions.shape[0]): |
| action = actions[idx] |
| obs, reward, terminated, truncated, info = env.step(action) |
| img = env.render().cuda().float() |
| proprio = obs['agent']['qpos'][:].cuda() |
| proprio = (proprio - state_min) / (state_max - state_min) * 2 - 1 |
| obs_window.append({ |
| 'agent_pos': proprio, |
| "head_cam": img.permute(0, 3, 1, 2), |
| }) |
| video_frames.append(env.render().squeeze(0).detach().cpu().numpy()) |
| global_steps += 1 |
| if terminated or truncated: |
| assert "success" in info, sorted(info.keys()) |
| if info['success']: |
| done = True |
| success_count += 1 |
| break |
| print(f"Trial {episode+1} finished, success: {info['success']}, steps: {global_steps}") |
|
|
| success_rate = success_count / total_episodes * 100 |
| print(f"Tested {total_episodes} episodes, success rate: {success_rate:.2f}%") |
| log_file = f"results_dp_{checkpoint_path.split('/')[-1].split('.')[0]}.txt" |
| with open(log_file, 'a') as f: |
| f.write(f"{args.env_id}:{seed}:{success_count}\n") |