Skip to content

Commit

Permalink
Rename run_sim, save_snapshot, and load_snapshot to run, save, and lo…
Browse files Browse the repository at this point in the history
…ad. (#28)

* rename

* bump version

* make backwards compat

* add test of run_sim

* fix typo
  • Loading branch information
nhz2 authored Oct 24, 2023
1 parent 45f150d commit 54e6465
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 47 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MEDYANSimRunner"
uuid = "b58a3b99-22e3-44d1-b5ea-258f082a6fe8"
authors = ["nhz2 <[email protected]>"]
version = "0.4.2"
version = "0.5.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ These functions can also use the default random number generator, this will auto
At the end of `main.jl` there should be the lines:
```julia
if abspath(PROGRAM_FILE) == @__FILE__
MEDYANSimRunner.run_sim(ARGS; jobs, setup, loop, load_snapshot, save_snapshot, done)
MEDYANSimRunner.run(ARGS; jobs, setup, loop, load, save, done)
end
```

Expand All @@ -59,16 +59,16 @@ The `job` string is also used to seed the default RNG right before `setup` is ca

#### `setup(job::String; kwargs...) -> header_dict, state`
Return the header dictionary to be written as the `header.json` file in output trajectory.
Also return the state that gets passed on to `loop` and the state that gets passed to `save_snapshot` and `load_snapshot`.
Also return the state that gets passed on to `loop` and the state that gets passed to `save` and `load`.

`job::String`: The job. This is used for multi job simulations.

#### `save_snapshot(step::Int, state; kwargs...)-> group::SmallZarrGroups.ZGroup`
#### `save(step::Int, state; kwargs...)-> group::SmallZarrGroups.ZGroup`
Return the state of the system as a `SmallZarrGroups.ZGroup`
This function should not mutate `state`

#### `load_snapshot(step::Int, group::SmallZarrGroups.ZGroup, state; kwargs...) -> state`
Load the state saved by `save_snapshot`
#### `load(step::Int, group::SmallZarrGroups.ZGroup, state; kwargs...) -> state`
Load the state saved by `save`
This function can mutate `state`.
`state` may be the state returned from `setup` or the `state` returned by `loop`.
This function should return the same output if `state` is the state returned by `loop` or the
Expand All @@ -82,7 +82,7 @@ Also return the expected value of step when done will first be true, used for di
This function should not mutate `state`

#### `loop(step::Int, state; kwargs...) -> state`
Return the state that gets passed to `save_snapshot`
Return the state that gets passed to `save`


### Main loop pseudo code
Expand All @@ -95,13 +95,13 @@ Random.seed!(collect(reinterpret(UInt64, sha256(job))))
job_header, state = setup(job)
save job_header
step = 0
SmallZarrGroups.save_zip(snapshot_zip_file, save_snapshot(step, state))
state = load_snapshot(step, SmallZarrGroups.load_zip(snapshot_zip_file), state)
SmallZarrGroups.save_zip(snapshot_zip_file, save(step, state))
state = load(step, SmallZarrGroups.load_zip(snapshot_zip_file), state)
while true
state = loop(step, state)
step = step + 1
SmallZarrGroups.save_zip(snapshot_zip_file, save_snapshot(step, state))
state = load_snapshot(step, SmallZarrGroups.load_zip(snapshot_zip_file), state)
SmallZarrGroups.save_zip(snapshot_zip_file, save(step, state))
state = load(step, SmallZarrGroups.load_zip(snapshot_zip_file), state)
if done(step::Int, state)[1]
break
end
Expand Down
56 changes: 29 additions & 27 deletions src/run-sim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function get_version_string()
end

"""
run_sim(ARGS; setup, loop, load_snapshot, save_snapshot, done)
run(ARGS; setup, loop, load, save, done)
This function should be called at the end of a script to run a simulation.
It takes keyword arguments:
Expand All @@ -41,10 +41,10 @@ is called once at the beginning of the simulation.
- `loop(step::Int, state; kwargs...) -> state`
is called once per step of the simulation.
- `save_snapshot(step::Int, state; kwargs...) -> group::SmallZarrGroups.ZGroup`
- `save(step::Int, state; kwargs...) -> group::SmallZarrGroups.ZGroup`
is called to save a snapshot.
- `load_snapshot(step::Int, group::SmallZarrGroups.ZGroup, state; kwargs...) -> state`
- `load(step::Int, group::SmallZarrGroups.ZGroup, state; kwargs...) -> state`
is called to load a snapshot.
- `done(step::Int, state; kwargs...) -> done::Bool, expected_final_step::Int`
Expand All @@ -54,12 +54,12 @@ is called to check if the simulation is done.
$(CLI_HELP)
"""
function run_sim(cli_args;
function run(cli_args;
jobs::Vector{String},
setup,
loop,
save_snapshot,
load_snapshot,
save,
load,
done,
kwargs...
)
Expand All @@ -78,16 +78,16 @@ function run_sim(cli_args;
continue_job(options.out_dir, job;
setup,
loop,
save_snapshot,
load_snapshot,
save,
load,
done,
)
else
start_job(options.out_dir, job;
setup,
loop,
save_snapshot,
load_snapshot,
save,
load,
done,
)
end
Expand All @@ -98,29 +98,31 @@ function run_sim(cli_args;
continue_job(options.out_dir, job;
setup,
loop,
save_snapshot,
load_snapshot,
save,
load,
done,
)
else
start_job(options.out_dir, job;
setup,
loop,
save_snapshot,
load_snapshot,
save,
load,
done,
)
end
end
return
end

@deprecate run_sim(cli_args;save_snapshot, load_snapshot, kwargs...) run(cli_args; save=save_snapshot, load=load_snapshot, kwargs...) false


function start_job(out_dir, job::String;
setup,
loop,
save_snapshot,
load_snapshot,
save,
load,
done,
)
basic_name_check.(String.(split(job, '/'; keepempty=true)))
Expand Down Expand Up @@ -148,15 +150,15 @@ function start_job(out_dir, job::String;
write_traj_file(traj, "header.json", codeunits(header_str))
local step::Int = 0

state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save_snapshot, load_snapshot, prev_sha256)
state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save, load, prev_sha256)
@info "Simulation started."
while true
copy!(Random.default_rng(), rng_state)
state = loop(step, state)
copy!(rng_state, Random.default_rng())

step += 1
state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save_snapshot, load_snapshot, prev_sha256)
state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save, load, prev_sha256)

copy!(Random.default_rng(), rng_state)
isdone::Bool, expected_final_step::Int64 = done(step::Int, state)
Expand All @@ -177,8 +179,8 @@ end
function continue_job(out_dir, job;
setup,
loop,
save_snapshot,
load_snapshot,
save,
load,
done,
)
basic_name_check.(String.(split(job, '/'; keepempty=true)))
Expand Down Expand Up @@ -232,7 +234,7 @@ function continue_job(out_dir, job;
prev_sha256 = bytes2hex(sha256(header_str))
write_traj_file(traj, "header.json", codeunits(header_str))
step = 0
state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save_snapshot, load_snapshot, prev_sha256)
state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save, load, prev_sha256)
@info "Simulation started."
else
@info "Continuing simulation from step $(step)."
Expand All @@ -241,7 +243,7 @@ function continue_job(out_dir, job;
reread_sub_snapshot_group = snapshot_group["snap"]
rng_state = str_2_rng(attrs(snapshot_group)["rng_state"])
copy!(Random.default_rng(), rng_state)
state = load_snapshot(step, reread_sub_snapshot_group, state)
state = load(step, reread_sub_snapshot_group, state)
copy!(rng_state, Random.default_rng())
prev_sha256 = bytes2hex(sha256(snapshot_data))
if step > 0
Expand All @@ -262,7 +264,7 @@ function continue_job(out_dir, job;
copy!(rng_state, Random.default_rng())

step += 1
state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save_snapshot, load_snapshot, prev_sha256)
state, prev_sha256 = save_load_state!(rng_state, step, state, traj, save, load, prev_sha256)

copy!(Random.default_rng(), rng_state)
isdone, expected_final_step = done(step::Int, state)
Expand All @@ -286,14 +288,14 @@ function save_load_state!(
step::Int,
state,
traj::String,
save_snapshot,
load_snapshot,
save,
load,
prev_sha256::String,
)
snapshot_group = ZGroup()

copy!(Random.default_rng(), rng_state)
sub_snapshot_group = save_snapshot(step, state)
sub_snapshot_group = save(step, state)
copy!(rng_state, Random.default_rng())

snapshot_group["snap"] = sub_snapshot_group
Expand All @@ -304,7 +306,7 @@ function save_load_state!(
reread_sub_snapshot_group = unzip_group(snapshot_data)["snap"]

copy!(Random.default_rng(), rng_state)
state = load_snapshot(step, reread_sub_snapshot_group, state)
state = load(step, reread_sub_snapshot_group, state)
copy!(rng_state, Random.default_rng())

write_traj_file(traj, SNAP_PREFIX*string(step)*SNAP_POSTFIX, snapshot_data)
Expand Down
6 changes: 3 additions & 3 deletions test/example/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ function setup(job::String; kwargs...)
header, state
end

function save_snapshot(step::Int, state; kwargs...)::ZGroup
function save(step::Int, state; kwargs...)::ZGroup
# @info "saving states" state
group = ZGroup()
group["states"] = state
group
end

function load_snapshot(step::Int, group, state; kwargs...)
function load(step::Int, group, state; kwargs...)
state .= collect(group["states"])
state
end
Expand All @@ -52,5 +52,5 @@ function loop(step::Int, state; kwargs...)
end

if abspath(PROGRAM_FILE) == @__FILE__
MEDYANSimRunner.run_sim(ARGS; jobs, setup, loop, load_snapshot, save_snapshot, done)
MEDYANSimRunner.run(ARGS; jobs, setup, loop, load, save, done)
end
33 changes: 27 additions & 6 deletions test/test_ref_out.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ warn_only_logger = MinLevelLogger(current_logger(), Logging.Warn);
args = ["--out=$test_out","--batch=1"]
continue_sim && push!(args,"--continue")
with_logger(warn_only_logger) do
MEDYANSimRunner.run_sim(args;
MEDYANSimRunner.run(args;
UserCode.jobs,
UserCode.setup,
UserCode.save_snapshot,
UserCode.load_snapshot,
UserCode.save,
UserCode.load,
UserCode.loop,
UserCode.done,
)
Expand Down Expand Up @@ -77,11 +77,11 @@ end
args = ["--out=$test_out","--batch=1"]
continue_sim && push!(args,"--continue")
with_logger(warn_only_logger) do
MEDYANSimRunner.run_sim(args;
MEDYANSimRunner.run(args;
UserCode.jobs,
UserCode.setup,
UserCode.save_snapshot,
UserCode.load_snapshot,
UserCode.save,
UserCode.load,
UserCode.loop,
UserCode.done,
)
Expand All @@ -93,4 +93,25 @@ end
end
end
end

@testset "deprecated full run example" begin
test_out = joinpath(@__DIR__, "example/output/deprecated-full")
rm(test_out; force=true, recursive=true)
args = ["--out=$test_out","--batch=1"]
with_logger(warn_only_logger) do
MEDYANSimRunner.run_sim(args;
UserCode.jobs,
UserCode.setup,
save_snapshot=UserCode.save,
load_snapshot=UserCode.load,
UserCode.loop,
UserCode.done,
)
end
out_diff = sprint(MEDYANSimRunner.print_traj_diff, joinpath(ref_out,"a"), joinpath(test_out,"a"))
if !isempty(out_diff)
println(out_diff)
@test false
end
end
end

2 comments on commit 54e6465

@nhz2
Copy link
Member Author

@nhz2 nhz2 commented on 54e6465 Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/94047

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.0 -m "<description of version>" 54e646503c18b8f05d3e74959fa4daf2d4802374
git push origin v0.5.0

Please sign in to comment.