Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rng): Ensure reproduciblity through construction and reset #84

Merged
merged 3 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions src/dehb/optimizers/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@ class DEBase():
'''
def __init__(self, cs=None, f=None, dimensions=None, pop_size=None, max_age=None,
mutation_factor=None, crossover_prob=None, strategy=None,
boundary_fix_type='random', config_repository=None, seed=None, rng=None, **kwargs):
# Rng, either uses the rng passed by user/DEHB or creates its own
self.rng = rng if rng is not None else np.random.default_rng(seed)
boundary_fix_type='random', config_repository=None, seed=None, **kwargs):
if seed is None:
seed = int(np.random.default_rng().integers(0, 2**32 - 1))
elif isinstance(seed, np.random.Generator):
seed = int(seed.integers(0, 2**32 - 1))

assert isinstance(seed, int)

self._original_seed = seed
self.rng = np.random.default_rng(self._original_seed)

# Benchmark related variables
self.cs = cs
Expand All @@ -39,6 +46,7 @@ def __init__(self, cs=None, f=None, dimensions=None, pop_size=None, max_age=None
self.configspace = True if isinstance(self.cs, ConfigSpace.ConfigurationSpace) else False
self.hps = dict()
if self.configspace:
self.cs.seed(self._original_seed)
for i, hp in enumerate(cs.get_hyperparameters()):
# maps hyperparameter name to positional index in vector form
self.hps[hp.name] = i
Expand All @@ -61,14 +69,20 @@ def __init__(self, cs=None, f=None, dimensions=None, pop_size=None, max_age=None
self.history : list[object]
self.reset()

def reset(self):
def reset(self, *, reset_seeds: bool = True):
self.inc_score = np.inf
self.inc_config = None
self.inc_id = -1
self.population = None
self.population_ids = None
self.fitness = None
self.age = None

if reset_seeds:
if isinstance(self.cs, ConfigSpace.ConfigurationSpace):
self.cs.seed(self._original_seed)
self.rng = np.random.default_rng(self._original_seed)

self.history = []

def _shuffle_pop(self):
Expand Down Expand Up @@ -248,10 +262,10 @@ def run(self):
class DE(DEBase):
def __init__(self, cs=None, f=None, dimensions=None, pop_size=20, max_age=np.inf,
mutation_factor=None, crossover_prob=None, strategy='rand1_bin', encoding=False,
dim_map=None, seed=None, rng=None, config_repository=None, **kwargs):
dim_map=None, seed=None, config_repository=None, **kwargs):
super().__init__(cs=cs, f=f, dimensions=dimensions, pop_size=pop_size, max_age=max_age,
mutation_factor=mutation_factor, crossover_prob=crossover_prob,
strategy=strategy, seed=seed, rng=rng, config_repository=config_repository,
strategy=strategy, seed=seed, config_repository=config_repository,
**kwargs)
if self.strategy is not None:
self.mutation_strategy = self.strategy.split('_')[0]
Expand All @@ -276,8 +290,8 @@ def __del__(self):
if hasattr(self, "client") and isinstance(self.client, Client):
self.client.close()

def reset(self):
super().reset()
def reset(self, *, reset_seeds: bool = True):
super().reset(reset_seeds=reset_seeds)
self.traj = []
self.runtime = []
self.history = []
Expand Down
89 changes: 27 additions & 62 deletions src/dehb/optimizers/dehb.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,14 @@ def __init__(self, cs=None, f=None, dimensions=None, mutation_factor=None,
crossover_prob=None, strategy=None, min_fidelity=None,
max_fidelity=None, eta=None, min_clip=None, max_clip=None, seed=None,
boundary_fix_type='random', max_age=np.inf, **kwargs):
# Rng
self.rng = np.random.default_rng(seed)
if seed is None:
seed = int(np.random.default_rng().integers(0, 2**32 - 1))
elif isinstance(seed, np.random.Generator):
seed = int(seed.integers(0, 2**32 - 1))

assert isinstance(seed, int)
self._original_seed = seed
self.rng = np.random.default_rng(self._original_seed)

# Miscellaneous
self._setup_logger(kwargs)
Expand All @@ -39,6 +45,7 @@ def __init__(self, cs=None, f=None, dimensions=None, mutation_factor=None,
self.cs = cs
self.use_configspace = True if isinstance(self.cs, ConfigSpace.ConfigurationSpace) else False
if self.use_configspace:
self.cs.seed(self._original_seed)
self.dimensions = len(self.cs.get_hyperparameters())
elif dimensions is None or not isinstance(dimensions, (int, np.integer)):
assert "Need to specify `dimensions` as an int when `cs` is not available/specified!"
Expand All @@ -61,7 +68,10 @@ def __init__(self, cs=None, f=None, dimensions=None, mutation_factor=None,
"max_age": self.max_age,
"cs": self.cs,
"dimensions": self.dimensions,
"rng": self.rng,
# NOTE(eddiebergman): To make reset work, we pass
# in an explicitly generated seed at construction,
# rather than share the rng state
# "rng": self.rng,
"f": f,
}

Expand Down Expand Up @@ -117,14 +127,18 @@ def _pre_compute_fidelity_spacing(self):
-np.linspace(start=self.max_SH_iter - 1,
stop=0, num=self.max_SH_iter))

def reset(self):
def reset(self, *, reset_seeds: bool = True):
self.inc_score = np.inf
self.inc_config = None
self.population = None
self.fitness = None
self.traj = []
self.runtime = []
self.history = []
if reset_seeds:
if isinstance(self.cs, ConfigSpace.ConfigurationSpace):
self.cs.seed(self._original_seed)
self.rng = np.random.default_rng(self._original_seed)
self.logger.info("\n\nRESET at {}\n\n".format(time.strftime("%x %X %Z")))

def init_population(self):
Expand Down Expand Up @@ -243,11 +257,6 @@ def __init__(self, cs=None, f=None, dimensions=None, mutation_factor=0.5,
self.logger.warning("A checkpoint already exists, " \
"results could potentially be overwritten.")

# Save initial random state
self.random_state = self.rng.bit_generator.state
if self.use_configspace:
self.cs_random_state = self.cs.random.get_state()

def __getstate__(self):
"""Allows the object to picklable while having Dask client as a class attribute."""
d = dict(self.__dict__)
Expand Down Expand Up @@ -341,8 +350,8 @@ def configspace_to_vector(self, config):
assert len(self.fidelities) > 0
return self.de[self.fidelities[0]].configspace_to_vector(config)

def reset(self):
super().reset()
def reset(self, *, reset_seeds: bool = True):
super().reset(reset_seeds=reset_seeds)
if self.n_workers > 1 and hasattr(self, "client") and isinstance(self.client, Client):
self.client.restart()
else:
Expand Down Expand Up @@ -406,9 +415,10 @@ def _get_pop_sizes(self):
def _init_subpop(self):
"""List of DE objects corresponding to the fidelities."""
self.de = {}
for i, f in enumerate(self._max_pop_size.keys()):
seeds = self.rng.integers(0, 2**32 - 1, size=len(self._max_pop_size))
for (i, f), _seed in zip(enumerate(self._max_pop_size.keys()), seeds):
self.de[f] = AsyncDE(**self.de_params, pop_size=self._max_pop_size[f],
config_repository=self.config_repository)
config_repository=self.config_repository, seed=int(_seed))
self.de[f].population = self.de[f].init_population(pop_size=self._max_pop_size[f])
self.de[f].population_ids = self.config_repository.announce_population(self.de[f].population, f)
self.de[f].fitness = np.array([np.inf] * self._max_pop_size[f])
Expand Down Expand Up @@ -661,10 +671,7 @@ def ask(self, n_configs: int=1):
for _ in range(n_configs):
jobs.append(self._get_next_job())
self._ask_counter += 1
# Save random state after ask
self.random_state = self.rng.bit_generator.state
if self.use_configspace:
self.cs_random_state = self.cs.random.get_state()

return jobs

def _get_gpu_id_with_low_load(self):
Expand Down Expand Up @@ -738,9 +745,9 @@ def _get_state(self):
state = {}
# DE parameters
serializable_de_params = self.de_params.copy()
serializable_de_params.pop("cs")
serializable_de_params.pop("rng")
serializable_de_params.pop("f")
serializable_de_params.pop("cs", None)
serializable_de_params.pop("rng", None)
serializable_de_params.pop("f", None)
serializable_de_params["output_path"] = str(serializable_de_params["output_path"])
state["DE_params"] = serializable_de_params
# Hyperband variables
Expand Down Expand Up @@ -772,16 +779,6 @@ def _save_state(self):
with state_path.open("w") as f:
json.dump(state, f, indent=2)

# Write random state to disk
rnd_state_path = self.output_path / "random_state.pkl"
with rnd_state_path.open("wb") as f:
pickle.dump(self.random_state, f)

if self.use_configspace:
cs_rnd_path = self.output_path / "cs_random_state.pkl"
with cs_rnd_path.open("wb") as f:
pickle.dump(self.cs_random_state, f)


def _is_run_budget_exhausted(self, fevals=None, brackets=None, total_cost=None):
"""Checks if the DEHB run should be terminated or continued."""
Expand Down Expand Up @@ -901,28 +898,6 @@ def _load_checkpoint(self, run_dir: str):
return False
self.eta = hb_vars["eta"]

self._pre_compute_fidelity_spacing()
self._get_pop_sizes()
self._init_subpop()
# Reset ConfigRepo after initializing DE
self.config_repository.reset()
# Load initial configurations from config repository
config_repo_path = run_dir / "config_repository.json"
with config_repo_path.open() as f:
config_repo_list = json.load(f)
# Get initial configs
num_initial_configs = sum(self._max_pop_size.values())
initial_config_entries = config_repo_list[:num_initial_configs]
# Filter initial configs by fidelity
initial_configs_by_fidelity = {fidelity: [np.array(item["config"]) for item in initial_config_entries
if str(fidelity) in item["results"]]
for fidelity in self.fidelities}
# Add initial configs to DE and announce them to ConfigRepo
for fidelity, sub_pop in initial_configs_by_fidelity.items():
self.de[fidelity].population = np.array(sub_pop)
self.de[fidelity].population_ids = self.config_repository.announce_population(sub_pop,
fidelity)

# Load history
history_path = run_dir / "history.pkl"
with history_path.open("rb") as f:
Expand All @@ -944,16 +919,6 @@ def _load_checkpoint(self, run_dir: str):
self.tell(job_info, result, replay=True)
# Clean inactive brackets
self.clean_inactive_brackets()

# Load and set random state
rnd_state_path = run_dir / "random_state.pkl"
with rnd_state_path.open("rb") as f:
self.rng.bit_generator.state = pickle.load(f)

if self.use_configspace:
cs_rnd_state_path = run_dir / "cs_random_state.pkl"
with cs_rnd_state_path.open("rb") as f:
self.cs.random.set_state(pickle.load(f))
return True

def save(self):
Expand Down