Skip to content

Commit

Permalink
Revert "Merge pull request #69 from automl/feat/ask_tell"
Browse files Browse the repository at this point in the history
This reverts commit 1ff61a0, reversing
changes made to a7f6bd9.
  • Loading branch information
Bronzila committed Jan 22, 2024
1 parent 7b49f35 commit 652fe0b
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 330 deletions.
4 changes: 4 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,17 @@ markdown_extensions:
- name: mermaid
class: mermaid
format: !!python/name:pymdownx.superfences.fence_code_format
- pymdownx.emoji:
emoji_index: !!python/name:materialx.emoji.twemoji
emoji_generator: !!python/name:materialx.emoji.to_svg

plugins:
- search
- markdown-exec
- mkdocstrings:
default_handler: python
enable_inventory: true
custom_templates: docs/_templates
handlers:
python:
paths: [src]
Expand Down
13 changes: 4 additions & 9 deletions src/dehb/optimizers/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def vector_to_configspace(self, vector: np.ndarray) -> ConfigSpace.Configuration
'''
# creates a ConfigSpace object dict with all hyperparameters present, the inactive too
new_config = ConfigSpace.util.impute_inactive_values(
self.cs.get_default_configuration()
self.cs.sample_configuration()
).get_dictionary()
# iterates over all hyperparameters and normalizes each based on its type
for i, hyper in enumerate(self.cs.get_hyperparameters()):
Expand Down Expand Up @@ -304,17 +304,12 @@ def f_objective(self, x, fidelity=None, **kwargs):
raise NotImplementedError("An objective function needs to be passed.")
if self.encoding:
x = self.map_to_original(x)

# Only convert config if configspace is used + configuration has not been converted yet
if self.configspace:
if not isinstance(x, ConfigSpace.Configuration):
# converts [0, 1] vector to a ConfigSpace object
config = self.vector_to_configspace(x)
else:
config = x
# converts [0, 1] vector to a ConfigSpace object
config = self.vector_to_configspace(x)
else:
# can insert custom scaling/transform function here
config = x.copy()

if fidelity is not None: # to be used when called by multi-fidelity based optimizers
res = self.f(config, fidelity=fidelity, **kwargs)
else:
Expand Down
175 changes: 52 additions & 123 deletions src/dehb/optimizers/dehb.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,18 +245,14 @@ def _f_objective(self, job_info):
res = self.de[fidelity].f_objective(config, fidelity, **kwargs)
info = res["info"] if "info" in res else {}
run_info = {
"job_info": {
"config": config,
"config_id": config_id,
"fidelity": fidelity,
"parent_id": parent_id,
"bracket_id": bracket_id,
},
"result": {
"fitness": res["fitness"],
"cost": res["cost"],
"info": info,
},
"fitness": res["fitness"],
"cost": res["cost"],
"config": config,
"config_id": config_id,
"fidelity": fidelity,
"parent_id": parent_id,
"bracket_id": bracket_id,
"info": info,
}

if "gpu_devices" in job_info:
Expand Down Expand Up @@ -546,10 +542,7 @@ def _acquire_config(self, bracket, fidelity):
return config, config_id, parent_id

def _get_next_job(self):
"""Loads a configuration and fidelity to be evaluated next.
Returns:
dict: Dicitonary containing all necessary information of the next job.
""" Loads a configuration and fidelity to be evaluated next by a free worker
"""
bracket = None
if len(self.active_brackets) == 0 or \
Expand All @@ -571,50 +564,16 @@ def _get_next_job(self):
# fidelity that the SH bracket allots
fidelity = bracket.get_next_job_fidelity()
config, config_id, parent_id = self._acquire_config(bracket, fidelity)

# transform config to proper representation
if self.configspace:
# converts [0, 1] vector to a ConfigSpace object
config = self.de[fidelity].vector_to_configspace(config)

# notifies the Bracket Manager that a single config is to run for the fidelity chosen
job_info = {
"config": config,
"config_id": config_id,
"fidelity": fidelity,
"parent_id": parent_id,
"bracket_id": bracket.bracket_id,
"bracket_id": bracket.bracket_id
}

# pass information of job submission to Bracket Manager
for bracket in self.active_brackets:
if bracket.bracket_id == job_info['bracket_id']:
# registering is IMPORTANT for Bracket Manager to perform SH
bracket.register_job(job_info['fidelity'])
break
return job_info

def ask(self, n_configs: int=1):
"""Get the next configuration to run from the optimizer.
The retrieved configuration can then be evaluated by the user.
After evaluation use `tell` to report the results back to the optimizer.
For more information, please refer to the description of `tell`.
Args:
n_configs (int, optional): Number of configs to ask for. Defaults to 1.
Returns:
dict or list of dict: Job info(s) of next configuration to evaluate.
"""
if n_configs == 1:
return self._get_next_job()

jobs = []
for _ in range(n_configs):
jobs.append(self._get_next_job())
return jobs

def _get_gpu_id_with_low_load(self):
candidates = []
for k, v in self.gpu_usage.items():
Expand All @@ -635,7 +594,7 @@ def submit_job(self, job_info, **kwargs):
""" Asks a free worker to run the objective function on config and fidelity
"""
job_info["kwargs"] = self.shared_data if self.shared_data is not None else kwargs
# submit to Dask client
# submit to to Dask client
if self.n_workers > 1 or isinstance(self.client, Client):
if self.single_node_with_gpus:
# managing GPU allocation for the job to be submitted
Expand All @@ -647,6 +606,13 @@ def submit_job(self, job_info, **kwargs):
# skipping scheduling to Dask worker to avoid added overheads in the synchronous case
self.futures.append(self._f_objective(job_info))

# pass information of job submission to Bracket Manager
for bracket in self.active_brackets:
if bracket.bracket_id == job_info['bracket_id']:
# registering is IMPORTANT for Bracket Manager to perform SH
bracket.register_job(job_info['fidelity'])
break

def _fetch_results_from_workers(self):
""" Iterate over futures and collect results from finished workers
"""
Expand All @@ -670,20 +636,40 @@ def _fetch_results_from_workers(self):
else:
# Dask not invoked in the synchronous case
run_info = future
# tell result
self.tell(run_info["job_info"], run_info["result"])
# update bracket information
fitness, cost = run_info["fitness"], run_info["cost"]
info = run_info["info"] if "info" in run_info else dict()
fidelity, parent_id = run_info["fidelity"], run_info["parent_id"]
config, config_id = run_info["config"], run_info["config_id"]
bracket_id = run_info["bracket_id"]
for bracket in self.active_brackets:
if bracket.bracket_id == bracket_id:
# bracket job complete
bracket.complete_job(fidelity) # IMPORTANT to perform synchronous SH

self.config_repository.tell_result(config_id, fidelity, fitness, cost, info)

# carry out DE selection
if fitness <= self.de[fidelity].fitness[parent_id]:
self.de[fidelity].population[parent_id] = config
self.de[fidelity].population_ids[parent_id] = config_id
self.de[fidelity].fitness[parent_id] = fitness
# updating incumbents
if self.de[fidelity].fitness[parent_id] < self.inc_score:
self._update_incumbents(
config=self.de[fidelity].population[parent_id],
score=self.de[fidelity].fitness[parent_id],
info=info
)
# book-keeping
self._update_trackers(
traj=self.inc_score, runtime=cost, history=(
config.tolist(), float(fitness), float(cost), float(fidelity), info
)
)
# remove processed future
self.futures = np.delete(self.futures, [i for i, _ in done_list]).tolist()

def _adjust_budgets(self, fevals=None, brackets=None):
# only update budgets if it is not the first run
if fevals is not None and len(self.traj) > 0:
fevals = len(self.traj) + fevals
elif brackets is not None and self.iteration_counter > -1:
brackets = self.iteration_counter + brackets

return fevals, brackets

def _is_run_budget_exhausted(self, fevals=None, brackets=None, total_cost=None):
""" Checks if the DEHB run should be terminated or continued
"""
Expand Down Expand Up @@ -759,54 +745,6 @@ def _verbosity_runtime(self, fevals, brackets, total_cost):
"{}/{} {}".format(remaining[0], remaining[1], remaining[2])
)

def tell(self, job_info: dict, result: dict):
"""Feed a result back to the optimizer.
In order to correctly interpret the results, the `job_info` dict, retrieved by `ask`,
has to be given. Moreover, the `result` dict has to contain the keys `fitness` and `cost`.
It is also possible to add the field `info` to the `result` in order to store additional,
user-specific information.
Args:
job_info (dict): Job info returned by ask().
result (dict): Result dictionary with mandatory keys `fitness` and `cost`.
"""
# update bracket information
fitness, cost = result["fitness"], result["cost"]
info = result["info"] if "info" in result else dict()
fidelity, parent_id = job_info["fidelity"], job_info["parent_id"]
config, config_id = job_info["config"], job_info["config_id"]
bracket_id = job_info["bracket_id"]
for bracket in self.active_brackets:
if bracket.bracket_id == bracket_id:
# bracket job complete
bracket.complete_job(fidelity) # IMPORTANT to perform synchronous SH

self.config_repository.tell_result(config_id, fidelity, fitness, cost, info)

# get hypercube representation from config repo
if self.configspace:
config = self.config_repository.get(config_id)

# carry out DE selection
if fitness <= self.de[fidelity].fitness[parent_id]:
self.de[fidelity].population[parent_id] = config
self.de[fidelity].population_ids[parent_id] = config_id
self.de[fidelity].fitness[parent_id] = fitness
# updating incumbents
if self.de[fidelity].fitness[parent_id] < self.inc_score:
self._update_incumbents(
config=self.de[fidelity].population[parent_id],
score=self.de[fidelity].fitness[parent_id],
info=info
)
# book-keeping
self._update_trackers(
traj=self.inc_score, runtime=cost, history=(
config.tolist(), float(fitness), float(cost), float(fidelity), info
)
)

@logger.catch
def run(self, fevals=None, brackets=None, total_cost=None, single_node_with_gpus=False,
verbose=False, debug=False, save_intermediate=True, save_history=True, name=None, **kwargs):
Expand All @@ -823,12 +761,6 @@ def run(self, fevals=None, brackets=None, total_cost=None, single_node_with_gpus
2) Number of Successive Halving brackets run under Hyperband (brackets)
3) Total computational cost (in seconds) aggregated by all function evaluations (total_cost)
"""
# check if run has already been called before
if self.start is not None:
logger.warning("DEHB has already been run. Calling 'run' twice could lead to unintended"
+ " behavior. Please restart DEHB with an increased compute budget"
+ " instead of calling 'run' twice.")

# checks if a Dask client exists
if len(kwargs) > 0 and self.n_workers > 1 and isinstance(self.client, Client):
# broadcasts all additional data passed as **kwargs to all client workers
Expand All @@ -842,8 +774,7 @@ def run(self, fevals=None, brackets=None, total_cost=None, single_node_with_gpus
if self.single_node_with_gpus:
self.distribute_gpus()

self.start = self.start = time.time()
fevals, brackets = self._adjust_budgets(fevals, brackets)
self.start = time.time()
if verbose:
print("\nLogging at {} for optimization starting at {}\n".format(
os.path.join(os.getcwd(), self.log_filename),
Expand All @@ -855,11 +786,11 @@ def run(self, fevals=None, brackets=None, total_cost=None, single_node_with_gpus
if self._is_run_budget_exhausted(fevals, brackets, total_cost):
break
if self.is_worker_available():
job_info = self.ask()
job_info = self._get_next_job()
if brackets is not None and job_info["bracket_id"] >= brackets:
# ignore submission and only collect results
# when brackets are chosen as run budget, an extra bracket is created
# since iteration_counter is incremented in ask() and then checked
# since iteration_counter is incremented in _get_next_job() and then checked
# in _is_run_budget_exhausted(), therefore, need to skip suggestions
# coming from the extra allocated bracket
# _is_run_budget_exhausted() will not return True until all the lower brackets
Expand Down Expand Up @@ -919,6 +850,4 @@ def run(self, fevals=None, brackets=None, total_cost=None, single_node_with_gpus
self.logger.info("{}".format(self.inc_config))
self._save_incumbent(name)
self._save_history(name)
# reset waiting jobs of active bracket to allow for continuation
self.active_brackets[0].reset_waiting_jobs()
return np.array(self.traj), np.array(self.runtime), np.array(self.history, dtype=object)
13 changes: 0 additions & 13 deletions src/dehb/utils/bracket_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,6 @@ def is_waiting(self):
"""
return np.any([self._is_rung_waiting(i) > 0 for i, _ in enumerate(self.fidelities)])

def reset_waiting_jobs(self):
"""Resets all waiting jobs and updates the current_rung pointer accordingly."""
for i, fidelity in enumerate(self.fidelities):
pending = self.sh_bracket[fidelity]
done = self._sh_bracket[fidelity]
waiting = np.abs(self.n_configs[i] - pending - done)

# update current_rung pointer to the lowest rung with waiting jobs
if waiting > 0 and self.current_rung > i:
self.current_rung = i
# reset waiting jobs
self.sh_bracket[fidelity] += waiting

def __repr__(self):
cell_width = 10
cell = "{{:^{}}}".format(cell_width)
Expand Down
Loading

0 comments on commit 652fe0b

Please sign in to comment.