Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Oct 3, 2023
1 parent 69d5343 commit d9b4f04
Showing 1 changed file with 22 additions and 27 deletions.
49 changes: 22 additions & 27 deletions tutorials/sphinx-tutorials/multiagent_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,11 @@
#
#

print("action_spec:", env.action_spec)
print("reward_spec:", env.reward_spec)
print("done_spec:", env.done_spec)
print("action_spec:", env.full_action_spec)
print("reward_spec:", env.full_reward_spec)
print("done_spec:", env.full_done_spec)
print("observation_spec:", env.observation_spec)


######################################################################
# Using the commands just shown we can access the domain of each value.
# Doing this we can see that all specs apart from done have a leading shape ``(num_vmas_envs, n_agents)``.
Expand All @@ -270,35 +269,20 @@
# In fact, specs that have the additional agent dimension
# (i.e., they vary for each agent) will be contained in a inner "agents" key.
#
# To access the full structure of the specs we can use
#

print("full_action_spec:", env.input_spec["full_action_spec"])
print("full_reward_spec:", env.output_spec["full_reward_spec"])
print("full_done_spec:", env.output_spec["full_done_spec"])

######################################################################
# As you can see the reward and action spec present the "agent" key,
# meaning that entries in tensordicts belonging to those specs will be nested in an "agents" tensordict,
# grouping all per-agent values.
#
# To quickly access the key for each of these values in tensordicts, we can simply ask the environment for the
# respective key, and
# To quickly access the keys for each of these values in tensordicts, we can simply ask the environment for the
# respective keys, and
# we will immediately understand which are per-agent and which shared.
# This info will be useful in order to tell all other TorchRL components where to find each value
#

print("action_key:", env.action_key)
print("reward_key:", env.reward_key)
print("done_key:", env.done_key)
print("action_key:", env.action_keys)
print("reward_key:", env.reward_keys)
print("done_key:", env.done_keys)

######################################################################
# To tie it all together, we can see that passing these keys to the full specs gives us the leaf domains
#

assert env.action_spec == env.input_spec["full_action_spec"][env.action_key]
assert env.reward_spec == env.output_spec["full_reward_spec"][env.reward_key]
assert env.done_spec == env.output_spec["full_done_spec"][env.done_key]

######################################################################
# Transforms
Expand Down Expand Up @@ -615,6 +599,9 @@
action=env.action_key,
sample_log_prob=("agents", "sample_log_prob"),
value=("agents", "state_value"),
# These last 2 keys will be expanded to match the reward shape
done=("agents", "done"),
terminated=("agents", "terminated"),
)


Expand Down Expand Up @@ -649,11 +636,19 @@
episode_reward_mean_list = []
for tensordict_data in collector:
tensordict_data.set(
("next", "done"),
("next", "agents", "done"),
tensordict_data.get(("next", "done"))
.unsqueeze(-1)
.expand(tensordict_data.get(("next", env.reward_key)).shape),
) # We need to expand the done to match the reward shape (this is expected by the value estimator)
.expand(tensordict_data.get_item_shape(("next", env.reward_key))),
)
tensordict_data.set(
("next", "agents", "terminated"),
tensordict_data.get(("next", "terminated"))
.unsqueeze(-1)
.expand(tensordict_data.get_item_shape(("next", env.reward_key))),
)

# We need to expand the done to match the reward shape (this is expected by the value estimator)

with torch.no_grad():
GAE(
Expand Down

0 comments on commit d9b4f04

Please sign in to comment.