Skip to content

Commit

Permalink
Polish GMPG; Add visualization for generated variables; Polish simula…
Browse files Browse the repository at this point in the history
…tors.
  • Loading branch information
zjowowen committed Sep 14, 2024
1 parent 49be438 commit 50a5948
Show file tree
Hide file tree
Showing 4 changed files with 327 additions and 14 deletions.
2 changes: 1 addition & 1 deletion grl/algorithms/gmpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def policy_gradient_loss_by_REINFORCE(
using_Hutchinson_trace_estimator=True,
)
bits_ratio = torch.prod(
torch.tensor(state_repeated.shape[1], device=state.device)
torch.tensor(action_repeated.shape[1], device=state.device)
) * torch.log(torch.tensor(2.0, device=state.device))
log_p_per_dim = log_p / bits_ratio
log_mu = compute_likelihood(
Expand Down
225 changes: 212 additions & 13 deletions grl/rl_modules/simulators/dm_control_suite_env_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,26 @@
import numpy as np
import torch


def partial_observation_rodent(obs_dict):
# Define the keys you want to keep
keys_to_keep = [
'walker/joints_pos',
'walker/joints_vel',
'walker/tendons_pos',
'walker/tendons_vel',
'walker/appendages_pos',
'walker/world_zaxis',
'walker/sensors_accelerometer',
'walker/sensors_velocimeter',
'walker/sensors_gyro',
'walker/sensors_touch',
'walker/egocentric_camera'
]
# Filter the observation dictionary to only include the specified keys
filtered_obs = {key: obs_dict[key] for key in keys_to_keep if key in obs_dict}
return filtered_obs

class DeepMindControlEnvSimulator:
"""
Overview:
Expand All @@ -14,20 +34,43 @@ class DeepMindControlEnvSimulator:
``__init__``, ``collect_episodes``, ``collect_steps``, ``evaluate``
"""

def __init__(self, domain_name: str,task_name: str) -> None:
def __init__(
self,
domain_name: str,
task_name: str,
dict_return=True
) -> None:
"""
Overview:
Initialize the DeepMindControlEnvSimulator according to the given configuration.
Arguments:
env_id (:obj:`str`): The id of the gym environment to simulate.
domain_name (:obj:`str`): The domain name of the environment.
task_name (:obj:`str`): The task name of the environment.
dict_return (:obj:`bool`): Whether to return the observation as a dictionary.
"""
from dm_control import suite
self.env_domain_name = domain_name
self.task_name=task_name
self.collect_env = suite.load(domain_name, task_name)
# self.observation_space = self.collect_env.observation_space
self.action_space = self.collect_env.action_spec()
if domain_name == "rodent" and task_name == "gaps":
import os
os.environ['MUJOCO_EGL_DEVICE_ID'] = '0' #we make it for 8 gpus
from dm_control import composer
from dm_control.locomotion.examples import basic_rodent_2020
self.domain_name = domain_name
self.task_name=task_name
self.collect_env=basic_rodent_2020.rodent_run_gaps()
self.action_space = self.collect_env.action_spec()
self.partial_observation=True
self.partial_observation_fn=partial_observation_rodent
else:
from dm_control import suite
self.domain_name = domain_name
self.task_name=task_name
self.collect_env = suite.load(domain_name, task_name)
self.action_space = self.collect_env.action_spec()
self.partial_observation=False

self.last_state_obs = self.collect_env.reset().observation
self.last_state_done = False
self.dict_return=dict_return

def collect_episodes(
self,
policy: Union[Callable, torch.nn.Module],
Expand Down Expand Up @@ -56,6 +99,25 @@ def collect_episodes(
next_obs = time_step.observation
reward = time_step.reward
done = time_step.last()
if not self.dict_return:
obs_values = []
next_obs_values = []
for key, value in obs.items():
if isinstance(value, np.ndarray):
if value.ndim == 3 and value.shape[0] == 1:
value = value.reshape(1, -1)
elif np.isscalar(value):
value = [value]
obs_values.append(value)
for key, value in next_obs.items():
if isinstance(value, np.ndarray):
if value.ndim == 3 and value.shape[0] == 1:
value = value.reshape(1, -1)
elif np.isscalar(value):
value = [value]
next_obs_values.append(value)
obs = np.concatenate(obs_values, axis=0)
next_obs = np.concatenate(next_obs_values, axis=0)
data_list.append(
dict(
obs=obs,
Expand All @@ -67,8 +129,49 @@ def collect_episodes(
)
obs = next_obs
return data_list


elif num_steps is not None:
data_list = []
with torch.no_grad():
for i in range(num_steps):
obs = self.collect_env.reset().observation
done = False
while not done:
action = policy(obs)
time_step = self.collect_env.step(action)
next_obs = time_step.observation
reward = time_step.reward
done = time_step.last()
if not self.dict_return:
obs_values = []
next_obs_values = []
for key, value in obs.items():
if isinstance(value, np.ndarray):
if value.ndim == 3 and value.shape[0] == 1:
value = value.reshape(1, -1)
elif np.isscalar(value):
value = [value]
obs_values.append(value)
for key, value in next_obs.items():
if isinstance(value, np.ndarray):
if value.ndim == 3 and value.shape[0] == 1:
value = value.reshape(1, -1)
elif np.isscalar(value):
value = [value]
next_obs_values.append(value)
obs = np.concatenate(obs_values, axis=0)
next_obs = np.concatenate(next_obs_values, axis=0)
data_list.append(
dict(
obs=obs,
action=action,
reward=reward,
done=done,
next_obs=next_obs,
)
)
obs = next_obs
return data_list

def collect_steps(
self,
policy: Union[Callable, torch.nn.Module],
Expand Down Expand Up @@ -104,19 +207,96 @@ def collect_steps(
next_obs = time_step.observation
reward = time_step.reward
done = time_step.last()
if not self.dict_return:
obs_values = []
next_obs_values = []
for key, value in self.last_state_obs.items():
if isinstance(value, np.ndarray):
if value.ndim == 3 and value.shape[0] == 1:
value = value.reshape(1, -1)
elif np.isscalar(value):
value = [value]
obs_values.append(value)
for key, value in next_obs.items():
if isinstance(value, np.ndarray):
if value.ndim == 3 and value.shape[0] == 1:
value = value.reshape(1, -1)
elif np.isscalar(value):
value = [value]
next_obs_values.append(value)
obs_flatten = np.concatenate(obs_values, axis=0)
next_obs_flatten = np.concatenate(next_obs_values, axis=0)
data_list.append(
dict(
obs=obs,
obs=obs_flatten,
action=action,
reward=reward,
done=done,
next_obs=next_obs,
next_obs=next_obs_flatten,
)
)
obs = next_obs
self.last_state_obs = self.collect_env.reset().observation
self.last_state_done = False
return data_list
elif num_steps is not None:
data_list = []
with torch.no_grad():
for i in range(num_steps):
if self.last_state_done:
self.last_state_obs = self.collect_env.reset().observation
self.last_state_done = False
if random_policy:
action = np.random.uniform(self.action_space.minimum,
self.action_space.maximum,
size=self.action_space.shape)
else:
if not self.dict_return:
obs_values = []
for key, value in self.last_state_obs.items():
if isinstance(value, np.ndarray):
if value.ndim == 3 and value.shape[0] == 1:
value = value.reshape(1, -1)
elif np.isscalar(value):
value = [value]
obs_values.append(value)
obs = np.concatenate(obs_values, axis=0)
action = policy(obs)
time_step = self.collect_env.step(action)
next_obs = time_step.observation
reward = time_step.reward
done = time_step.last()
if not self.dict_return:
obs_values = []
next_obs_values = []
for key, value in self.last_state_obs.items():
if isinstance(value, np.ndarray):
if value.ndim == 3 and value.shape[0] == 1:
value = value.reshape(1, -1)
elif np.isscalar(value):
value = [value]
obs_values.append(value)
for key, value in next_obs.items():
if isinstance(value, np.ndarray):
if value.ndim == 3 and value.shape[0] == 1:
value = value.reshape(1, -1)
elif np.isscalar(value):
value = [value]
next_obs_values.append(value)
obs_flatten = np.concatenate(obs_values, axis=0)
next_obs_flatten = np.concatenate(next_obs_values, axis=0)
data_list.append(
dict(
obs=obs_flatten,
action=action,
reward=reward,
done=done,
next_obs=next_obs_flatten,
)
)
self.last_state_obs = next_obs
self.last_state_done = done
return data_list


def evaluate(
Expand Down Expand Up @@ -149,7 +329,14 @@ def render_env(env, render_args):
return render_output

eval_results = []
env = suite.load(self.env_domain_name, self.task_name)
if self.domain_name == "rodent" and self.task_name == "gaps":
import os
os.environ['MUJOCO_EGL_DEVICE_ID'] = '0'
from dm_control import composer
from dm_control.locomotion.examples import basic_rodent_2020
env=basic_rodent_2020.rodent_run_gaps()
else:
env = suite.load(self.domain_name, self.task_name)
for i in range(num_episodes):
if render:
render_output = []
Expand All @@ -163,6 +350,18 @@ def render_env(env, render_args):
done = False
action_spec = env.action_spec()
while not done:
if self.partial_observation:
obs = self.partial_observation_fn(obs)
if not self.dict_return:
obs_values = []
for key, value in obs.items():
if isinstance(value, np.ndarray):
if value.ndim == 2 :
value = value.reshape(-1)
elif np.isscalar(value):
value = [value]
obs_values.append(value)
obs = np.concatenate(obs_values, axis=0)
action = policy(obs)
time_step = env.step(action)
next_obs = time_step.observation
Expand Down
53 changes: 53 additions & 0 deletions grl/unittest/utils/test_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import unittest
import os
import numpy as np
from grl.utils.plot import plot_distribution

class TestPlotDistribution(unittest.TestCase):

def setUp(self):
"""
Set up the test environment. This runs before each test.
"""
# Sample data for testing
self.B = 1000 # Number of samples
self.N = 4 # Number of features
self.data = np.random.randn(self.B, self.N) # Random data for demonstration
self.save_path = "test_distribution_plot.png" # Path to save test plot

def tearDown(self):
"""
Clean up after the test. This runs after each test.
"""
# Remove the plot file after the test if it was created
if os.path.exists(self.save_path):
os.remove(self.save_path)

def test_plot_creation(self):
"""
Test if the plot is created and saved to the specified path.
"""
# Call the plot_distribution function
plot_distribution(self.data, self.save_path)

# Check if the file was created
self.assertTrue(os.path.exists(self.save_path), "The plot file was not created.")

# Verify the file is not empty
self.assertGreater(os.path.getsize(self.save_path), 0, "The plot file is empty.")

def test_plot_size(self):
"""
Test if the plot can be saved with a specified size and DPI.
"""
size = (8, 8)
dpi = 300

# Call the plot_distribution function with a custom size and DPI
plot_distribution(self.data, self.save_path, size=size, dpi=dpi)

# Check if the file was created
self.assertTrue(os.path.exists(self.save_path), "The plot file was not created.")

# Verify the file is not empty
self.assertGreater(os.path.getsize(self.save_path), 0, "The plot file is empty.")
Loading

0 comments on commit 50a5948

Please sign in to comment.