Skip to content

Commit

Permalink
possibly retrieve stats via the ortoolscallback used in cpsat solvers…
Browse files Browse the repository at this point in the history
…. test this usage in facility unit test. Also test an alternative stats callback (from do api) that use the cpsat callback stored as attribute in the solver object (introduced in prev commit)
  • Loading branch information
g-poveda committed Oct 24, 2024
1 parent 10f58c1 commit e8f19a9
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 2 deletions.
26 changes: 24 additions & 2 deletions discrete_optimization/generic_tools/ortools_cpsat_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def solve(
parameters_cp: Optional[ParametersCp] = None,
time_limit: Optional[float] = 100.0,
ortools_cpsat_solver_kwargs: Optional[dict[str, Any]] = None,
retrieve_stats: bool = False,
**kwargs: Any,
) -> ResultStorage:
"""Solve the problem with a CpSat solver drom ortools library.
Expand All @@ -74,6 +75,8 @@ def solve(
We use here only `parameters_cp.nb_process`.
ortools_cpsat_solver_kwargs: used to customize the underlying ortools solver.
Each key/value will update the corresponding attribute from the ortools.sat.python.cp_model.CpSolver
retrieve_stats: retrieve detailed stats of cpsat solving in the cpsat callback
and store it in the res object.
**kwargs: keyword arguments passed to `self.init_model()`
Returns:
Expand Down Expand Up @@ -101,7 +104,9 @@ def solve(
# customize solver
for k, v in ortools_cpsat_solver_kwargs.items():
setattr(solver.parameters, k, v)
ortools_callback = OrtoolsCpSatCallback(do_solver=self, callback=callbacks_list)
ortools_callback = OrtoolsCpSatCallback(
do_solver=self, callback=callbacks_list, retrieve_stats=retrieve_stats
)
self.clb = ortools_callback
status = solver.Solve(self.cp_model, ortools_callback)
self.status_solver = cpstatus_to_dostatus(status_from_cpsat=status)
Expand Down Expand Up @@ -130,11 +135,19 @@ def remove_constraints(self, constraints: Iterable[Any]) -> None:


class OrtoolsCpSatCallback(CpSolverSolutionCallback):
def __init__(self, do_solver: OrtoolsCpSatSolver, callback: Callback):
def __init__(
self,
do_solver: OrtoolsCpSatSolver,
callback: Callback,
retrieve_stats: bool = False,
):
super().__init__()
self.do_solver = do_solver
self.callback = callback
self.retrieve_stats = retrieve_stats
self.res = do_solver.create_result_storage()
if retrieve_stats:
self.res.stats = []
self.nb_solutions = 0

def on_solution_callback(self) -> None:
Expand All @@ -160,6 +173,15 @@ def store_current_solution(self):
sol = self.do_solver.retrieve_solution(cpsolvercb=self)
fit = self.do_solver.aggreg_from_sol(sol)
self.res.append((sol, fit))
if self.retrieve_stats:
self.res.stats.append(
{
"bound": self.BestObjectiveBound(),
"obj": self.ObjectiveValue(),
"time": self.UserTime(),
"num_conflicts": self.NumConflicts(),
}
)


def cpstatus_to_dostatus(status_from_cpsat) -> StatusSolver:
Expand Down
111 changes: 111 additions & 0 deletions tests/fjsp/solvers/test_cpsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,78 @@
# LICENSE file in the root directory of this source tree.

import logging
import time
from typing import Optional

import discrete_optimization.fjsp.parser as fjsp_parser
import discrete_optimization.jsp.parser as jsp_parser
from discrete_optimization.fjsp.problem import Job
from discrete_optimization.fjsp.solvers.cpsat import CpSatFjspSolver, FJobShopProblem
from discrete_optimization.generic_tools.callbacks.callback import Callback
from discrete_optimization.generic_tools.cp_tools import ParametersCp
from discrete_optimization.generic_tools.ortools_cpsat_tools import OrtoolsCpSatSolver
from discrete_optimization.generic_tools.result_storage.result_storage import (
ResultStorage,
)
from discrete_optimization.jsp.problem import JobShopProblem

logging.basicConfig(level=logging.INFO)


class StatsCpsatCallback(Callback):
def __init__(self):
self.starting_time: int = None
self.end_time: int = None
self.stats: list[dict] = []
self.final_status: str = None

def on_step_end(
self, step: int, res: ResultStorage, solver: OrtoolsCpSatSolver
) -> Optional[bool]:
self.stats.append(
{
"obj": solver.clb.ObjectiveValue(),
"bound": solver.clb.BestObjectiveBound(),
"time": time.perf_counter() - self.starting_time,
"time-cpsat": {
"user-time": solver.clb.UserTime(),
"wall-time": solver.clb.WallTime(),
},
}
)
if solver.clb.ObjectiveValue() == solver.clb.BestObjectiveBound():
return False

def on_solve_start(self, solver: OrtoolsCpSatSolver):
self.starting_time = time.perf_counter()

def on_solve_end(self, res: ResultStorage, solver: OrtoolsCpSatSolver):
"""Called at the end of solve.
Args:
res: current result storage
solver: solvers using the callback
"""
status_name = solver.solver.status_name()
# status_do: StatusSolver = cpstatus_to_dostatus(status_name)
if len(self.stats) > 0:
if (
solver.solver.ObjectiveValue() != self.stats[-1]["obj"]
or solver.solver.BestObjectiveBound() != self.stats[-1]["bound"]
):
self.stats.append(
{
"obj": solver.solver.ObjectiveValue(),
"bound": solver.solver.BestObjectiveBound(),
"time": time.perf_counter() - self.starting_time,
"time-cpsat": {
"user-time": solver.solver.UserTime(),
"wall-time": solver.solver.WallTime(),
},
}
)
self.final_status = status_name


def test_fjsp_solver_on_jsp():
file_path = jsp_parser.get_data_available()[1]
# file_path = [f for f in get_data_available() if "abz6" in f][0]
Expand Down Expand Up @@ -52,3 +113,53 @@ def test_cpsat_fjsp():
)
sol, _ = res.get_best_solution_fit()
assert problem.satisfy(sol)


def test_cpsat_retrieve_stats():
files = fjsp_parser.get_data_available()
print(files)
file = [f for f in files if "Behnke60.fjs" in f][0]
print(file)
problem = fjsp_parser.parse_file(file)
print(problem)
solver = CpSatFjspSolver(problem=problem)
p = ParametersCp.default_cpsat()
p.nb_process = 10
res = solver.solve(
parameters_cp=p,
time_limit=5,
ortools_cpsat_solver_kwargs=dict(log_search_progress=True),
duplicate_temporal_var=True,
add_cumulative_constraint=True,
retrieve_stats=True,
)
assert res.stats is not None
assert res.stats[-1]["obj"] == solver.solver.ObjectiveValue()
# assert res.stats[-1]["bound"] == solver.solver.BestObjectiveBound()
sol, _ = res.get_best_solution_fit()
assert problem.satisfy(sol)


def test_cpsat_retrieve_stats_via_clb():
files = fjsp_parser.get_data_available()
file = [f for f in files if "Behnke60.fjs" in f][0]
problem = fjsp_parser.parse_file(file)
solver = CpSatFjspSolver(problem=problem)
p = ParametersCp.default_cpsat()
p.nb_process = 10
stats_clb = StatsCpsatCallback()
res = solver.solve(
callbacks=[stats_clb],
parameters_cp=p,
time_limit=20,
ortools_cpsat_solver_kwargs=dict(log_search_progress=True),
duplicate_temporal_var=True,
add_cumulative_constraint=True,
retrieve_stats=False,
)
assert stats_clb.stats is not None
assert stats_clb.stats[-1]["obj"] == solver.solver.ObjectiveValue()
print(stats_clb.stats)
# assert res.stats[-1]["bound"] == solver.solver.BestObjectiveBound()
sol, _ = res.get_best_solution_fit()
assert problem.satisfy(sol)

0 comments on commit e8f19a9

Please sign in to comment.