Skip to content

Commit

Permalink
Merge branch 'master' into ttt_update
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottower authored May 3, 2024
2 parents da8373e + 38e2520 commit 44716c3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
18 changes: 16 additions & 2 deletions tutorials/SB3/connect_four/sb3_connect_four_action_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,23 @@ def reset(self, seed=None, options=None):
return self.observe(self.agent_selection), {}

def step(self, action):
"""Gymnasium-like step function, returning observation, reward, termination, truncation, info."""
"""Gymnasium-like step function, returning observation, reward, termination, truncation, info.
The observation is for the next agent (used to determine the next action), while the remaining
items are for the agent that just acted (used to understand what just happened).
"""
current_agent = self.agent_selection

super().step(action)
return super().last()

next_agent = self.agent_selection
return (
self.observe(next_agent),
self._cumulative_rewards[current_agent],
self.terminations[current_agent],
self.truncations[current_agent],
self.infos[current_agent],
)

def observe(self, agent):
"""Return only raw observation, removing action mask."""
Expand Down
9 changes: 3 additions & 6 deletions tutorials/SB3/test/test_sb3_action_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
EASY_ENVS = [
gin_rummy_v4,
texas_holdem_no_limit_v6, # texas holdem human rendered game ends instantly, but with random actions it works fine
texas_holdem_v4,
tictactoe_v3,
leduc_holdem_v4,
]

# More difficult environments which will likely take more training time
MEDIUM_ENVS = [
leduc_holdem_v4, # with 10x as many steps it gets higher total rewards (9 vs -9), 0.52 winrate, and 0.92 vs 0.83 total scores
hanabi_v5, # even with 10x as many steps, total score seems to always be tied between the two agents
texas_holdem_v4, # this performs poorly with updates to SB3 wrapper
chess_v6, # difficult to train because games take so long, performance varies heavily
]

Expand All @@ -50,10 +50,7 @@ def test_action_mask_easy(env_fn):

env_kwargs = {}

steps = 8192
# These take slightly longer to outperform random
if env_fn in [leduc_holdem_v4, tictactoe_v3]:
steps *= 4
steps = 8192 * 4

# Train a model against itself (takes ~2 minutes on GPU)
train_action_mask(env_fn, steps=steps, seed=0, **env_kwargs)
Expand Down

0 comments on commit 44716c3

Please sign in to comment.