diff --git a/examples/lqr_pg.py b/examples/lqr_pg.py index 8e4e9565..ad84484d 100644 --- a/examples/lqr_pg.py +++ b/examples/lqr_pg.py @@ -29,7 +29,7 @@ def experiment(alg, n_epochs, n_iterations, ep_per_run): logger.info('Experiment Algorithm: ' + alg.__name__) # MDP - mdp = LQR.generate(dimensions=1) + mdp = LQR.generate(dimensions=2, max_action=1., max_pos=1.) approximator = Regressor(LinearApproximator, input_shape=mdp.info.observation_space.shape, @@ -39,13 +39,13 @@ def experiment(alg, n_epochs, n_iterations, ep_per_run): input_shape=mdp.info.observation_space.shape, output_shape=mdp.info.action_space.shape) - sigma_weights = 2 * np.ones(sigma.weights_size) + sigma_weights = 0.25 * np.ones(sigma.weights_size) sigma.set_weights(sigma_weights) policy = StateStdGaussianPolicy(approximator, sigma) # Agent - optimizer = AdaptiveOptimizer(eps=.01) + optimizer = AdaptiveOptimizer(eps=1e-2) algorithm_params = dict(optimizer=optimizer) agent = alg(mdp.info, policy, **algorithm_params) @@ -53,18 +53,18 @@ def experiment(alg, n_epochs, n_iterations, ep_per_run): core = Core(agent, mdp) dataset_eval = core.evaluate(n_episodes=ep_per_run) J = compute_J(dataset_eval, gamma=mdp.info.gamma) - logger.epoch_info(0, J=np.mean(J), policy_weights=policy.get_weights()) + logger.epoch_info(0, J=np.mean(J), policy_weights=policy.get_weights().tolist()) for i in trange(n_epochs, leave=False): core.learn(n_episodes=n_iterations * ep_per_run, n_episodes_per_fit=ep_per_run) dataset_eval = core.evaluate(n_episodes=ep_per_run) J = compute_J(dataset_eval, gamma=mdp.info.gamma) - logger.epoch_info(i+1, J=np.mean(J), policy_weights=policy.get_weights()) + logger.epoch_info(i+1, J=np.mean(J), policy_weights=policy.get_weights().tolist()) if __name__ == '__main__': algs = [REINFORCE, GPOMDP, eNAC] for alg in algs: - experiment(alg, n_epochs=10, n_iterations=4, ep_per_run=100) + experiment(alg, n_epochs=10, n_iterations=4, ep_per_run=25) diff --git a/mushroom_rl/algorithms/policy_search/policy_gradient/enac.py b/mushroom_rl/algorithms/policy_search/policy_gradient/enac.py index e6106712..5df90907 100644 --- a/mushroom_rl/algorithms/policy_search/policy_gradient/enac.py +++ b/mushroom_rl/algorithms/policy_search/policy_gradient/enac.py @@ -42,7 +42,7 @@ def _compute_gradient(self, J): self.sum_grad_log_list = list() - return nat_grad, + return nat_grad def _step_update(self, x, u, r): self.sum_grad_log += self.policy.diff_log(x, u) diff --git a/mushroom_rl/algorithms/policy_search/policy_gradient/gpomdp.py b/mushroom_rl/algorithms/policy_search/policy_gradient/gpomdp.py index 77a1159f..c0a68eb4 100644 --- a/mushroom_rl/algorithms/policy_search/policy_gradient/gpomdp.py +++ b/mushroom_rl/algorithms/policy_search/policy_gradient/gpomdp.py @@ -40,24 +40,26 @@ def __init__(self, mdp_info, policy, optimizer, features=None): np.seterr(divide='ignore', invalid='ignore') def _compute_gradient(self, J): - gradient = np.zeros(self.policy.weights_size) - n_episodes = len(self.list_sum_d_log_pi_ep) - + grad_J_episode = list() for i in range(n_episodes): list_sum_d_log_pi = self.list_sum_d_log_pi_ep[i] list_reward = self.list_reward_ep[i] n_steps = len(list_sum_d_log_pi) + gradient = np.zeros(self.policy.weights_size) + for t in range(n_steps): step_grad = list_sum_d_log_pi[t] step_reward = list_reward[t] - baseline = self.baseline_num[t] / self.baseline_den[t] + baseline = np.mean(self.baseline_num[t], axis=0) / np.mean(self.baseline_den[t], axis=0) baseline[np.logical_not(np.isfinite(baseline))] = 0. - gradient += (step_reward - baseline) * step_grad + gradient += step_grad * (step_reward - baseline) - gradient /= n_episodes + grad_J_episode.append(gradient) + + gradJ = np.mean(grad_J_episode, axis=0) self.list_reward_ep = list() self.list_sum_d_log_pi_ep = list() @@ -65,7 +67,7 @@ def _compute_gradient(self, J): self.baseline_num = list() self.baseline_den = list() - return gradient, + return gradJ def _step_update(self, x, u, r): discounted_reward = self.df * r @@ -74,17 +76,16 @@ def _step_update(self, x, u, r): d_log_pi = self.policy.diff_log(x, u) self.sum_d_log_pi += d_log_pi - self.list_sum_d_log_pi.append(self.sum_d_log_pi) + self.list_sum_d_log_pi.append(self.sum_d_log_pi.copy()) squared_sum_d_log_pi = np.square(self.sum_d_log_pi) - if self.step_count < len(self.baseline_num): - self.baseline_num[ - self.step_count] += discounted_reward * squared_sum_d_log_pi - self.baseline_den[self.step_count] += squared_sum_d_log_pi - else: - self.baseline_num.append(discounted_reward * squared_sum_d_log_pi) - self.baseline_den.append(squared_sum_d_log_pi) + if self.step_count >= len(self.baseline_num): + self.baseline_num.append(list()) + self.baseline_den.append(list()) + + self.baseline_num[self.step_count].append(discounted_reward * squared_sum_d_log_pi) + self.baseline_den[self.step_count].append(squared_sum_d_log_pi) self.step_count += 1 diff --git a/mushroom_rl/algorithms/policy_search/policy_gradient/policy_gradient.py b/mushroom_rl/algorithms/policy_search/policy_gradient/policy_gradient.py index 332e1f9e..755d1438 100644 --- a/mushroom_rl/algorithms/policy_search/policy_gradient/policy_gradient.py +++ b/mushroom_rl/algorithms/policy_search/policy_gradient/policy_gradient.py @@ -60,16 +60,11 @@ def _update_parameters(self, J): episode in the dataset. """ - res = self._compute_gradient(J) + grad = self._compute_gradient(J) theta_old = self.policy.get_weights() - if len(res) == 1: - grad = res[0] - theta_new = self.optimizer(theta_old, grad) - else: - grad, nat_grad = res - theta_new = self.optimizer(theta_old, grad, nat_grad) + theta_new = self.optimizer(theta_old, grad) self.policy.set_weights(theta_new) @@ -111,6 +106,9 @@ def _compute_gradient(self, J): J (list): list of the cumulative discounted rewards for each episode in the dataset. + Returns: + The gradient computed by the algorithm. + """ raise NotImplementedError('PolicyGradient is an abstract class') diff --git a/mushroom_rl/algorithms/policy_search/policy_gradient/reinforce.py b/mushroom_rl/algorithms/policy_search/policy_gradient/reinforce.py index bec19a9d..15d4db55 100644 --- a/mushroom_rl/algorithms/policy_search/policy_gradient/reinforce.py +++ b/mushroom_rl/algorithms/policy_search/policy_gradient/reinforce.py @@ -28,8 +28,7 @@ def __init__(self, mdp_info, policy, optimizer, features=None): np.seterr(divide='ignore', invalid='ignore') def _compute_gradient(self, J): - baseline = np.mean(self.baseline_num, axis=0) / np.mean( - self.baseline_den, axis=0) + baseline = np.mean(self.baseline_num, axis=0) / np.mean(self.baseline_den, axis=0) baseline[np.logical_not(np.isfinite(baseline))] = 0. grad_J_episode = list() for i, J_episode in enumerate(J): @@ -41,7 +40,7 @@ def _compute_gradient(self, J): self.baseline_den = list() self.baseline_num = list() - return grad_J, + return grad_J def _step_update(self, x, u, r): d_log_pi = self.policy.diff_log(x, u) diff --git a/mushroom_rl/environments/__init__.py b/mushroom_rl/environments/__init__.py index 5a0f166e..65f456af 100644 --- a/mushroom_rl/environments/__init__.py +++ b/mushroom_rl/environments/__init__.py @@ -30,7 +30,6 @@ try: PyBullet = None from .pybullet import PyBullet - from .pybullet_envs import * except ImportError: pass diff --git a/mushroom_rl/environments/pybullet.py b/mushroom_rl/environments/pybullet.py index 0a33d262..58035366 100644 --- a/mushroom_rl/environments/pybullet.py +++ b/mushroom_rl/environments/pybullet.py @@ -1,66 +1,10 @@ import numpy as np -from enum import Enum import pybullet import pybullet_data from pybullet_utils.bullet_client import BulletClient from mushroom_rl.core import Environment, MDPInfo from mushroom_rl.utils.spaces import Box -from mushroom_rl.utils.viewer import ImageViewer - - -class PyBulletObservationType(Enum): - """ - An enum indicating the type of data that should be added to the observation - of the environment, can be Joint-/Body-/Site- positions and velocities. - - """ - __order__ = "BODY_POS BODY_LIN_VEL BODY_ANG_VEL JOINT_POS JOINT_VEL LINK_POS LINK_LIN_VEL LINK_ANG_VEL" - BODY_POS = 0 - BODY_LIN_VEL = 1 - BODY_ANG_VEL = 2 - JOINT_POS = 3 - JOINT_VEL = 4 - LINK_POS = 5 - LINK_LIN_VEL = 6 - LINK_ANG_VEL = 7 - - -class PyBulletViewer(ImageViewer): - def __init__(self, client, dt, size=(500, 500), distance=4, origin=(0, 0, 1), angles=(0, -45, 60), - fov=60, aspect=1, near_val=0.01, far_val=100): - self._client = client - self._size = size - self._distance = distance - self._origin = origin - self._angles = angles - self._fov = fov - self._aspect = aspect - self._near_val = near_val - self._far_val = far_val - super().__init__(size, dt) - - def display(self): - img = self._get_image() - super().display(img) - - def _get_image(self): - view_matrix = self._client.computeViewMatrixFromYawPitchRoll(cameraTargetPosition=self._origin, - distance=self._distance, - roll=self._angles[0], - pitch=self._angles[1], - yaw=self._angles[2], - upAxisIndex=2) - proj_matrix = self._client.computeProjectionMatrixFOV(fov=self._fov, aspect=self._aspect, - nearVal=self._near_val, farVal=self._far_val) - (_, _, px, _, _) = self._client.getCameraImage(width=self._size[0], - height=self._size[1], - viewMatrix=view_matrix, - projectionMatrix=proj_matrix, - renderer=pybullet.ER_BULLET_HARDWARE_OPENGL) - - rgb_array = np.reshape(np.array(px), (self._size[0], self._size[1], -1)) - rgb_array = rgb_array[:, :, :3] - return rgb_array +from mushroom_rl.utils.pybullet import * class PyBullet(Environment): @@ -69,13 +13,13 @@ class PyBullet(Environment): """ def __init__(self, files, actuation_spec, observation_spec, gamma, - horizon, timestep=1/240, n_intermediate_steps=1, + horizon, timestep=1/240, n_intermediate_steps=1, enforce_joint_velocity_limits=False, debug_gui=False, **viewer_params): """ Constructor. Args: - files (list): Paths to the URDF files to load; + files (dict): dictionary of the URDF/MJCF/SDF files to load (key) and parameters dictionary (value); actuation_spec (list): A list of tuples specifying the names of the joints which should be controllable by the agent and tehir control mode. Can be left empty when all actuators should be used in position control; @@ -90,10 +34,15 @@ def __init__(self, files, actuation_spec, observation_spec, gamma, n_intermediate_steps (int): The number of steps between every action taken by the agent. Allows the user to modify, control and access intermediate states; + enforce_joint_velocity_limits (bool, False): flag to enforce the velocity limits; + debug_gui (bool, False): flag to activate the default pybullet visualizer, that can be used for debug + purposes; **viewer_params: other parameters to be passed to the viewer. See PyBulletViewer documentation for the available options. """ + assert len(observation_spec) > 0 + assert len(actuation_spec) > 0 # Store simulation parameters self._timestep = timestep @@ -102,68 +51,34 @@ def __init__(self, files, actuation_spec, observation_spec, gamma, # Create the simulation and viewer if debug_gui: self._client = BulletClient(connection_mode=pybullet.GUI) + self._client.configureDebugVisualizer(pybullet.COV_ENABLE_GUI, 0) else: self._client = BulletClient(connection_mode=pybullet.DIRECT) self._client.setTimeStep(self._timestep) self._client.setGravity(0, 0, -9.81) self._client.setAdditionalSearchPath(pybullet_data.getDataPath()) - self._viewer = PyBulletViewer(self._client, self._timestep * self._n_intermediate_steps, **viewer_params) + self._viewer = PyBulletViewer(self._client, self.dt, **viewer_params) self._state = None # Load model and create access maps self._model_map = dict() - for file_name, kwargs in files.items(): - model_id = self._client.loadURDF(file_name, **kwargs) - model_name = self._client.getBodyInfo(model_id)[1].decode('UTF-8') - self._model_map[model_name] = model_id - self._model_map.update(self._custom_load_models()) + self._load_all_models(files, enforce_joint_velocity_limits) + + # Build utils + self._indexer = IndexMap(self._client, self._model_map, actuation_spec, observation_spec) + self.joints = JointsHelper(self._client, self._indexer, observation_spec) + + # Finally, we create the MDP information and call the constructor of the parent class + action_space = Box(*self._indexer.action_limits) - self._joint_map = dict() - self._link_map = dict() - for model_id in self._model_map.values(): - for joint_id in range(self._client.getNumJoints(model_id)): - joint_data = self._client.getJointInfo(model_id, joint_id) - if joint_data[2] != pybullet.JOINT_FIXED: - joint_name = joint_data[1].decode('UTF-8') - self._joint_map[joint_name] = (model_id, joint_id) - link_name = joint_data[12].decode('UTF-8') - self._link_map[link_name] = (model_id, joint_id) - - # Read the actuation spec and build the mapping between actions and ids - # as well as their limits - assert(len(actuation_spec) > 0) - self._action_data = list() - for name, mode in actuation_spec: - if name in self._joint_map: - data = self._joint_map[name] + (mode,) - self._action_data.append(data) - - if mode == pybullet.TORQUE_CONTROL: - self._client.setJointMotorControl2(data[0], data[1], - controlMode=pybullet.VELOCITY_CONTROL, - force=0) - - low, high = self._compute_action_limits() - action_space = Box(np.array(low), np.array(high)) - - # Read the observation spec to build a mapping at every step. It is - # ensured that the values appear in the order they are specified. - if len(observation_spec) == 0: - raise AttributeError("No Environment observations were specified. " - "Add at least one observation to the observation_spec.") - - self._observation_map = observation_spec - self._observation_indices_map = dict() - - # We can only specify limits for the joint positions, all other - # information can be potentially unbounded - low, high = self._compute_observation_limits() - observation_space = Box(low, high) - - # Finally, we create the MDP information and call the constructor of - # the parent class + observation_space = Box(*self._indexer.observation_limits) mdp_info = MDPInfo(observation_space, action_space, gamma, horizon) + + # Let the child class modify the mdp_info data structure + mdp_info = self._modify_mdp_info(mdp_info) + + # Provide the structure to the superclass super().__init__(mdp_info) # Save initial state of the MDP @@ -174,9 +89,11 @@ def seed(self, seed): def reset(self, state=None): self._client.restoreState(self._initial_state) - self.setup() - self._state = self._create_observation() - return self._state + self.setup(state) + self._state = self._indexer.create_sim_state() + observation = self._create_observation(self._state) + + return observation def render(self): self._viewer.display() @@ -185,16 +102,16 @@ def stop(self): pass def step(self, action): - cur_obs = self._state + curr_state = self._state.copy() action = self._preprocess_action(action) - self._step_init(cur_obs, action) + self._step_init(curr_state, action) for i in range(self._n_intermediate_steps): - ctrl_action = self._compute_action(action) - self._apply_control(ctrl_action) + ctrl_action = self._compute_action(curr_state, action) + self._indexer.apply_control(ctrl_action) self._simulation_pre_step() @@ -202,18 +119,23 @@ def step(self, action): self._simulation_post_step() - self._state = self._create_observation() + curr_state = self._indexer.create_sim_state() self._step_finalize() - reward = self.reward(cur_obs, action, self._state) + absorbing = self.is_absorbing(curr_state) + reward = self.reward(self._state, action, curr_state, absorbing) - return self._state, reward, self.is_absorbing(self._state), {} + observation = self._create_observation(curr_state) - def get_observation_index(self, name, obs_type): - return self._observation_indices_map[name][obs_type] + self._state = curr_state - def get_observation(self, obs, name, obs_type): + return observation, reward, absorbing, {} + + def get_sim_state_index(self, name, obs_type): + return self._indexer.get_index(name, obs_type) + + def get_sim_state(self, obs, name, obs_type): """ Returns a specific observation value @@ -226,126 +148,64 @@ def get_observation(self, obs, name, obs_type): The required elements of the input state vector. """ - indices = self.get_observation_index(name, obs_type) + indices = self.get_sim_state_index(name, obs_type) return obs[indices] - def _compute_action_limits(self): - low = list() - high = list() - - for model_id, joint_id, mode in self._action_data: - joint_info = self._client.getJointInfo(model_id, joint_id) - if mode is pybullet.POSITION_CONTROL: - low.append(joint_info[8]) - high.append(joint_info[9]) - elif mode is pybullet.VELOCITY_CONTROL: - low.append(-joint_info[11]) - high.append(joint_info[11]) - elif mode is pybullet.TORQUE_CONTROL: - low.append(-joint_info[10]) - high.append(joint_info[10]) - else: - raise NotImplementedError - - return np.array(low), np.array(high) - - def _compute_observation_limits(self): - low = list() - high = list() - - for name, obs_type in self._observation_map: - index_count = len(low) - if obs_type is PyBulletObservationType.BODY_POS \ - or obs_type is PyBulletObservationType.BODY_LIN_VEL \ - or obs_type is PyBulletObservationType.BODY_ANG_VEL: - n_dim = 7 if obs_type is PyBulletObservationType.BODY_POS else 3 - low += [-np.inf] * n_dim - high += [-np.inf] * n_dim - elif obs_type is PyBulletObservationType.LINK_POS \ - or obs_type is PyBulletObservationType.LINK_LIN_VEL \ - or obs_type is PyBulletObservationType.LINK_ANG_VEL: - n_dim = 7 if obs_type is PyBulletObservationType.LINK_POS else 3 - low += [-np.inf] * n_dim - high += [-np.inf] * n_dim - else: - model_id, joint_id = self._joint_map[name] - joint_info = self._client.getJointInfo(model_id, joint_id) - - if obs_type is PyBulletObservationType.JOINT_POS: - low.append(joint_info[8]) - high.append(joint_info[9]) - else: - low.append(-np.inf) - high.append(np.inf) - - self._add_observation_index(name, obs_type, index_count, len(low)) - - return np.array(low), np.array(high) - - def _add_observation_index(self, name, obs_type, start, end): - if name not in self._observation_indices_map: - self._observation_indices_map[name] = dict() - - self._observation_indices_map[name][obs_type] = list(range(start, end)) - - def _create_observation(self): - data_obs = list() - - for name, obs_type in self._observation_map: - if obs_type is PyBulletObservationType.BODY_POS \ - or obs_type is PyBulletObservationType.BODY_LIN_VEL \ - or obs_type is PyBulletObservationType.BODY_ANG_VEL: - model_id = self._model_map[name] - if obs_type is PyBulletObservationType.BODY_POS: - t, q = self._client.getBasePositionAndOrientation(model_id) - data_obs += t + q - else: - v, w = self._client.getBaseVelocity(model_id) - if obs_type is PyBulletObservationType.BODY_LIN_VEL: - data_obs += v - else: - data_obs += w - elif obs_type is PyBulletObservationType.LINK_POS \ - or obs_type is PyBulletObservationType.LINK_LIN_VEL \ - or obs_type is PyBulletObservationType.LINK_ANG_VEL: - model_id, link_id = self._link_map[name] - - if obs_type is PyBulletObservationType.LINK_POS: - link_data = self._client.getLinkState(model_id, link_id) - t = link_data[0] - q = link_data[1] - data_obs += t + q - elif obs_type is PyBulletObservationType.LINK_LIN_VEL: - data_obs += self._client.getLinkState(model_id, link_id, computeLinkVelocity=True)[-2] - elif obs_type is PyBulletObservationType.LINK_ANG_VEL: - data_obs += self._client.getLinkState(model_id, link_id, computeLinkVelocity=True)[-1] - else: - model_id, joint_id = self._joint_map[name] - pos, vel, _, _ = self._client.getJointState(model_id, joint_id) - if obs_type is PyBulletObservationType.JOINT_POS: - data_obs.append(pos) - elif obs_type is PyBulletObservationType.JOINT_VEL: - data_obs.append(vel) - - return np.array(data_obs) - - def _apply_control(self, action): - - i = 0 - for model_id, joint_id, mode in self._action_data: - u = action[i] - if mode is pybullet.POSITION_CONTROL: - kwargs = dict(targetPosition=u, maxVelocity=self._client.getJointInfo(model_id, joint_id)[11]) - elif mode is pybullet.VELOCITY_CONTROL: - kwargs = dict(targetVelocity=u, maxVelocity=self._client.getJointInfo(model_id, joint_id)[11]) - elif mode is pybullet.TORQUE_CONTROL: - kwargs = dict(force=u) - else: - raise NotImplementedError - - self._client.setJointMotorControl2(model_id, joint_id, mode, **kwargs) - i += 1 + def _modify_mdp_info(self, mdp_info): + """ + This method can be overridden to modify the automatically generated MDPInfo data structure. + By default, returns the given mdp_info structure unchanged. + + Args: + mdp_info (MDPInfo): the MDPInfo structure automatically computed by the environment. + + Returns: + The modified MDPInfo data structure. + + """ + return mdp_info + + def _create_observation(self, state): + """ + This method can be overridden to ctreate an observation vector from the simulator state vector. + By default, returns the simulator state vector unchanged. + + Args: + state (np.ndarray): the simulator state vector. + + Returns: + The environment observation. + + """ + return state + + def _load_model(self, file_name, kwargs): + if file_name.endswith('.urdf'): + model_id = self._client.loadURDF(file_name, **kwargs) + elif file_name.endswith('.sdf'): + model_id = self._client.loadSDF(file_name, **kwargs)[0] + else: + model_id = self._client.loadMJCF(file_name, **kwargs)[0] + + for j in range(self._client.getNumJoints(model_id)): + self._client.setJointMotorControl2(model_id, j, pybullet.POSITION_CONTROL, force=0) + + return model_id + + def _load_all_models(self, files, enforce_joint_velocity_limits): + for file_name, kwargs in files.items(): + model_id = self._load_model(file_name, kwargs) + model_name = self._client.getBodyInfo(model_id)[1].decode('UTF-8') + self._model_map[model_name] = model_id + self._model_map.update(self._custom_load_models()) + + # Enforce velocity limits on every joint + if enforce_joint_velocity_limits: + for model_id in self._model_map.values(): + for joint_id in range(self._client.getNumJoints(model_id)): + joint_data = self._client.getJointInfo(model_id, joint_id) + self._client.changeDynamics(model_id, joint_id, maxJointVelocity=joint_data[11]) def _preprocess_action(self, action): """ @@ -367,14 +227,14 @@ def _step_init(self, state, action): """ pass - def _compute_action(self, action): + def _compute_action(self, state, action): """ Compute a transformation of the action at every intermediate step. Useful to add control signals simulated directly in python. Args: - action (np.ndarray): numpy array with the actions - provided at every step. + state (np.ndarray): numpy array with the current state of teh simulation; + action (np.ndarray): numpy array with the actions, provided at every step. Returns: The action to be set in the actual pybullet simulation. @@ -413,15 +273,15 @@ def _custom_load_models(self): """ return list() - def reward(self, state, action, next_state): + def reward(self, state, action, next_state, absorbing): """ Compute the reward based on the given transition. Args: state (np.array): the current state of the system; action (np.array): the action that is applied in the current state; - next_state (np.array): the state reached after applying the given - action. + next_state (np.array): the state reached after applying the given action; + absorbing (bool): whether next_state is an absorbing state or not. Returns: The reward as a floating point scalar value. @@ -442,11 +302,16 @@ def is_absorbing(self, state): """ raise NotImplementedError - def setup(self): + def setup(self, state): """ A function that allows to execute setup code after an environment reset. + Args: + state (np.ndarray): the state to be restored. If the state should be + chosen by the environment, state is None. Environments can ignore this + value if the initial state cannot be set programmatically. + """ pass @@ -454,3 +319,7 @@ def setup(self): def client(self): return self._client + @property + def dt(self): + return self._timestep * self._n_intermediate_steps + diff --git a/mushroom_rl/environments/pybullet_envs/__init__.py b/mushroom_rl/environments/pybullet_envs/__init__.py deleted file mode 100644 index 6ba022fd..00000000 --- a/mushroom_rl/environments/pybullet_envs/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .hexapod import HexapodBullet - -HexapodBullet.register() \ No newline at end of file diff --git a/mushroom_rl/environments/pybullet_envs/data/__init__.py b/mushroom_rl/environments/pybullet_envs/data/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/mushroom_rl/environments/pybullet_envs/data/hexapod/hexapod.urdf b/mushroom_rl/environments/pybullet_envs/data/hexapod/hexapod.urdf deleted file mode 100644 index 2029f842..00000000 --- a/mushroom_rl/environments/pybullet_envs/data/hexapod/hexapod.urdf +++ /dev/null @@ -1,963 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/base.stl b/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/base.stl deleted file mode 100644 index 1c7c9dd4..00000000 Binary files a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/base.stl and /dev/null differ diff --git a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/leg_0_left.stl b/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/leg_0_left.stl deleted file mode 100644 index e9096962..00000000 Binary files a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/leg_0_left.stl and /dev/null differ diff --git a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/leg_0_right.stl b/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/leg_0_right.stl deleted file mode 100644 index abe69380..00000000 Binary files a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/leg_0_right.stl and /dev/null differ diff --git a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/leg_1.stl b/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/leg_1.stl deleted file mode 100644 index 9948294e..00000000 Binary files a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/leg_1.stl and /dev/null differ diff --git a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/leg_2_left.stl b/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/leg_2_left.stl deleted file mode 100644 index b6134ab1..00000000 Binary files a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/leg_2_left.stl and /dev/null differ diff --git a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/leg_2_right.stl b/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/leg_2_right.stl deleted file mode 100644 index 0bab3e80..00000000 Binary files a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/leg_2_right.stl and /dev/null differ diff --git a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/motor.stl b/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/motor.stl deleted file mode 100644 index b5f4f028..00000000 Binary files a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/collision/motor.stl and /dev/null differ diff --git a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/base.stl b/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/base.stl deleted file mode 100644 index 972582f4..00000000 Binary files a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/base.stl and /dev/null differ diff --git a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/leg_0_left.stl b/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/leg_0_left.stl deleted file mode 100644 index 4c02c3c6..00000000 Binary files a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/leg_0_left.stl and /dev/null differ diff --git a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/leg_0_right.stl b/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/leg_0_right.stl deleted file mode 100644 index c439c1a9..00000000 Binary files a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/leg_0_right.stl and /dev/null differ diff --git a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/leg_1.stl b/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/leg_1.stl deleted file mode 100644 index d86c1dc5..00000000 Binary files a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/leg_1.stl and /dev/null differ diff --git a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/leg_2_left.stl b/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/leg_2_left.stl deleted file mode 100644 index 47f31bdc..00000000 Binary files a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/leg_2_left.stl and /dev/null differ diff --git a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/leg_2_right.stl b/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/leg_2_right.stl deleted file mode 100644 index 1bc8d804..00000000 Binary files a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/leg_2_right.stl and /dev/null differ diff --git a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/motor.stl b/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/motor.stl deleted file mode 100644 index 21224648..00000000 Binary files a/mushroom_rl/environments/pybullet_envs/data/hexapod/meshes/visual/motor.stl and /dev/null differ diff --git a/mushroom_rl/environments/pybullet_envs/hexapod.py b/mushroom_rl/environments/pybullet_envs/hexapod.py deleted file mode 100644 index 393536a5..00000000 --- a/mushroom_rl/environments/pybullet_envs/hexapod.py +++ /dev/null @@ -1,184 +0,0 @@ -import time -import numpy as np -import pybullet -from mushroom_rl.environments.pybullet import PyBullet, PyBulletObservationType - -from pathlib import Path - -from mushroom_rl.environments.pybullet_envs import __file__ as path_robots - - -class HexapodBullet(PyBullet): - def __init__(self, gamma=0.99, horizon=1000, goal=None, debug_gui=False): - hexapod_path = Path(path_robots).absolute().parent / 'data' / 'hexapod'/ 'hexapod.urdf' - self.robot_path = str(hexapod_path) - - action_spec = [ - ("hexapod/leg_0/joint_0", pybullet.VELOCITY_CONTROL), - ("hexapod/leg_0/joint_1", pybullet.VELOCITY_CONTROL), - ("hexapod/leg_0/joint_2", pybullet.VELOCITY_CONTROL), - - ("hexapod/leg_1/joint_0", pybullet.VELOCITY_CONTROL), - ("hexapod/leg_1/joint_1", pybullet.VELOCITY_CONTROL), - ("hexapod/leg_1/joint_2", pybullet.VELOCITY_CONTROL), - - ("hexapod/leg_2/joint_0", pybullet.VELOCITY_CONTROL), - ("hexapod/leg_2/joint_1", pybullet.VELOCITY_CONTROL), - ("hexapod/leg_2/joint_2", pybullet.VELOCITY_CONTROL), - - ("hexapod/leg_3/joint_0", pybullet.VELOCITY_CONTROL), - ("hexapod/leg_3/joint_1", pybullet.VELOCITY_CONTROL), - ("hexapod/leg_3/joint_2", pybullet.VELOCITY_CONTROL), - - ("hexapod/leg_4/joint_0", pybullet.VELOCITY_CONTROL), - ("hexapod/leg_4/joint_1", pybullet.VELOCITY_CONTROL), - ("hexapod/leg_4/joint_2", pybullet.VELOCITY_CONTROL), - - ("hexapod/leg_5/joint_0", pybullet.VELOCITY_CONTROL), - ("hexapod/leg_5/joint_1", pybullet.VELOCITY_CONTROL), - ("hexapod/leg_5/joint_2", pybullet.VELOCITY_CONTROL), - ] - - observation_spec = [ - ("hexapod/leg_0/joint_0", PyBulletObservationType.JOINT_POS), - ("hexapod/leg_0/joint_1", PyBulletObservationType.JOINT_POS), - ("hexapod/leg_0/joint_2", PyBulletObservationType.JOINT_POS), - - ("hexapod/leg_1/joint_0", PyBulletObservationType.JOINT_POS), - ("hexapod/leg_1/joint_1", PyBulletObservationType.JOINT_POS), - ("hexapod/leg_1/joint_2", PyBulletObservationType.JOINT_POS), - - ("hexapod/leg_2/joint_0", PyBulletObservationType.JOINT_POS), - ("hexapod/leg_2/joint_1", PyBulletObservationType.JOINT_POS), - ("hexapod/leg_2/joint_2", PyBulletObservationType.JOINT_POS), - - ("hexapod/leg_3/joint_0", PyBulletObservationType.JOINT_POS), - ("hexapod/leg_3/joint_1", PyBulletObservationType.JOINT_POS), - ("hexapod/leg_3/joint_2", PyBulletObservationType.JOINT_POS), - - ("hexapod/leg_4/joint_0", PyBulletObservationType.JOINT_POS), - ("hexapod/leg_4/joint_1", PyBulletObservationType.JOINT_POS), - ("hexapod/leg_4/joint_2", PyBulletObservationType.JOINT_POS), - - ("hexapod/leg_5/joint_0", PyBulletObservationType.JOINT_POS), - ("hexapod/leg_5/joint_1", PyBulletObservationType.JOINT_POS), - ("hexapod/leg_5/joint_2", PyBulletObservationType.JOINT_POS), - - ("hexapod", PyBulletObservationType.BODY_POS), - ("hexapod", PyBulletObservationType.BODY_LIN_VEL) - ] - - files = { - self.robot_path: dict(basePosition=[0.0, 0, 0.12], - baseOrientation=[0, 0, 0.0, 1.0], - flags=pybullet.URDF_USE_SELF_COLLISION), - 'plane.urdf': {} - } - - super().__init__(files, action_spec, observation_spec, gamma, horizon, n_intermediate_steps=8, debug_gui=debug_gui, - distance=3, origin=[0., 0., 0.], angles=[0., -45., 0.]) - - self._client.setGravity(0, 0, -9.81) - - self.hexapod_initial_state = np.array( - [-0.66, 0.66, -1.45, - 0.66, -0.66, 1.45, - 0.00, 0.66, -1.45, - 0.00, -0.66, 1.45, - 0.66, 0.66, -1.45, - -0.66, -0.66, 1.45] - ) - - self._goal = np.array([2.0, 2.0]) if goal is None else goal - - def setup(self): - for i, (model_id, joint_id, _) in enumerate(self._action_data): - self._client.resetJointState(model_id, joint_id, self.hexapod_initial_state[i]) - - self._client.resetDebugVisualizerCamera(cameraDistance=3, cameraYaw=0.0, cameraPitch=-45, - cameraTargetPosition=[0., 0., 0.]) - - self._filter_collisions() - - def reward(self, state, action, next_state): - - pose = self.get_observation(next_state, "hexapod", PyBulletObservationType.BODY_POS) - euler = pybullet.getEulerFromQuaternion(pose[3:]) - - goal_distance = np.linalg.norm(pose[:2] - self._goal) - goal_reward = np.exp(-goal_distance) - - attitude_distance = np.linalg.norm(euler[:2]) - attitude_reward = np.exp(-attitude_distance) - - action_penalty = np.linalg.norm(action) - - self_collisions_penalty = 1.0*self._count_self_collisions() - - return goal_reward + 1e-2*attitude_reward - 1e-3*action_penalty - self_collisions_penalty - - def is_absorbing(self, state): - pose = self.get_observation(state, "hexapod", PyBulletObservationType.BODY_POS) - - euler = pybullet.getEulerFromQuaternion(pose[3:]) - - return pose[2] > 0.5 or abs(euler[0]) > np.pi/2 or abs(euler[1]) > np.pi/2 or self._count_self_collisions() >= 2 - - def _count_self_collisions(self): - hexapod_id = self._model_map['hexapod'] - - collision_count = 0 - collisions = self._client.getContactPoints(hexapod_id) - - for collision in collisions: - body_2 = collision[2] - if body_2 == hexapod_id: - collision_count += 1 - - return collision_count - - def _filter_collisions(self): - # Disable fixed links collisions - for leg_n in range(6): - for link_n in range(3): - motor_name = f'hexapod/leg_{leg_n}/motor_{link_n}' - link_name = f'hexapod/leg_{leg_n}/link_{link_n}' - self._client.setCollisionFilterPair(self._link_map[motor_name][0], self._link_map[link_name][0], - self._link_map[motor_name][1], self._link_map[link_name][1], 0) - - -if __name__ == '__main__': - from mushroom_rl.core import Core - from mushroom_rl.core import Agent - from mushroom_rl.utils.dataset import compute_J - - - class DummyAgent(Agent): - def __init__(self, n_actions): - self._n_actions = n_actions - - def draw_action(self, state): - time.sleep(0.01) - - return np.random.randn(self._n_actions) - - def episode_start(self): - pass - - def fit(self, dataset): - pass - - - mdp = HexapodBullet(debug_gui=True) - agent = DummyAgent(mdp.info.action_space.shape[0]) - - core = Core(agent, mdp) - dataset = core.evaluate(n_episodes=10, render=False) - print('reward: ', compute_J(dataset, mdp.info.gamma)) - print("mdp_info state shape", mdp.info.observation_space.shape) - print("actual state shape", dataset[0][0].shape) - print("mdp_info action shape", mdp.info.action_space.shape) - print("actual action shape", dataset[0][1].shape) - - print("action low", mdp.info.action_space.low) - print("action high", mdp.info.action_space.high) diff --git a/mushroom_rl/utils/pybullet/__init__.py b/mushroom_rl/utils/pybullet/__init__.py new file mode 100644 index 00000000..66de4454 --- /dev/null +++ b/mushroom_rl/utils/pybullet/__init__.py @@ -0,0 +1,4 @@ +from .observation import PyBulletObservationType +from .index_map import IndexMap +from .viewer import PyBulletViewer +from .joints_helper import JointsHelper diff --git a/mushroom_rl/utils/pybullet/contacts.py b/mushroom_rl/utils/pybullet/contacts.py new file mode 100644 index 00000000..b886efcb --- /dev/null +++ b/mushroom_rl/utils/pybullet/contacts.py @@ -0,0 +1,90 @@ +class ContactHelper(object): + def __init__(self, client, contacts, model_map, link_map): + self._client = client + + self._contact_list = list() + self._contact_name_map = dict() + self._computed_contacts = dict() + + for contact in contacts: + name_1, name_2 = contact.split('<->') + + model_1, link_1 = self._get_link_ids(name_1, link_map, model_map) + model_2, link_2 = self._get_link_ids(name_2, link_map, model_map) + + self._contact_name_map[contact] = (model_1, model_2, link_1, link_2) + self._add_contact(model_1, model_2, link_1, link_2) + + def compute_contacts(self): + if len(self._contact_list) == 0: + return + + self._reset_computed_contacts() + for contact in self._contact_list: + model_a, model_b = contact[0] + contacts_points = self._client.getContactPoints(model_a, model_b) + + for contact_p in contacts_points: + involved_links = contact_p[3:5] + if involved_links in contact[1]: + self._computed_contacts[model_a][model_b][contact_p[3]][contact_p[4]] = contact_p + + def get_contact(self, contact_name): + model_1, model_2, link_1, link_2 = self._order_contact(*self._contact_name_map[contact_name]) + return self._computed_contacts[model_1][model_2][link_1][link_2] + + def _add_contact(self, model_1, model_2, link_1, link_2): + model_a, model_b, link_a, link_b = self._order_contact(model_1, model_2, link_1, link_2) + + done = False + for contact in self._contact_list: + models = contact[0] + links_list = contact[1] + if models == (model_a, model_b): + exists = False + for links in links_list: + if (link_a, link_b) == links: + exists = True + break + if not exists: + links_list.append((link_a, link_b)) + done = True + break + + if not done: + contact = ((model_a, model_b), [(link_a, link_b)]) + self._contact_list.append(contact) + + def _reset_computed_contacts(self): + for contact in self._contact_name_map.values(): + model_a, model_b, link_a, link_b = self._order_contact(*contact) + + if model_a not in self._computed_contacts: + self._computed_contacts[model_a] = dict() + if model_b not in self._computed_contacts[model_a]: + self._computed_contacts[model_a][model_b] = dict() + if link_a not in self._computed_contacts[model_a][model_b]: + self._computed_contacts[model_a][model_b][link_a] = dict() + self._computed_contacts[model_a][model_b][link_a][link_b] = None + + @staticmethod + def _get_link_ids(name, link_map, model_map): + if name in link_map: + return link_map[name] + else: + return model_map[name], -1 + + @staticmethod + def _order_contact(model_1, model_2, link_1, link_2): + if model_1 < model_2: + model_a = model_1 + link_a = link_1 + model_b = model_2 + link_b = link_2 + else: + model_a = model_2 + link_a = link_2 + model_b = model_1 + link_b = link_1 + + return model_a, model_b, link_a, link_b diff --git a/mushroom_rl/utils/pybullet/index_map.py b/mushroom_rl/utils/pybullet/index_map.py new file mode 100644 index 00000000..afd21539 --- /dev/null +++ b/mushroom_rl/utils/pybullet/index_map.py @@ -0,0 +1,189 @@ +import numpy as np +import pybullet +from .observation import PyBulletObservationType +from .contacts import ContactHelper + + +class IndexMap(object): + def __init__(self, client, model_map, actuation_spec, observation_spec): + self._client = client + self.model_map = model_map + self.joint_map = dict() + self.link_map = dict() + + self._build_joint_and_link_maps() + + # Contact utils + contact_types = [PyBulletObservationType.CONTACT_FLAG] + contacts = [obs[0] for obs in observation_spec if obs[1] in contact_types] + self._contacts = ContactHelper(client, contacts, self.model_map, self.link_map) + + # Read the actuation spec and build the mapping between actions and ids as well as their limits + self.action_data = list() + self._action_low, self._action_high = self._process_actuation_spec(actuation_spec) + + # Read the observation spec to build a mapping at every step. + # It is ensured that the values appear in the order they are specified. + self.observation_map = observation_spec + self.observation_indices_map = dict() + + # We can only specify limits for the joints, all other information can be potentially unbounded + self._observation_low, self._observation_high = self._process_observations() + + def create_sim_state(self): + data_obs = list() + + self._contacts.compute_contacts() + + for name, obs_type in self.observation_map: + if obs_type is PyBulletObservationType.BODY_POS \ + or obs_type is PyBulletObservationType.BODY_LIN_VEL \ + or obs_type is PyBulletObservationType.BODY_ANG_VEL: + model_id = self.model_map[name] + if obs_type is PyBulletObservationType.BODY_POS: + t, q = self._client.getBasePositionAndOrientation(model_id) + data_obs += t + q + else: + v, w = self._client.getBaseVelocity(model_id) + if obs_type is PyBulletObservationType.BODY_LIN_VEL: + data_obs += v + else: + data_obs += w + elif obs_type is PyBulletObservationType.LINK_POS \ + or obs_type is PyBulletObservationType.LINK_LIN_VEL \ + or obs_type is PyBulletObservationType.LINK_ANG_VEL: + model_id, link_id = self.link_map[name] + + if obs_type is PyBulletObservationType.LINK_POS: + link_data = self._client.getLinkState(model_id, link_id) + t = link_data[0] + q = link_data[1] + data_obs += t + q + elif obs_type is PyBulletObservationType.LINK_LIN_VEL: + data_obs += self._client.getLinkState(model_id, link_id, computeLinkVelocity=True)[-2] + elif obs_type is PyBulletObservationType.LINK_ANG_VEL: + data_obs += self._client.getLinkState(model_id, link_id, computeLinkVelocity=True)[-1] + elif obs_type is PyBulletObservationType.JOINT_POS \ + or obs_type is PyBulletObservationType.JOINT_VEL: + model_id, joint_id = self.joint_map[name] + pos, vel, _, _ = self._client.getJointState(model_id, joint_id) + if obs_type is PyBulletObservationType.JOINT_POS: + data_obs.append(pos) + elif obs_type is PyBulletObservationType.JOINT_VEL: + data_obs.append(vel) + elif obs_type is PyBulletObservationType.CONTACT_FLAG: + contact = self._contacts.get_contact(name) + contact_flag = 0 if contact is None else 1 + data_obs.append(contact_flag) + + return np.array(data_obs) + + def apply_control(self, action): + + i = 0 + for model_id, joint_id, mode in self.action_data: + u = action[i] + if mode is pybullet.POSITION_CONTROL: + kwargs = dict(targetPosition=u, maxVelocity=self._client.getJointInfo(model_id, joint_id)[11]) + elif mode is pybullet.VELOCITY_CONTROL: + kwargs = dict(targetVelocity=u, maxVelocity=self._client.getJointInfo(model_id, joint_id)[11]) + elif mode is pybullet.TORQUE_CONTROL: + kwargs = dict(force=u) + else: + raise NotImplementedError + + self._client.setJointMotorControl2(model_id, joint_id, mode, **kwargs) + i += 1 + + def get_index(self, name, obs_type): + return self.observation_indices_map[name][obs_type] + + def _build_joint_and_link_maps(self): + for model_id in self.model_map.values(): + for joint_id in range(self._client.getNumJoints(model_id)): + joint_data = self._client.getJointInfo(model_id, joint_id) + + if joint_data[2] != pybullet.JOINT_FIXED: + joint_name = joint_data[1].decode('UTF-8') + self.joint_map[joint_name] = (model_id, joint_id) + link_name = joint_data[12].decode('UTF-8') + self.link_map[link_name] = (model_id, joint_id) + + def _process_actuation_spec(self, actuation_spec): + for name, mode in actuation_spec: + if name in self.joint_map: + data = self.joint_map[name] + (mode,) + self.action_data.append(data) + + low = list() + high = list() + + for model_id, joint_id, mode in self.action_data: + joint_info = self._client.getJointInfo(model_id, joint_id) + if mode is pybullet.POSITION_CONTROL: + low.append(joint_info[8]) + high.append(joint_info[9]) + elif mode is pybullet.VELOCITY_CONTROL: + low.append(-joint_info[11]) + high.append(joint_info[11]) + elif mode is pybullet.TORQUE_CONTROL: + low.append(-joint_info[10]) + high.append(joint_info[10]) + else: + raise NotImplementedError + + return np.array(low), np.array(high) + + def _process_observations(self): + low = list() + high = list() + + for name, obs_type in self.observation_map: + index_count = len(low) + if obs_type is PyBulletObservationType.BODY_POS \ + or obs_type is PyBulletObservationType.BODY_LIN_VEL \ + or obs_type is PyBulletObservationType.BODY_ANG_VEL: + n_dim = 7 if obs_type is PyBulletObservationType.BODY_POS else 3 + low += [-np.inf] * n_dim + high += [np.inf] * n_dim + elif obs_type is PyBulletObservationType.LINK_POS \ + or obs_type is PyBulletObservationType.LINK_LIN_VEL \ + or obs_type is PyBulletObservationType.LINK_ANG_VEL: + n_dim = 7 if obs_type is PyBulletObservationType.LINK_POS else 3 + low += [-np.inf] * n_dim + high += [np.inf] * n_dim + elif obs_type is PyBulletObservationType.JOINT_POS \ + or obs_type is PyBulletObservationType.JOINT_VEL: + model_id, joint_id = self.joint_map[name] + joint_info = self._client.getJointInfo(model_id, joint_id) + + if obs_type is PyBulletObservationType.JOINT_POS: + low.append(joint_info[8]) + high.append(joint_info[9]) + else: + max_joint_vel = joint_info[11] + low.append(-max_joint_vel) + high.append(max_joint_vel) + elif obs_type is PyBulletObservationType.CONTACT_FLAG: + low.append(0.) + high.append(1.) + else: + raise RuntimeError('Unsupported observation type') + + self._add_observation_index(name, obs_type, index_count, len(low)) + + return np.array(low), np.array(high) + + def _add_observation_index(self, name, obs_type, start, end): + if name not in self.observation_indices_map: + self.observation_indices_map[name] = dict() + + self.observation_indices_map[name][obs_type] = list(range(start, end)) + + @property + def observation_limits(self): + return self._observation_low, self._observation_high + + @property + def action_limits(self): + return self._action_low, self._action_high \ No newline at end of file diff --git a/mushroom_rl/utils/pybullet/joints_helper.py b/mushroom_rl/utils/pybullet/joints_helper.py new file mode 100644 index 00000000..7d096c77 --- /dev/null +++ b/mushroom_rl/utils/pybullet/joints_helper.py @@ -0,0 +1,44 @@ +import numpy as np + +from .observation import PyBulletObservationType + + +class JointsHelper(object): + def __init__(self, client, indexer, observation_spec): + self._joint_pos_indexes = list() + self._joint_velocity_indexes = list() + joint_limits_low = list() + joint_limits_high = list() + joint_velocity_limits = list() + for joint_name, obs_type in observation_spec: + joint_idx = indexer.get_index(joint_name, obs_type) + if obs_type == PyBulletObservationType.JOINT_VEL: + self._joint_velocity_indexes.append(joint_idx[0]) + + model_id, joint_id = indexer.joint_map[joint_name] + joint_info = client.getJointInfo(model_id, joint_id) + joint_velocity_limits.append(joint_info[11]) + + elif obs_type == PyBulletObservationType.JOINT_POS: + self._joint_pos_indexes.append(joint_idx[0]) + + model_id, joint_id = indexer.joint_map[joint_name] + joint_info = client.getJointInfo(model_id, joint_id) + joint_limits_low.append(joint_info[8]) + joint_limits_high.append(joint_info[9]) + + self._joint_limits_low = np.array(joint_limits_low) + self._joint_limits_high = np.array(joint_limits_high) + self._joint_velocity_limits = np.array(joint_velocity_limits) + + def positions(self, state): + return state[self._joint_pos_indexes] + + def velocities(self, state): + return state[self._joint_velocity_indexes] + + def limits(self): + return self._joint_limits_low, self._joint_limits_high + + def velocity_limits(self): + return self._joint_velocity_limits diff --git a/mushroom_rl/utils/pybullet/observation.py b/mushroom_rl/utils/pybullet/observation.py new file mode 100644 index 00000000..32082927 --- /dev/null +++ b/mushroom_rl/utils/pybullet/observation.py @@ -0,0 +1,19 @@ +from enum import Enum + + +class PyBulletObservationType(Enum): + """ + An enum indicating the type of data that should be added to the observation + of the environment, can be Joint-/Body-/Site- positions and velocities. + + """ + __order__ = "BODY_POS BODY_LIN_VEL BODY_ANG_VEL JOINT_POS JOINT_VEL LINK_POS LINK_LIN_VEL LINK_ANG_VEL CONTACT_FLAG" + BODY_POS = 0 + BODY_LIN_VEL = 1 + BODY_ANG_VEL = 2 + JOINT_POS = 3 + JOINT_VEL = 4 + LINK_POS = 5 + LINK_LIN_VEL = 6 + LINK_ANG_VEL = 7 + CONTACT_FLAG = 8 \ No newline at end of file diff --git a/mushroom_rl/utils/pybullet/viewer.py b/mushroom_rl/utils/pybullet/viewer.py new file mode 100644 index 00000000..d66ac72f --- /dev/null +++ b/mushroom_rl/utils/pybullet/viewer.py @@ -0,0 +1,41 @@ +import numpy as np +import pybullet +from mushroom_rl.utils.viewer import ImageViewer + + +class PyBulletViewer(ImageViewer): + def __init__(self, client, dt, size=(500, 500), distance=4, origin=(0, 0, 1), angles=(0, -45, 60), + fov=60, aspect=1, near_val=0.01, far_val=100): + self._client = client + self._size = size + self._distance = distance + self._origin = origin + self._angles = angles + self._fov = fov + self._aspect = aspect + self._near_val = near_val + self._far_val = far_val + super().__init__(size, dt) + + def display(self): + img = self._get_image() + super().display(img) + + def _get_image(self): + view_matrix = self._client.computeViewMatrixFromYawPitchRoll(cameraTargetPosition=self._origin, + distance=self._distance, + roll=self._angles[0], + pitch=self._angles[1], + yaw=self._angles[2], + upAxisIndex=2) + proj_matrix = self._client.computeProjectionMatrixFOV(fov=self._fov, aspect=self._aspect, + nearVal=self._near_val, farVal=self._far_val) + (_, _, px, _, _) = self._client.getCameraImage(width=self._size[0], + height=self._size[1], + viewMatrix=view_matrix, + projectionMatrix=proj_matrix, + renderer=pybullet.ER_BULLET_HARDWARE_OPENGL) + + rgb_array = np.reshape(np.array(px), (self._size[0], self._size[1], -1)) + rgb_array = rgb_array[:, :, :3] + return rgb_array diff --git a/tests/algorithms/test_policy_gradient.py b/tests/algorithms/test_policy_gradient.py index bdcc5796..9e413a66 100644 --- a/tests/algorithms/test_policy_gradient.py +++ b/tests/algorithms/test_policy_gradient.py @@ -72,7 +72,7 @@ def test_REINFORCE_save(tmpdir): def test_GPOMDP(): params = dict(optimizer=AdaptiveOptimizer(eps=.01)) policy = learn(GPOMDP, params).policy - w = np.array([-0.07623939, 2.05232858]) + w = np.array([-0.11457566, 1.99784316]) assert np.allclose(w, policy.get_weights())