Skip to content

Commit

Permalink
Merge pull request #8 from gkwt/main
Browse files Browse the repository at this point in the history
Example scripts added
  • Loading branch information
gkwt authored Nov 29, 2024
2 parents cecf394 + 07c6a3b commit 7121e61
Show file tree
Hide file tree
Showing 9 changed files with 1,069 additions and 540 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
269 changes: 267 additions & 2 deletions docs/source/examples/notebooks/multi_fidelity.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,277 @@
{
"cell_type": "markdown",
"metadata": {},
"source": []
"source": [
"Some experiments can be very expensive. These may be supplemented by simpler alternatives, or perhaps high-throughput calculations. This would give measurements of lower *fidelity*, and the planner can take advantage of these measurements to guide high fidelity optimization.\n",
"\n",
"This can also be used in a virtual screening setting. Expensive quantum chemistry calculations can be supplemented by faster semi-empirical methods. Another example could also be the virtual screening of compounds for drug activity, with high fidelity free-energy perturbation calcualtions being approximated by faster and lower fidelity docking calculations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/garyk/mambaforge/envs/atlas/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import json\n",
"import pickle\n",
"import numpy as np\n",
"import pandas as pd\n",
"from copy import deepcopy\n",
"\n",
"from olympus.datasets import Dataset\n",
"from olympus.objects import (\n",
"\tParameterContinuous,\n",
"\tParameterDiscrete, \n",
"\tParameterCategorical, \n",
"\tParameterVector\n",
")\n",
"from olympus.campaigns import ParameterSpace, Campaign\n",
"\n",
"from atlas.planners.multi_fidelity.planner import MultiFidelityPlanner # specially designed planner for multi-fidelity optimization\n",
"\n",
"import pickle"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For this example, we will perform a screening of the bandgap of perovskites. There's two fidelities of measurements, one using GGA (*low*), and one use HSE06 (*high*). You can set the associated cost to each one, but we will consider queries to GGA calculations as 10 times cheaper than with HSE06."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"COST_BUDGET = 50 # this time the budget is a cost\n",
"NUM_INIT_DESIGN = 10\n",
"NUM_CHEAP = 8 # this is the ratio of low:high measurements (ie. 8:1 low/high fidelity)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we will create an additional fidelity parameter `s`, which can only be the permitted fidelities. The `MultiFidelityPlanner` will be allowed to vary this parameter, and perform optimization with an additional constrained *fidelity* parameter."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"dataset = Dataset(kind='perovskites')\n",
"\n",
"# build parameter space\n",
"param_space = ParameterSpace()\n",
"\n",
"# fidelity param\n",
"param_space.add(ParameterDiscrete(name='s', options=[0.1, 1.0], low=0.1, high=1.0))\n",
"for param in dataset.param_space: # add perovskite component parameters ('organic', 'cation', and 'anion')\n",
"\tparam_space.add(param)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# lower fidelity data calucated using GGA is available in the examples folder\n",
"# so we will load it here to create a new function for measurements\n",
"# fill in the ATLAS_PATH\n",
"ATLAS_PATH = '.'\n",
"LOOKUP = pickle.load(open(f'{ATLAS_PATH}/examples/multi_fidelity/perovskites/lookup/lookup_table.pkl', 'rb'))\n",
"\n",
"def measure(params, s):\n",
"\t# high-fidelity is hse06, low-fidelity is gga\n",
"\tif s == 1.0:\n",
"\t\tmeasurement = np.amin(\n",
"\t\t\tLOOKUP[params.organic.capitalize()][params.cation][params.anion]['bandgap_hse06']\n",
"\t\t)\n",
"\telif s == 0.1:\n",
"\t\tmeasurement = np.amin(\n",
"\t\t\tLOOKUP[params.organic.capitalize()][params.cation][params.anion]['bandgap_gga']\n",
"\t\t)\n",
"\treturn measurement\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #051923; text-decoration-color: #051923\">───────────────────────────────────────────────────────────────────────────────────────────────────────────────────</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[38;2;5;25;35m───────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #05a6fb; text-decoration-color: #05a6fb; font-weight: bold\"> </span>\n",
"<span style=\"color: #05a6fb; text-decoration-color: #05a6fb; font-weight: bold\"> Welcome to ATLAS! </span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;38;2;5;166;251m \u001b[0m\u001b[1;38;2;5;166;251m \u001b[0m\u001b[1;38;2;5;166;251m \u001b[0m\n",
"\u001b[1;38;2;5;166;251m \u001b[0m\u001b[1;38;2;5;166;251mWelcome to ATLAS!\u001b[0m\u001b[1;38;2;5;166;251m \u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #006494; text-decoration-color: #006494\"> Made with 💕 in 🇨🇦 </span>\n",
"<span style=\"color: #006494; text-decoration-color: #006494\"> </span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[38;2;0;100;148m \u001b[0m\u001b[38;2;0;100;148mMade with 💕 in 🇨🇦\u001b[0m\u001b[38;2;0;100;148m \u001b[0m\n",
"\u001b[38;2;0;100;148m \u001b[0m\u001b[38;2;0;100;148m \u001b[0m\u001b[38;2;0;100;148m \u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #051923; text-decoration-color: #051923\">───────────────────────────────────────────────────────────────────────────────────────────────────────────────────</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[38;2;5;25;35m───────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #05a6fb; text-decoration-color: #05a6fb\">───────────────────────────── Initial design phase ─────────────────────────────</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[38;2;5;166;251m───────────────────────────── Initial design phase ─────────────────────────────\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"campaign = Campaign()\n",
"campaign.set_param_space(param_space)\n",
"\n",
"planner = MultiFidelityPlanner(\n",
" goal='minimize',\n",
" init_design_strategy='random',\n",
" num_init_design=NUM_INIT_DESIGN,\n",
" use_descriptors=True,\n",
" batch_size=1,\n",
" acquisition_optimizer_kind='pymoo', # this is required\n",
" fidelity_params=0, # this dimension is the fidelity parameter (we use the first one)\n",
" fidelities=[0.1, 1.], # these are the possible fidelities (GGA = 0.1, and HSE = 1.0)\n",
")\n",
"\n",
"planner.set_param_space(param_space)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# accumulated cost, the budget is also cost\n",
"COST = 0.\n",
"\n",
"target_rec_measurements = []\n",
"iter_ = 0\n",
"while COST < COST_BUDGET:\n",
" print(f'\\nITER : {iter_+1}\\tCOST : {COST}\\n')\n",
"\n",
" # this is how much the corresponding measurement will cost\n",
" if iter_ % NUM_CHEAP == 0:\n",
" planner.set_ask_fidelity(1.0)\n",
" else:\n",
" planner.set_ask_fidelity(0.1)\n",
"\n",
" samples = planner.recommend(campaign.observations)\n",
" for sample in samples:\n",
" measurement = measure(sample, sample.s)\n",
" campaign.add_observation(sample, measurement)\n",
"\n",
" print('SAMPLE : ', sample)\n",
" print('MEASUREMENT : ', measurement)\n",
"\n",
" iter_+=1\n",
"\n",
" # do a check to see if model will find the optimal\n",
" if campaign.num_obs > NUM_INIT_DESIGN:\n",
" # make greedy recommendation on the target fidelity\n",
" # use this to make a high-fidelity measurement\n",
" rec_sample = planner.recommend_target_fidelity(batch_size=1)[0]\n",
" rec_measurement = measure(rec_sample, rec_sample.s)\n",
" print('REC SAMPLE : ', rec_sample)\n",
" print('REC MEASUREMENT : ', rec_measurement)\n",
"\n",
" target_rec_measurements.append(rec_measurement)\n",
" # kill the run if we have found the lowest hse06 bandgap\n",
" # on the most recent high-fidelity measurement\n",
" if rec_measurement == min_hse06_bandgap:\n",
" print('found the min hse06 bandgap!')\n",
" break\n",
" else:\n",
" target_rec_measurements.append(measurement)\n",
" if measurement == min_hse06_bandgap and samples[0].s == 1.:\n",
" print('found the min hse06 bandgap!')\n",
" break"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "atlas",
"language": "python",
"name": "atlas"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.20"
},
"orig_nbformat": 4
},
Expand Down
Loading

0 comments on commit 7121e61

Please sign in to comment.