From 5449310559b2cf2947cfe3a195e497c948919075 Mon Sep 17 00:00:00 2001 From: Dylan Asmar <91484811+dylan-asmar@users.noreply.github.com> Date: Sun, 12 Jan 2025 17:48:01 -0800 Subject: [PATCH] Updates to fix issues associated with Flux 0.14 (#76) * updates to fix issues associated with Flux 0.14 * version bump --- .gitignore | 3 ++- Project.toml | 2 +- README.md | 24 +++++++++++------------ src/dueling.jl | 2 +- src/solver.jl | 19 ++++++++++--------- test/README_examples.jl | 42 +++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 5 +++++ 7 files changed, 73 insertions(+), 24 deletions(-) create mode 100644 test/README_examples.jl diff --git a/.gitignore b/.gitignore index 1496a5e..84ede71 100755 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ log* *.bson events.out.tfevents* .vscode -Manifest.toml \ No newline at end of file +Manifest.toml +.DS_Store \ No newline at end of file diff --git a/Project.toml b/Project.toml index 2ce03c3..e034d14 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DeepQLearning" uuid = "de0a67f4-c691-11e8-0034-5fc6e16e22d3" repo = "https://github.com/JuliaPOMDP/DeepQLearning.jl" -version = "0.7.1" +version = "0.7.2" [deps] BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" diff --git a/README.md b/README.md index c86a089..6b60450 100755 --- a/README.md +++ b/README.md @@ -27,7 +27,6 @@ using DeepQLearning using POMDPs using Flux using POMDPModels -using POMDPSimulators using POMDPTools # load MDP model from POMDPModels or define your own! @@ -37,7 +36,7 @@ mdp = SimpleGridWorld(); # the gridworld state is represented by a 2 dimensional vector. model = Chain(Dense(2, 32), Dense(32, length(actions(mdp)))) -exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2)) +exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2)); solver = DeepQLearningSolver(qnetwork = model, max_steps=10000, exploration_policy = exploration, @@ -99,9 +98,12 @@ mdp = SimpleGridWorld(); # the model weights will be send to the gpu in the call to solve model = Chain(Dense(2, 32), Dense(32, length(actions(mdp)))) -solver = DeepQLearningSolver(qnetwork = model, max_steps=10000, - learning_rate=0.005,log_freq=500, - recurrence=false,double_q=true, dueling=true, prioritized_replay=true) +exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2)); + +solver = DeepQLearningSolver(qnetwork=model, max_steps=10000, + exploration_policy=exploration, + learning_rate=0.005,log_freq=500, + recurrence=false,double_q=true, dueling=true, prioritized_replay=true) policy = solve(solver, mdp) ``` @@ -109,19 +111,18 @@ policy = solve(solver, mdp) **Fields of the Q Learning solver:** - `qnetwork::Any = nothing` Specify the architecture of the Q network +- `exploration_policy::= saved_mean_reward - bson(joinpath(solver.logdir, "qnetwork.bson"), qnetwork=[w for w in Flux.params(active_q)]) + copied_model = deepcopy(active_q) + Flux.reset!(copied_model) + bson(joinpath(solver.logdir, "qnetwork_state.bson"), qnetwork_state=Flux.state(copied_model)) if solver.verbose @printf("Saving new model with eval reward %1.3f \n", scores_eval) end @@ -311,8 +312,8 @@ function restore_best_model(solver::DeepQLearningSolver, env::AbstractEnv) active_q = solver.qnetwork end policy = NNPolicy(env, active_q, collect(actions(env)), length(obs_dimensions(env))) - weights = BSON.load(solver.logdir*"qnetwork.bson")[:qnetwork] - Flux.loadparams!(getnetwork(policy), weights) + saved_network_state = BSON.load(solver.logdir*"qnetwork_state.bson")[:qnetwork_state] + Flux.loadmodel!(getnetwork(policy), saved_network_state) Flux.testmode!(getnetwork(policy)) return policy end diff --git a/test/README_examples.jl b/test/README_examples.jl new file mode 100644 index 0000000..6590b5e --- /dev/null +++ b/test/README_examples.jl @@ -0,0 +1,42 @@ +using DeepQLearning +using POMDPs +using Flux +using POMDPModels +using POMDPTools + +@testset "README Example 1" begin + # load MDP model from POMDPModels or define your own! + mdp = SimpleGridWorld(); + + # Define the Q network (see Flux.jl documentation) + # the gridworld state is represented by a 2 dimensional vector. + model = Chain(Dense(2, 32), Dense(32, length(actions(mdp)))) + + exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2)); + + solver = DeepQLearningSolver(qnetwork = model, max_steps=10000, + exploration_policy = exploration, + learning_rate=0.005,log_freq=500, + recurrence=false,double_q=true, dueling=true, prioritized_replay=true) + policy = solve(solver, mdp) + + sim = RolloutSimulator(max_steps=30) + r_tot = simulate(sim, mdp, policy) + println("Total discounted reward for 1 simulation: $r_tot") +end + +@testset "README Example 2" begin + # Without using CuArrays + mdp = SimpleGridWorld(); + + # the model weights will be send to the gpu in the call to solve + model = Chain(Dense(2, 32), Dense(32, length(actions(mdp)))) + + exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2)); + + solver = DeepQLearningSolver(qnetwork=model, max_steps=10000, + exploration_policy=exploration, + learning_rate=0.005,log_freq=500, + recurrence=false,double_q=true, dueling=true, prioritized_replay=true) + policy = solve(solver, mdp) +end diff --git a/test/runtests.jl b/test/runtests.jl index 262d626..2c1df14 100755 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -232,3 +232,8 @@ end @test evaluate(env, policy, GLOBAL_RNG) > 1.0 end + + +@testset "README Examples" begin + include("README_examples.jl") +end