Skip to content

Commit

Permalink
add Snake environment (#123)
Browse files Browse the repository at this point in the history
* add DataStructures package

* import DataStructures

* add Snake environment

* add tests for Snake

* add gif for Snake
  • Loading branch information
Sid-Bhatia-0 authored Jan 31, 2021
1 parent dedfa11 commit a714045
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 2 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@ version = "0.3.0"

[deps]
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"

[compat]
Crayons = "4.0"
DataStructures = "0.18"
MacroTools = "0.5"
Requires = "1.1"
ReinforcementLearningBase = "0.9"
Requires = "1.1"
julia = "1"

[extras]
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,7 @@ The behaviour of environments is easily customizable. Here are some of the thing
DirectedNavigation | UndirectedNavigation
------------ | -------------
<img src="https://github.com/JuliaReinforcementLearning/GridWorlds.jl/raw/master/docs/src/assets/img/sokoban_directed.gif" width="300px"> | <img src="https://github.com/JuliaReinforcementLearning/GridWorlds.jl/raw/master/docs/src/assets/img/sokoban_undirected.gif" width="300px">

1. ### Snake

<img src="https://github.com/JuliaReinforcementLearning/GridWorlds.jl/raw/master/docs/src/assets/img/snake.gif" width="300px">
Binary file added docs/src/assets/img/snake.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions src/GridWorlds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import Requires
import Crayons
import MacroTools:@forward
import Random
import DataStructures
const DS = DataStructures
import ReinforcementLearningBase
import ReinforcementLearningBase:RLBase

Expand Down
1 change: 1 addition & 0 deletions src/envs/envs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ include("door_key.jl")
include("collect_gems.jl")
include("dynamic_obstacles.jl")
include("sokoban/sokoban.jl")
include("snake.jl")
123 changes: 123 additions & 0 deletions src/envs/snake.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
export Snake

mutable struct Snake{R} <: AbstractGridWorld
world::GridWorldBase{Tuple{Empty, Wall, Body, Food}}
agent::Agent
reward::Float64
rng::R
terminated::Bool
food_reward::Float64
food_pos::CartesianIndex{2}
end

function Snake(; height = 8, width = 8, rng = Random.GLOBAL_RNG)
objects = (EMPTY, WALL, BODY, FOOD)
world = GridWorldBase(objects, height, width)
room = Room(CartesianIndex(1, 1), height, width)
place_room!(world, room)

food_pos = CartesianIndex(height - 1, width - 1)
world[FOOD, food_pos] = true
world[EMPTY, food_pos] = false

agent_start_pos = CartesianIndex(2, 2)
body = DS.Queue{CartesianIndex{2}}()
DS.enqueue!(body, agent_start_pos)
world[BODY, agent_start_pos] = true
world[EMPTY, agent_start_pos] = false
agent_start_dir = CENTER
agent = Agent(inventory_type = typeof(body), inventory = body, pos = agent_start_pos, dir = agent_start_dir)

reward = 0.0
terminated = false
food_reward = 1.0

env = Snake(world, agent, reward, rng, terminated, food_reward, food_pos)

RLBase.reset!(env)

return env
end

RLBase.StateStyle(env::Snake) = RLBase.InternalState{Any}()
get_navigation_style(::Snake) = UNDIRECTED_NAVIGATION

RLBase.is_terminated(env::Snake) = env.terminated

function (env::Snake)(action::AbstractMoveAction)
world = get_world(env)
rng = get_rng(env)
body = get_inventory(env)

dest = move(action, get_agent_dir(env), get_agent_pos(env))

if world[EMPTY, dest]
set_agent_pos!(env, dest)

DS.enqueue!(body, dest)
world[BODY, dest] = true
world[EMPTY, dest] = false

last_pos = DS.dequeue!(body)
world[BODY, last_pos] = false
world[EMPTY, last_pos] = true

set_reward!(env, 0.0)

elseif world[FOOD, dest]
set_agent_pos!(env, dest)

DS.enqueue!(body, dest)
world[BODY, dest] = true
world[FOOD, dest] = false

food_pos = rand(rng, pos -> world[EMPTY, pos], env)
env.food_pos = food_pos
world[FOOD, food_pos] = true
world[EMPTY, food_pos] = false

set_reward!(env, env.food_reward)

else
env.terminated = true
set_reward!(env, 0.0)

end

return env
end

function RLBase.reset!(env::Snake)
world = get_world(env)
rng = get_rng(env)
body = get_inventory(env)

old_food_pos = env.food_pos
world[FOOD, old_food_pos] = false
world[EMPTY, old_food_pos] = true

for i in 1:length(body)
pos = DS.dequeue!(body)
world[BODY, pos] = false
world[EMPTY, pos] = true
end

new_food_pos = rand(rng, pos -> world[EMPTY, pos], env)
env.food_pos = new_food_pos
world[FOOD, new_food_pos] = true
world[EMPTY, new_food_pos] = false

agent_start_pos = rand(rng, pos -> world[EMPTY, pos], env)
world[BODY, agent_start_pos] = true
world[EMPTY, agent_start_pos] = false
DS.enqueue!(body, agent_start_pos)
agent_start_dir = get_agent_start_dir(env)

set_agent_pos!(env, agent_start_pos)
set_agent_dir!(env, agent_start_dir)

set_reward!(env, 0.0)
env.terminated = false

return env
end
10 changes: 10 additions & 0 deletions src/objects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ const TARGET = Target()
get_char(::Target) = ''
get_color(::Target) = :red

struct Body <: AbstractObject end
const BODY = Body()
get_char(::Body) = ''
get_color(::Body) = :green

struct Food <: AbstractObject end
const FOOD = Food()
get_char(::Food) = ''
get_color(::Food) = :yellow

#####
# Agent
#####
Expand Down
6 changes: 5 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test
using Random
using ReinforcementLearningBase

ENVS = [EmptyRoom, GridRooms, SequentialRooms, GoToDoor, DoorKey, CollectGems, DynamicObstacles, Sokoban]
ENVS = [EmptyRoom, GridRooms, SequentialRooms, GoToDoor, DoorKey, CollectGems, DynamicObstacles, Sokoban, Snake]

MAX_STEPS = 3000
NUM_RESETS = 3
Expand All @@ -16,10 +16,14 @@ get_terminal_rewards(env::DoorKey) = (env.terminal_reward,)
get_terminal_rewards(env::CollectGems) = (env.num_gem_init * env.gem_reward,)
get_terminal_rewards(env::DynamicObstacles) = (env.terminal_reward, env.terminal_penalty)
get_terminal_rewards(env::Sokoban) = (Float64(length(env.box_pos)),)
get_terminal_rewards(env::Snake) = zero(env.food_reward):one(env.food_reward):GW.get_height(env)*GW.get_width(env)*one(env.food_reward)

@testset "GridWorlds.jl" begin
for Env in ENVS
for NAVIGATION in (GridWorlds.DIRECTED_NAVIGATION, GridWorlds.UNDIRECTED_NAVIGATION)
if Env == Snake && NAVIGATION == GridWorlds.DIRECTED_NAVIGATION
continue
end
GridWorlds.get_navigation_style(::Env) = NAVIGATION
@testset "$(NAVIGATION) $(Env)" begin
env = Env()
Expand Down

0 comments on commit a714045

Please sign in to comment.