Skip to content

Commit

Permalink
Merge pull request #186 from automl/predictor_evaluator_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
arberzela authored Jul 3, 2024
2 parents 0b51bc1 + f60a56a commit 8cb5d2b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
21 changes: 16 additions & 5 deletions naslib/defaults/predictor_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from sklearn import metrics
import math

from naslib.predictors.zerocost import ZeroCost
from naslib.search_spaces.core.query_metrics import Metric
from naslib.utils import generate_kfold, cross_validation

from naslib import utils

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -47,6 +50,9 @@ def __init__(self, predictor, config=None):
self.num_arches_to_mutate = 5
self.max_mutation_rate = 3

# For ZeroCost proxies
self.dataloader = None

def adapt_search_space(
self, search_space, load_labeled, scope=None, dataset_api=None
):
Expand All @@ -70,6 +76,9 @@ def adapt_search_space(
"This search space is not yet implemented in PredictorEvaluator."
)

if isinstance(self.predictor, ZeroCost):
self.dataloader, _, _, _, _ = utils.get_train_val_loaders(self.config)

def get_full_arch_info(self, arch):
"""
Given an arch, return the accuracy, train_time,
Expand Down Expand Up @@ -139,10 +148,8 @@ def load_dataset(self, load_labeled=False, data_size=10, arch_hash_map={}):
arch.load_labeled_architecture(dataset_api=self.dataset_api)

arch_hash = arch.get_hash()
if False: # removing this for consistency, for now
continue
else:
arch_hash_map[arch_hash] = True

arch_hash_map[arch_hash] = True

accuracy, train_time, info_dict = self.get_full_arch_info(arch)
xdata.append(arch)
Expand Down Expand Up @@ -295,7 +302,11 @@ def single_evaluate(self, train_data, test_data, fidelity):
hyperparams = self.predictor.get_hyperparams()

fit_time_end = time.time()
test_pred = self.predictor.query(xtest, test_info)
if isinstance(self.predictor, ZeroCost):
[g.parse() for g in xtest] # parse the graphs because they will be used
test_pred = self.predictor.query_batch(xtest, self.dataloader)
else:
test_pred = self.predictor.query(xtest, test_info)
query_time_end = time.time()

# If the predictor is an ensemble, take the mean
Expand Down
14 changes: 12 additions & 2 deletions naslib/predictors/zerocost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
based on https://github.com/mohsaied/zero-cost-nas
"""
import torch
import numpy as np
import logging
import math

Expand All @@ -24,12 +25,21 @@ def __init__(self, method_type="jacov"):
self.num_imgs_or_batches = 1
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def query(self, graph, dataloader=None, info=None):
def query_batch(self, graphs, dataloader):
scores = []

for graph in graphs:
score = self.query(graph, dataloader)
scores.append(score)

return np.array(scores)

def query(self, graph, dataloader):
loss_fn = graph.get_loss_fn()

n_classes = graph.num_classes
score = predictive.find_measures(
net_orig=graph,
net_orig=graph.to(self.device),
dataloader=dataloader,
dataload_info=(self.dataload, self.num_imgs_or_batches, n_classes),
device=self.device,
Expand Down

0 comments on commit 8cb5d2b

Please sign in to comment.