Skip to content

Commit

Permalink
Merge pull request #392 from sunya-ch/server-api-rebase-patch-4
Browse files Browse the repository at this point in the history
feat: get type in /best-models API and pass source and type from estimator
  • Loading branch information
sthaha authored Aug 22, 2024
2 parents 3d96e35 + d414cf1 commit a0ad11f
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 6 deletions.
10 changes: 8 additions & 2 deletions src/kepler_model/estimate/model_server_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,17 @@ def make_request(power_request):
return unpack(power_request.energy_source, output_type, response)


def list_all_models():
def list_all_models(energy_source=None, node_type=None):
if not is_model_server_enabled():
return dict()
try:
response = requests.get(get_model_server_list_endpoint())
endpoint = get_model_server_list_endpoint()
params= {}
if energy_source:
params["source"] = energy_source
if node_type:
params["type"] = node_type
response = requests.get(endpoint, params=params)
except Exception as err:
print(f"cannot list model: {err}")
return dict()
Expand Down
10 changes: 9 additions & 1 deletion src/kepler_model/server/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def get_available_models():
fg = request.args.get("fg")
ot = request.args.get("ot")
energy_source = request.args.get("source")
node_type = request.args.get("type")
filter = request.args.get("filter")

try:
Expand All @@ -205,18 +206,25 @@ def get_available_models():
if energy_source is None or "rapl" in energy_source:
energy_source = "rapl-sysfs"

if node_type is None:
node_type = any_node_type
else:
node_type = int(node_type)

if filter is None:
filters = dict()
else:
filters = parse_filters(filter)

model_names = dict()
for output_type in output_types:
logger.debug(f"Searching output type {output_type}")
model_names[output_type.name] = dict()
for fg in valid_fgs:
logger.debug(f"Searching feature group {fg}")
valid_groupath = get_model_group_path(model_toppath, output_type, fg, energy_source, pipeline_name=pipelineName[energy_source])
if os.path.exists(valid_groupath):
best_candidate, _ = select_best_model(None, valid_groupath, filters, energy_source)
best_candidate, _ = select_best_model(None, valid_groupath, filters, energy_source, node_type=node_type)
if best_candidate is None:
continue
model_names[output_type.name][fg.name] = best_candidate["model_name"]
Expand Down
2 changes: 1 addition & 1 deletion tests/estimator_model_request_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_model_request():
energy_source = test_energy_source
# test getting model from server
os.environ["MODEL_SERVER_ENABLE"] = "true"
available_models = list_all_models()
available_models = list_all_models(energy_source=energy_source)
assert len(available_models) > 0, "must have more than one available models"
print("Available Models:", available_models)
for output_type_name, valid_fgs in available_models.items():
Expand Down
4 changes: 2 additions & 2 deletions tests/weight_model_request_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
os.environ["MODEL_SERVER_ENABLE"] = "true"
energy_source = test_energy_source

available_models = list_all_models()
available_models = list_all_models(energy_source=energy_source)
while len(available_models) == 0:
time.sleep(1)
print("wait for kepler model server response")
available_models = list_all_models()
available_models = list_all_models(energy_source=energy_source)

for output_type_name, valid_fgs in available_models.items():
output_type = ModelOutputType[output_type_name]
Expand Down

0 comments on commit a0ad11f

Please sign in to comment.