Skip to content

Commit

Permalink
Updates to fix issues associated with Flux 0.14 (#76)
Browse files Browse the repository at this point in the history
* updates to fix issues associated with Flux 0.14

* version bump
  • Loading branch information
dylan-asmar authored Jan 13, 2025
1 parent bdaa6cb commit 5449310
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 24 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ log*
*.bson
events.out.tfevents*
.vscode
Manifest.toml
Manifest.toml
.DS_Store
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DeepQLearning"
uuid = "de0a67f4-c691-11e8-0034-5fc6e16e22d3"
repo = "https://github.com/JuliaPOMDP/DeepQLearning.jl"
version = "0.7.1"
version = "0.7.2"

[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Expand Down
24 changes: 12 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ using DeepQLearning
using POMDPs
using Flux
using POMDPModels
using POMDPSimulators
using POMDPTools

# load MDP model from POMDPModels or define your own!
Expand All @@ -37,7 +36,7 @@ mdp = SimpleGridWorld();
# the gridworld state is represented by a 2 dimensional vector.
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))

exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2))
exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2));

solver = DeepQLearningSolver(qnetwork = model, max_steps=10000,
exploration_policy = exploration,
Expand Down Expand Up @@ -99,39 +98,40 @@ mdp = SimpleGridWorld();
# the model weights will be send to the gpu in the call to solve
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))

solver = DeepQLearningSolver(qnetwork = model, max_steps=10000,
learning_rate=0.005,log_freq=500,
recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2));

solver = DeepQLearningSolver(qnetwork=model, max_steps=10000,
exploration_policy=exploration,
learning_rate=0.005,log_freq=500,
recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
policy = solve(solver, mdp)
```

## Solver Options

**Fields of the Q Learning solver:**
- `qnetwork::Any = nothing` Specify the architecture of the Q network
- `exploration_policy::<ExplorationPolicy` Exploration strategy (e.g. EpsGreedyPolicy)
- `learning_rate::Float64 = 1e-4` learning rate
- `max_steps::Int64` total number of training step default = 1000
- `target_update_freq::Int64` frequency at which the target network is updated default = 500
- `batch_size::Int64` batch size sampled from the replay buffer default = 32
- `train_freq::Int64` frequency at which the active network is updated default = 4
- `log_freq::Int64` frequency at which to logg info default = 100
- `eval_freq::Int64` frequency at which to eval the network default = 100
- `target_update_freq::Int64` frequency at which the target network is updated default = 500
- `num_ep_eval::Int64` number of episodes to evaluate the policy default = 100
- `eps_fraction::Float64` fraction of the training set used to explore default = 0.5
- `eps_end::Float64` value of epsilon at the end of the exploration phase default = 0.01
- `double_q::Bool` double q learning udpate default = true
- `dueling::Bool` dueling structure for the q network default = true
- `recurrence::Bool = false` set to true to use DRQN, it will throw an error if you set it to false and pass a recurrent model.
- `evaluation_policy::Function = basic_evaluation` function use to evaluate the policy every `eval_freq` steps, the default is a rollout that return the undiscounted average reward
- `prioritized_replay::Bool` enable prioritized experience replay default = true
- `prioritized_replay_alpha::Float64` default = 0.6
- `prioritized_replay_epsilon::Float64` default = 1e-6
- `prioritized_replay_beta::Float64` default = 0.4
- `buffer_size::Int64` size of the experience replay buffer default = 1000
- `max_episode_length::Int64` maximum length of a training episode default = 100
- `train_start::Int64` number of steps used to fill in the replay buffer initially default = 200
- `save_freq::Int64` save the model every `save_freq` steps, default = 1000
- `evaluation_policy::Function = basic_evaluation` function use to evaluate the policy every `eval_freq` steps, the default is a rollout that return the undiscounted average reward
- `exploration_policy::Any = linear_epsilon_greedy(max_steps, eps_fraction, eps_end)` exploration strategy (default is epsilon greedy with linear decay)
- `rng::AbstractRNG` random number generator default = MersenneTwister(0)
- `logdir::String = ""` folder in which to save the model
- `save_freq::Int64` save the model every `save_freq` steps, default = 1000
- `log_freq::Int64` frequency at which to logg info default = 100
- `verbose::Bool` default = true
2 changes: 1 addition & 1 deletion src/dueling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function (m::DuelingNetwork)(inpt)
return m.val(x) .+ m.adv(x) .- mean(m.adv(x), dims=1)
end

Flux.@functor DuelingNetwork
Flux.@layer DuelingNetwork

function Flux.reset!(m::DuelingNetwork)
Flux.reset!(m.base)
Expand Down
19 changes: 10 additions & 9 deletions src/solver.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@with_kw mutable struct DeepQLearningSolver{E<:ExplorationPolicy} <: Solver
qnetwork::Any = nothing # intended to be a flux model
exploration_policy::E # No default since 9ac3ab
learning_rate::Float32 = 1f-4
max_steps::Int64 = 1000
batch_size::Int64 = 32
Expand All @@ -11,7 +12,6 @@
dueling::Bool = true
recurrence::Bool = false
evaluation_policy::Any = basic_evaluation
exploration_policy::E
trace_length::Int64 = 40
prioritized_replay::Bool = true
prioritized_replay_alpha::Float32 = 0.6f0
Expand Down Expand Up @@ -139,9 +139,8 @@ function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnv, policy::Abstr
sethiddenstates!(active_q, hs)
end

if t%solver.target_update_freq == 0
weights = Flux.params(active_q)
Flux.loadparams!(target_q, weights)
if t % solver.target_update_freq == 0
target_q = deepcopy(active_q)
end

if t % solver.eval_freq == 0
Expand Down Expand Up @@ -170,9 +169,9 @@ function dqn_train!(solver::DeepQLearningSolver, env::AbstractEnv, policy::Abstr
if model_saved
if solver.verbose
@printf("Restore model with eval reward %1.3f \n", saved_mean_reward)
saved_model = BSON.load(joinpath(solver.logdir, "qnetwork.bson"))[:qnetwork]
Flux.loadparams!(getnetwork(policy), saved_model)
end
saved_model_state = BSON.load(joinpath(solver.logdir, "qnetwork_state.bson"))[:qnetwork_state]
Flux.loadmodel!(policy.qnetwork, saved_model_state)
end
return policy
end
Expand Down Expand Up @@ -289,7 +288,9 @@ end

function save_model(solver::DeepQLearningSolver, active_q, scores_eval::Float64, saved_mean_reward::Float64, model_saved::Bool)
if scores_eval >= saved_mean_reward
bson(joinpath(solver.logdir, "qnetwork.bson"), qnetwork=[w for w in Flux.params(active_q)])
copied_model = deepcopy(active_q)
Flux.reset!(copied_model)
bson(joinpath(solver.logdir, "qnetwork_state.bson"), qnetwork_state=Flux.state(copied_model))
if solver.verbose
@printf("Saving new model with eval reward %1.3f \n", scores_eval)
end
Expand All @@ -311,8 +312,8 @@ function restore_best_model(solver::DeepQLearningSolver, env::AbstractEnv)
active_q = solver.qnetwork
end
policy = NNPolicy(env, active_q, collect(actions(env)), length(obs_dimensions(env)))
weights = BSON.load(solver.logdir*"qnetwork.bson")[:qnetwork]
Flux.loadparams!(getnetwork(policy), weights)
saved_network_state = BSON.load(solver.logdir*"qnetwork_state.bson")[:qnetwork_state]
Flux.loadmodel!(getnetwork(policy), saved_network_state)
Flux.testmode!(getnetwork(policy))
return policy
end
Expand Down
42 changes: 42 additions & 0 deletions test/README_examples.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using DeepQLearning
using POMDPs
using Flux
using POMDPModels
using POMDPTools

@testset "README Example 1" begin
# load MDP model from POMDPModels or define your own!
mdp = SimpleGridWorld();

# Define the Q network (see Flux.jl documentation)
# the gridworld state is represented by a 2 dimensional vector.
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))

exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2));

solver = DeepQLearningSolver(qnetwork = model, max_steps=10000,
exploration_policy = exploration,
learning_rate=0.005,log_freq=500,
recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
policy = solve(solver, mdp)

sim = RolloutSimulator(max_steps=30)
r_tot = simulate(sim, mdp, policy)
println("Total discounted reward for 1 simulation: $r_tot")
end

@testset "README Example 2" begin
# Without using CuArrays
mdp = SimpleGridWorld();

# the model weights will be send to the gpu in the call to solve
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))

exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2));

solver = DeepQLearningSolver(qnetwork=model, max_steps=10000,
exploration_policy=exploration,
learning_rate=0.005,log_freq=500,
recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
policy = solve(solver, mdp)
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,8 @@ end

@test evaluate(env, policy, GLOBAL_RNG) > 1.0
end


@testset "README Examples" begin
include("README_examples.jl")
end

2 comments on commit 5449310

@dylan-asmar
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/122872

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.2 -m "<description of version>" 5449310559b2cf2947cfe3a195e497c948919075
git push origin v0.7.2

Please sign in to comment.