diff --git a/Project.toml b/Project.toml index c17e289..5e9f004 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Experimenter" uuid = "6aee034a-9508-47b1-8e11-813cc29af79f" authors = ["Jamie Mair and contributors"] -version = "0.1.1" +version = "0.1.2" [deps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 12a4fd9..5461707 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.8.2" +julia_version = "1.9.2" manifest_format = "2.0" project_hash = "ed6ce97ede4b2f3f9b4003b8bba0e6161365248c" @@ -20,15 +20,19 @@ uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" [[deps.Compat]] -deps = ["Dates", "LinearAlgebra", "UUIDs"] -git-tree-sha1 = "3ca828fe1b75fa84b021a7860bd039eaea84d2f2" +deps = ["UUIDs"] +git-tree-sha1 = "4e88377ae7ebeaf29a047aa1ee40826e0b708a5d" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.3.0" +version = "4.7.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "0.5.2+0" +version = "1.0.5+0" [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" @@ -41,21 +45,21 @@ uuid = "a10d1c49-ce27-4219-8d33-6db1a4562965" version = "2.5.0" [[deps.DataAPI]] -git-tree-sha1 = "46d2680e618f8abd007bce0c3026cb0c4a8f2032" +git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.12.0" +version = "1.15.0" [[deps.DataFrames]] -deps = ["Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SnoopPrecompile", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "5b93f1b47eec9b7194814e40542752418546679f" +deps = ["Compat", "DataAPI", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "089d29c0fc00a190661517e4f3cba5dcb3fd0c08" uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.4.2" +version = "1.6.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0" +git-tree-sha1 = "cf25ccb972fec4e4817764d01c82386ae94f77b4" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.13" +version = "0.18.14" [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -72,15 +76,15 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.DocStringExtensions]] deps = ["LibGit2"] -git-tree-sha1 = "c36550cb29cbe373e95b3f40486b9a4148f89ffd" +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.2" +version = "0.9.3" [[deps.Documenter]] deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "6030186b00a38e9d0434518627426570aac2ef95" +git-tree-sha1 = "39fd748a73dce4c05a9655475e437170d8fb1b67" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.27.23" +version = "0.27.25" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] @@ -88,10 +92,10 @@ uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" [[deps.Experimenter]] -deps = ["DataFrames", "Distributed", "Logging", "Pkg", "ProgressBars", "Random", "SQLite", "SafeTestsets", "TestItemRunner", "TestItems", "UUIDs"] +deps = ["DataFrames", "Distributed", "Logging", "Pkg", "ProgressBars", "Random", "SQLite", "SafeTestsets", "UUIDs"] path = ".." uuid = "6aee034a-9508-47b1-8e11-813cc29af79f" -version = "0.1.0" +version = "0.1.1" [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" @@ -108,24 +112,24 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.IOCapture]] deps = ["Logging", "Random"] -git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a" +git-tree-sha1 = "d75853a0bdbfb1ac815478bacd89cd27b550ace6" uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.2" +version = "0.2.3" [[deps.InlineStrings]] deps = ["Parsers"] -git-tree-sha1 = "a62189e59d33e1615feb7a48c0bea7c11e4dc61d" +git-tree-sha1 = "9cc2baf75c6d09f9da536ddf58eb2f29dedaf461" uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" -version = "1.3.0" +version = "1.4.0" [[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[deps.InvertedIndices]] -git-tree-sha1 = "bee5f1ef5bf65df56bdd2e40447590b272a5471f" +git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" -version = "1.1.0" +version = "1.3.0" [[deps.IteratorInterfaceExtensions]] git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" @@ -140,9 +144,14 @@ version = "1.4.1" [[deps.JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e" +git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.3" +version = "0.21.4" + +[[deps.LaTeXStrings]] +git-tree-sha1 = "f2355693d6778a178ade15952b7ac47a4ff97996" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.3.0" [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] @@ -167,7 +176,7 @@ version = "1.10.2+0" uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" [[deps.LinearAlgebra]] -deps = ["Libdl", "libblastrampoline_jll"] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.Logging]] @@ -180,20 +189,20 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.0+0" +version = "2.28.2+0" [[deps.Missings]] deps = ["DataAPI"] -git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f" +git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.0.2" +version = "1.1.0" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.2.1" +version = "2022.10.11" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" @@ -202,23 +211,23 @@ version = "1.2.0" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.20+0" +version = "0.3.21+4" [[deps.OrderedCollections]] -git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" +git-tree-sha1 = "d321bf2de576bf25ec4d3e4360faca399afca282" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.4.1" +version = "1.6.0" [[deps.Parsers]] -deps = ["Dates"] -git-tree-sha1 = "6c01a9b494f6d2a9fc180a08b182fcb06f0958a0" +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "4b2e829ee66d4218e0cef22c0a64ee37cf258c29" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.4.2" +version = "2.7.1" [[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.8.0" +version = "1.9.2" [[deps.PooledArrays]] deps = ["DataAPI", "Future"] @@ -226,17 +235,23 @@ git-tree-sha1 = "a6062fe4063cdafe78f4a0a81cfffb89721b30e7" uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" version = "1.4.2" +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "9673d39decc5feece56ef3940e5dafba15ba0f81" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.1.2" + [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" +git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.3.0" +version = "1.4.0" [[deps.PrettyTables]] -deps = ["Crayons", "Formatting", "Markdown", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "460d9e154365e058c4d886f6f7d6df5ffa1ea80e" +deps = ["Crayons", "Formatting", "LaTeXStrings", "Markdown", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "331cc8048cba270591eab381e7aa3e2e3fef7f5e" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.1.2" +version = "2.2.5" [[deps.Printf]] deps = ["Unicode"] @@ -244,9 +259,9 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" [[deps.ProgressBars]] deps = ["Printf"] -git-tree-sha1 = "806ebc92e1b4b4f72192369a28dfcaf688566b2b" +git-tree-sha1 = "9d84c8646109eb8bc7a006d59b157c64d5155c81" uuid = "49802e3a-d2f1-5c88-81d8-b72133a6f568" -version = "1.4.1" +version = "1.5.0" [[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] @@ -266,16 +281,16 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" [[deps.SQLite]] -deps = ["DBInterface", "Dates", "Libdl", "Random", "SQLite_jll", "Serialization", "Tables", "Test", "WeakRefStrings"] -git-tree-sha1 = "93863dea19851aa45d67d70e638e16666f33731e" +deps = ["DBInterface", "Random", "SQLite_jll", "Serialization", "Tables", "WeakRefStrings"] +git-tree-sha1 = "eb9a473c9b191ced349d04efa612ec9f39c087ea" uuid = "0aa819cd-b072-5ff4-a722-6bc24af294d9" -version = "1.4.2" +version = "1.6.0" [[deps.SQLite_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] -git-tree-sha1 = "9d920c4ee8cd5684e23bf84f43ead45c0af796e7" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] +git-tree-sha1 = "4619dd3363610d94fb42a95a6dc35b526a26d0ef" uuid = "76ed43ae-9a5d-5a62-8c75-30186b810ce8" -version = "3.39.4+0" +version = "3.42.0+0" [[deps.SafeTestsets]] deps = ["Test"] @@ -283,40 +298,47 @@ git-tree-sha1 = "36ebc5622c82eb9324005cc75e7e2cc51181d181" uuid = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" version = "0.0.1" +[[deps.SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "04bdff0b09c65ff3e06a05e3eb7b120223da3d39" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.4.0" + [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" -[[deps.SnoopPrecompile]] -git-tree-sha1 = "f604441450a3c0569830946e5b33b78c928e1a85" -uuid = "66db9d55-30c0-4569-8b51-7e840670fc0c" -version = "1.0.1" - [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[deps.SortingAlgorithms]] deps = ["DataStructures"] -git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508" +git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.0.1" +version = "1.1.1" [[deps.SparseArrays]] -deps = ["LinearAlgebra", "Random"] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.9.0" [[deps.StringManipulation]] git-tree-sha1 = "46da2434b41f41ac3594ee9816ce5541c6096123" uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" version = "0.3.0" +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "5.10.1+6" + [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.0" +version = "1.0.3" [[deps.TableTraits]] deps = ["IteratorInterfaceExtensions"] @@ -326,30 +348,19 @@ version = "1.0.1" [[deps.Tables]] deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] -git-tree-sha1 = "c79322d36826aa2f4fd8ecfa96ddb47b174ac78d" +git-tree-sha1 = "1544b926975372da01227b382066ab70e574a3ec" uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.10.0" +version = "1.10.1" [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.1" +version = "1.10.0" [[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -[[deps.TestItemRunner]] -deps = ["TOML", "Test", "TestItems"] -git-tree-sha1 = "724a24752f42117bf071df3fad3d4d0ea780d30e" -uuid = "f8b46487-2199-4994-9208-9a1283c18c0a" -version = "0.2.1" - -[[deps.TestItems]] -git-tree-sha1 = "b7aab261a1662071e1c976b18ad2cee88e27d04a" -uuid = "1c621080-faea-4a02-84b6-bbd5e436b8fe" -version = "0.1.0" - [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" @@ -366,12 +377,12 @@ version = "1.4.2" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.12+3" +version = "1.2.13+0" [[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] +deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.1.1+0" +version = "5.8.0+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] diff --git a/docs/make.jl b/docs/make.jl index a76955b..cc8a9c9 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -19,6 +19,7 @@ makedocs(; "Getting Started" => "getting_started.md", "Running your Experiments" => "execution.md", "Distributed Execution" => "distributed.md", + "Data Store" => "store.md", "Custom Snapshots" => "snapshots.md", "Public API" => "api.md" ], diff --git a/docs/src/api.md b/docs/src/api.md index 1cd77cf..a59614d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -17,6 +17,11 @@ get_experiment_by_name get_ratio_completed_trials_by_name ``` +## Data Storage +```@docs +get_global_store +``` + ## Trials ```@docs get_trial diff --git a/docs/src/index.md b/docs/src/index.md index ce0d5dd..aeecc2a 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -22,6 +22,7 @@ Pages = [ "getting_started.md", "execution.md", "distributed.md", + "store.md", "snapshots.md" ] Depth = 2 diff --git a/docs/src/store.md b/docs/src/store.md new file mode 100644 index 0000000..aef88be --- /dev/null +++ b/docs/src/store.md @@ -0,0 +1,44 @@ +# Store + +As many experiments may require a data of data to be preloaded for each trial, `Experiment.jl` provides a data store that can be initialised once on each worker to reduce the amount of time required for loading the same data. + +This store is intended as a **read-only** store that is reused upon execution of each of the trials. + +## Usage + +To start, you must create a function which creates the data to be stored, similar to the functions that runs the trial. As en example: +```julia +# Goes inside the same file as your experiment run file (i.e. the file that get's included). +function create_global_store(config) + # config is the global configuration given to the experiment + data = Dict{Symbol, Any}( + :dataset => rand(1000), + :flag => false, + # etc... + ) + return data +end +``` +The variable `config` will be the configuration provided to the `Experiment` struct created for your experiment. Importantly, this function will return a `Dict{Symbol, Any}`. + +The name of this function can be anything, but you need to supply it to the experiment when it is being created, i.e. +```julia +experiment = Experiment( + name="Test Experiment", + include_file="run.jl", + function_name="run_trial", + init_store_function_name="create_global_store", + configuration=config +) +``` + +Inside your `run_trial` function, you can access the global store using `get_global_store` +```julia +using Experimenter # exports get_global_store +function run_trial(config, trial_id) + store = get_global_store() + dataset = store[:dataset] + # gather your results + return results +end +``` \ No newline at end of file diff --git a/src/Experimenter.jl b/src/Experimenter.jl index 65378d3..9ff1914 100644 --- a/src/Experimenter.jl +++ b/src/Experimenter.jl @@ -1,5 +1,6 @@ module Experimenter +include("store.jl") include("snapshots.jl") include("experiment.jl") include("database.jl") @@ -24,7 +25,7 @@ export get_trial, get_trials, get_trials_by_name, get_trials_ids_by_name, get_re export complete_trial!, complete_trial_in_global_database, mark_trial_as_incomplete! ### Execution -export execute_trial, execute_trial_and_save_to_db_async +export execute_trial, execute_trial_and_save_to_db_async, get_global_store export @execute export SerialMode, MultithreadedMode, DistributedMode diff --git a/src/database.jl b/src/database.jl index 3b0bb40..814f90f 100644 --- a/src/database.jl +++ b/src/database.jl @@ -19,7 +19,7 @@ end function get_experiment_insert_stmt(db::SQLite.DB) sql = raw""" - INSERT OR IGNORE INTO Experiments (id, name, include_file, function_name, configuration, num_trials) VALUES (?, ?, ?, ?, ?, ?) + INSERT OR IGNORE INTO Experiments (id, name, include_file, function_name, init_store_function_name, configuration, num_trials) VALUES (?, ?, ?, ?, ?, ?, ?) """ return SQLite.Stmt(db, sql) end @@ -30,7 +30,7 @@ function get_trial_insert_stmt(db::SQLite.DB) return SQLite.Stmt(db, sql) end function Base.push!(db::ExperimentDatabase, experiment::Experiment) - vs = (string(experiment.id), experiment.name, experiment.include_file, experiment.function_name, experiment.configuration, experiment.num_trials) + vs = (string(experiment.id), experiment.name, experiment.include_file, experiment.function_name, experiment.init_store_function_name, experiment.configuration, experiment.num_trials) SQLite.execute(db._experimentInsertStmt, vs) nothing end @@ -46,7 +46,7 @@ function Base.push!(db::ExperimentDatabase, snapshot::Snapshot) end -Experiment(row::DataFrameRow) = Experiment(UUID(row.id), row.name, row.include_file, row.function_name, row.configuration, row.num_trials) +Experiment(row::DataFrameRow) = Experiment(UUID(row.id), row.name, row.include_file, row.function_name, row.init_store_function_name, row.configuration, row.num_trials) Trial(row::DataFrameRow) = Trial( id=UUID(row.id), experiment_id=UUID(row.experiment_id), @@ -64,6 +64,7 @@ function prepare_db(db::SQLite.DB) name TEXT NOT NULL UNIQUE, include_file TEXT, function_name TEXT, + init_store_function_name TEXT, configuration BLOB, num_trials INTEGER NOT NULL ); @@ -146,11 +147,13 @@ end function check_overlap(experimentA::Experiment, experimentB::Experiment) - if experimentA.name != experimentB.name + if experimentA.name !== experimentB.name return false - elseif experimentA.function_name != experimentB.function_name + elseif experimentA.function_name !== experimentB.function_name return false - elseif experimentA.num_trials != experimentB.num_trials + elseif experimentA.num_trials !== experimentB.num_trials + return false + elseif experimentA.init_store_function_name !== experimentB.init_store_function_name return false else trials_a = collect(experimentA) diff --git a/src/experiment.jl b/src/experiment.jl index 7801131..db582d1 100644 --- a/src/experiment.jl +++ b/src/experiment.jl @@ -147,6 +147,7 @@ Base.@kwdef struct Experiment name::AbstractString include_file::Union{Missing,AbstractString} = missing function_name::AbstractString + init_store_function_name::Union{Missing,AbstractString} = missing configuration::Dict{Symbol,Any} num_trials::Int = count_trials(configuration) end diff --git a/src/runner.jl b/src/runner.jl index 3a51e93..86ffba5 100644 --- a/src/runner.jl +++ b/src/runner.jl @@ -11,6 +11,11 @@ using Pkg @doc raw"Executes the trials of the experiment in parallel using `Threads.@Threads`" MultithreadedMode @doc raw"Executes the trials of the experiment in parallel using `Distributed.jl`s `pmap`." DistributedMode +# Global database +const global_database = Ref{Union{Missing, ExperimentDatabase}}(missing) +const global_database_lock = Ref{ReentrantLock}(ReentrantLock()) +const global_store = Ref{Union{Missing, Store}}(missing) + Base.@kwdef struct Runner execution_mode::EXECUTEMODE @@ -75,6 +80,64 @@ macro execute(experiment, database, mode=SerialMode, use_progress=false, directo end end +""" + construct_store(function_name, configuration) + +Calls the supplied function and constructs a store to use +globally per process. This can be used to initialise a +shared database. The store is intended to be read-only. +""" +function construct_store(function_name::AbstractString, configuration) + fn = Base.eval(Main, Meta.parse("$function_name")) + store_data = fn(configuration) + # ToDo add a potential lock here? This should only be called once per process. + global_store[] = Store(store_data) + return nothing +end +construct_store(::Missing, ::Any) = Store() # Construct an empty store + +""" + get_global_store() + +Tries to get the global store that is initialised by the supplied +function with the name specified by `init_store_function_name` set in +the running experiment. This store is local to each worker. + +# Setup +To create the store, add a function in your include file which +returns a dictionary of type Dict{Symbol, Any}, which has the +signature similar to: +```julia +function create_global_store(config) + # config is the global configuration given to the experiment + data = Dict{Symbol, Any}( + :dataset => rand(1000), + :flag => false, + # etc... + ) + return data +end +``` + +Inside your main experiment execution function, you can get this +store via `get_global_store`, which is exported by `Experimenter`. +```julia +function myrunner(config, trial_id) + store = get_global_store() + dataset = store[:dataset] # Retrieve the keys from the store + # process data + return results +end +``` +""" +function get_global_store() + if ismissing(global_store[]) + error("Tried to get the global store, but it was not initialised. Make sure 'init_store_function_name' is set when you create the experiment.") + end + + return (global_store[])::Store +end + function execute_trial(function_name::AbstractString, trial::Trial)::Tuple{UUID,Dict{Symbol,Any}} fn = Base.eval(Main, Meta.parse("$function_name")) results = fn(trial.configuration, trial.id) @@ -88,24 +151,26 @@ function execute_trial_and_save_to_db_async(function_name::AbstractString, trial end function set_global_database(db::ExperimentDatabase) - global global_experiment_database = db - lck = ReentrantLock() - global global_database_lock = lck + lock(global_database_lock[]) do + global_database[] = db + end end function unset_global_database() - global global_experiment_database = nothing - global global_database_lock = nothing + lock(global_database_lock[]) do + global_database[] = missing + end end """ complete_trial_in_global_database(trial_id::UUID, results::Dict{Symbol,Any}) Marks a specific trial (with `trial_id`) complete in the global database and stores the supplied `results`. Redirects to the master node if on a worker node. Locks to secure access. """ -function complete_trial_in_global_database(trial_id::UUID, results::Dict{Symbol,Any}) - global global_experiment_database, global_database_lock - - lock(global_database_lock) do - complete_trial!(global_experiment_database, trial_id, results) +function complete_trial_in_global_database(trial_id::UUID, results::Dict{Symbol,Any}) + lock(global_database_lock[]) do + if ismissing(global_database[]) + @error "Global database should have been set prior to calling this function." + end + complete_trial!(global_database[], trial_id, results) end end @@ -119,10 +184,8 @@ function get_results_from_trial_global_database(trial_id::UUID) return remotecall_fetch(get_results_from_trial_global_database, 1, trial_id) end - global global_experiment_database, global_database_lock - - lock(global_database_lock) do - trial = get_trial(global_experiment_database, trial_id) + lock(global_database_lock[]) do + trial = get_trial(global_database[], trial_id) return trial.results end end @@ -138,10 +201,8 @@ function save_snapshot_in_global_database(trial_id::UUID, state::Dict{Symbol,Any return nothing end - global global_experiment_database, global_database_lock - - lock(global_database_lock) do - save_snapshot!(global_experiment_database, trial_id, state, label) + lock(global_database_lock[]) do + save_snapshot!(global_database[], trial_id, state, label) end nothing end @@ -155,12 +216,9 @@ function get_latest_snapshot_from_global_database(trial_id::UUID) if myid() != 1 return remotecall_fetch(get_latest_snapshot_from_global_database, 1, trial_id) end - - global global_experiment_database, global_database_lock - - - snapshot = lock(global_database_lock) do - return latest_snapshot(global_experiment_database, trial_id) + + snapshot = lock(global_database_lock[]) do + return latest_snapshot(global_database[], trial_id) end return snapshot end @@ -173,18 +231,39 @@ function run_trials(runner::Runner, trials::AbstractArray{Trial}; use_progress=f return nothing end + execution_mode = runner.execution_mode + iter = use_progress ? ProgressBar(trials) : trials - if runner == DistributedMode && length(workers()) <= 1 + if execution_mode == DistributedMode && length(workers()) <= 1 @info "Only one worker found, switching to serial execution." - runner = SerialMode + execution_mode = SerialMode end set_global_database(runner.database) - if runner.execution_mode == DistributedMode + + # Run initialisation + if !ismissing(runner.experiment.init_store_function_name) + init_fn_name = runner.experiment.init_store_function_name + experiment_config = runner.experiment.configuration + @info "Initialising the store." + if execution_mode == DistributedMode + # Each worker has their own copy of the store + tasks = map(workers()) do worker_id + remotecall(construct_store, worker_id, init_fn_name, experiment_config) + end + wait.(tasks) + else + construct_store(init_fn_name, experiment_config) + end + + end + + + if execution_mode == DistributedMode @info "Running $(length(trials)) trials across $(length(workers())) workers" function_names = (_ -> runner.experiment.function_name).(trials) use_progress && @debug "Progress bar not supported in distributed mode." pmap(execute_trial_and_save_to_db_async, function_names, trials) - elseif runner.execution_mode == MultithreadedMode + elseif execution_mode == MultithreadedMode @info "Running $(length(trials)) trials across $(Threads.nthreads()) threads" Threads.@threads for trial in iter (id, results) = execute_trial(runner.experiment.function_name, trial) diff --git a/src/store.jl b/src/store.jl new file mode 100644 index 0000000..bac77f5 --- /dev/null +++ b/src/store.jl @@ -0,0 +1,26 @@ +import Base + +""" + Store(data::Dict{Symbol, Any}) + +Contains a set of data that can be reused by functions. This is +initialised once by a custom function ran by the user and passed +into the runner function. + +# Examples + +The store acts like a key-value store, where data can be accessed +via +```julia +store[:datakey] +``` +""" +Base.@kwdef struct Store + data::Dict{Symbol, Any} = Dict{Symbol, Any}() +end + +Base.getindex(store::Store, key::Symbol) = Base.getindex(store.data, key) +# Currently does not support setindex! as this is intended to be readonly storage +# Base.setindex!(store::Store, key::Symbol, value) = Base.setindex!(store.data, key, value) + +export Store \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 1a1819a..bad1bfb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,4 +11,7 @@ end end @safetestset "Restore from trial" begin include("restore_from_trial.jl") +end +@safetestset "Store" begin + include("store.jl") end \ No newline at end of file diff --git a/test/store.jl b/test/store.jl new file mode 100644 index 0000000..41961a5 --- /dev/null +++ b/test/store.jl @@ -0,0 +1,68 @@ +using Experimenter +using Experimenter: @execute +using Logging +using Distributed +include("trial_functions.jl") + +file_path = @__FILE__ +directory = dirname(file_path) + +Logging.disable_logging(Logging.Info) + +function get_store_test_config() + return Dict{Symbol,Any}( + :n => IterableVariable([3, 4, 5]), + :m => 5, + :flag => IterableVariable([false, true]), + :label => "configuration", + :param => (777, false) # test this param + ) +end + +function get_store_experiment(name, config) + experiment = Experiment( + name=name, + include_file="trial_functions.jl", + function_name="run_experiment_with_store", + init_store_function_name="init_store", + configuration=config + ) + return experiment +end + +@testset "Local store init" for mode in (SerialMode, MultithreadedMode) + database = open_db("store init test"; in_memory=true) + config = get_store_test_config() + experiment = get_store_experiment("local store init $(mode)", config) + + + @execute experiment database mode false directory + + trials = get_trials_by_name(database, experiment.name) + + for trial in trials + results = trial.results # should return the config + expected_results = init_store(config) + @test expected_results == results + end +end + + +@testset "Distributed store init" begin + ps = addprocs(2) + database = open_db("store init test"; in_memory=true) + config = get_store_test_config() + experiment = get_store_experiment("distributed store init test", config) + + @execute experiment database DistributedMode false directory + + trials = get_trials_by_name(database, experiment.name) + + for trial in trials + results = trial.results # should return the config + expected_results = init_store(config) + @test expected_results == results + end + # Cleanup + rmprocs(ps...) +end diff --git a/test/trial_functions.jl b/test/trial_functions.jl index 65297ed..e03976f 100644 --- a/test/trial_functions.jl +++ b/test/trial_functions.jl @@ -1,4 +1,24 @@ using Experimenter +using Test + +function init_store(config) + return Dict{Symbol, Any}( + :magic_number => 42, + :data => [1,101,1001], + :config_param => config[:param] + ) +end + +function run_experiment_with_store(config, trial_id) + store = get_global_store() + @assert store[:config_param] == config[:param] + + return Dict{Symbol, Any}( + :data=>store[:data], + :magic_number=>store[:magic_number], + :config_param=>store[:config_param] + ) +end function run_experiment(config, trial_id) info = Dict{Symbol,Any}()