diff --git a/python/python-core/src/main/python/domain/_annotations.py b/python/python-core/src/main/python/domain/_annotations.py index 56c191962d..1c7109be07 100644 --- a/python/python-core/src/main/python/domain/_annotations.py +++ b/python/python-core/src/main/python/domain/_annotations.py @@ -1,5 +1,6 @@ from _jpyinterpreter import JavaAnnotation -from jpype import JImplements, JOverride +from enum import Enum +from jpype import JImplements, JOverride, JClass from typing import Union, List, Callable, Type, TypeVar from ._variable_listener import VariableListener @@ -63,6 +64,15 @@ class PlanningPin: pass +class PlanningVariableGraphType(Enum): + CHAINED = 'CHAINED' + NONE = 'NONE' + + def _to_java_value(self): + return getattr(JClass('ai.timefold.solver.core.api.domain.variable.PlanningVariableGraphType'), + self.name) + + class PlanningVariable(JavaAnnotation): """ Specifies that an attribute can be changed and should be optimized by the optimization algorithms. @@ -83,13 +93,13 @@ class PlanningVariable(JavaAnnotation): def __init__(self, *, value_range_provider_refs: List[str] = None, allows_unassigned: bool = False, - graph_type=None): + graph_type: PlanningVariableGraphType = PlanningVariableGraphType.NONE): ensure_init() from ai.timefold.solver.core.api.domain.variable import PlanningVariable as JavaPlanningVariable super().__init__(JavaPlanningVariable, { 'valueRangeProviderRefs': value_range_provider_refs, - 'graphType': graph_type, + 'graphType': graph_type._to_java_value(), 'allowsUnassigned': allows_unassigned }) @@ -814,7 +824,7 @@ def constraint_configuration(constraint_configuration_class: Type[Solution_]) -> __all__ = ['PlanningId', 'PlanningScore', 'PlanningPin', 'PlanningVariable', - 'PlanningListVariable', 'ShadowVariable', + 'PlanningVariableGraphType', 'PlanningListVariable', 'ShadowVariable', 'PiggybackShadowVariable', 'CascadingUpdateShadowVariable', 'IndexShadowVariable', 'PreviousElementShadowVariable', 'NextElementShadowVariable', 'AnchorShadowVariable', 'InverseRelationShadowVariable', diff --git a/python/python-core/tests/test_anchors.py b/python/python-core/tests/test_anchors.py index 671ea54b16..160b446037 100644 --- a/python/python-core/tests/test_anchors.py +++ b/python/python-core/tests/test_anchors.py @@ -1,111 +1,89 @@ -# import timefold.solver -# import timefold.solver.score -# import timefold.solver.config -# import timefold.solver.constraint -# from timefold.solver.types import PlanningVariableGraphType -# -# -# @timefold.solver.problem_fact -# class ChainedObject: -# pass -# -# -# @timefold.solver.problem_fact -# class ChainedAnchor(ChainedObject): -# def __init__(self, code): -# self.code = code -# -# -# @timefold.solver.planning_entity -# class ChainedEntity(ChainedObject): -# def __init__(self, code, value=None, anchor=None): -# self.code = code -# self.value = value -# self.anchor = anchor -# -# @timefold.solver.planning_variable(ChainedObject, value_range_provider_refs=['chained_anchor_range', -# 'chained_entity_range'], -# graph_type=PlanningVariableGraphType.CHAINED) -# def get_value(self): -# return self.value -# -# def set_value(self, value): -# self.value = value -# -# @timefold.solver.anchor_shadow_variable(ChainedAnchor, source_variable_name='value') -# def get_anchor(self): -# return self.anchor -# -# def set_anchor(self, anchor): -# self.anchor = anchor -# -# def __str__(self): -# return f'ChainedEntity(code={self.code}, value={self.value}, anchor={self.anchor})' -# -# -# @timefold.solver.planning_solution -# class ChainedSolution: -# def __init__(self, anchors, entities, score=None): -# self.anchors = anchors -# self.entities = entities -# self.score = score -# -# @timefold.solver.problem_fact_collection_property(ChainedAnchor) -# @timefold.solver.value_range_provider('chained_anchor_range') -# def get_anchors(self): -# return self.anchors -# -# @timefold.solver.planning_entity_collection_property(ChainedEntity) -# @timefold.solver.value_range_provider('chained_entity_range') -# def get_entities(self): -# return self.entities -# -# @timefold.solver.planning_score(timefold.solver.score.SimpleScore) -# def get_score(self): -# return self.score -# -# def set_score(self, score): -# self.score = score -# -# -# @timefold.solver.constraint_provider -# def chained_constraints(constraint_factory: timefold.solver.constraint.ConstraintFactory): -# return [ -# constraint_factory.for_each(ChainedEntity) -# .group_by(lambda entity: entity.anchor, timefold.solver.constraint.ConstraintCollectors.count()) -# .reward('Maximize chain length', timefold.solver.score.SimpleScore.ONE, -# lambda anchor, count: count * count) -# ] -# -# -# def test_chained(): -# termination = timefold.solver.config.solver.termination.TerminationConfig() -# termination.setBestScoreLimit('9') -# solver_config = timefold.solver.config.solver.SolverConfig() \ -# .withSolutionClass(ChainedSolution) \ -# .withEntityClasses(ChainedEntity) \ -# .withConstraintProviderClass(chained_constraints) \ -# .withTerminationConfig(termination) -# solver = timefold.solver.solver_factory_create(solver_config).buildSolver() -# solution = solver.solve(ChainedSolution( -# [ -# ChainedAnchor('A'), -# ChainedAnchor('B'), -# ChainedAnchor('C') -# ], -# [ -# ChainedEntity('1'), -# ChainedEntity('2'), -# ChainedEntity('3'), -# ] -# )) -# assert solution.score.score == 9 -# anchor = solution.entities[0].anchor -# assert anchor is not None -# anchor_value_count = 0 -# for entity in solution.entities: -# if entity.value == anchor: -# anchor_value_count += 1 -# assert anchor_value_count == 1 -# for entity in solution.entities: -# assert entity.anchor == anchor +from timefold.solver import * +from timefold.solver.config import * +from timefold.solver.domain import * +from timefold.solver.score import * +from typing import Annotated + + +class ChainedObject: + pass + + +class ChainedAnchor(ChainedObject): + def __init__(self, code): + self.code = code + + +@planning_entity +class ChainedEntity(ChainedObject): + value: Annotated[ChainedObject, PlanningVariable(graph_type=PlanningVariableGraphType.CHAINED, + value_range_provider_refs=['chained_anchor_range', + 'chained_entity_range'])] + anchor: Annotated[ChainedAnchor, AnchorShadowVariable(source_variable_name='value')] + + def __init__(self, code, value=None, anchor=None): + self.code = code + self.value = value + self.anchor = anchor + + def __str__(self): + return f'ChainedEntity(code={self.code}, value={self.value}, anchor={self.anchor})' + + +@planning_solution +class ChainedSolution: + anchors: Annotated[ + list[ChainedAnchor], ProblemFactCollectionProperty, ValueRangeProvider(id='chained_anchor_range')] + entities: Annotated[ + list[ChainedEntity], PlanningEntityCollectionProperty, ValueRangeProvider(id='chained_entity_range')] + score: Annotated[SimpleScore, PlanningScore] + + def __init__(self, anchors, entities, score=None): + self.anchors = anchors + self.entities = entities + self.score = score + + +@constraint_provider +def chained_constraints(constraint_factory: ConstraintFactory): + return [ + constraint_factory.for_each(ChainedEntity) + .group_by(lambda entity: entity.anchor, ConstraintCollectors.count()) + .reward(SimpleScore.ONE, lambda anchor, count: count * count) + .as_constraint('Maximize chain length') + ] + + +def test_chained(): + termination = TerminationConfig(best_score_limit='9') + solver_config = SolverConfig( + solution_class=ChainedSolution, + entity_class_list=[ChainedEntity], + score_director_factory_config=ScoreDirectorFactoryConfig( + constraint_provider_function=chained_constraints + ), + termination_config=termination + ) + solver = SolverFactory.create(solver_config).build_solver() + solution = solver.solve(ChainedSolution( + [ + ChainedAnchor('A'), + ChainedAnchor('B'), + ChainedAnchor('C') + ], + [ + ChainedEntity('1'), + ChainedEntity('2'), + ChainedEntity('3'), + ] + )) + assert solution.score.score == 9 + anchor = solution.entities[0].anchor + assert anchor is not None + anchor_value_count = 0 + for entity in solution.entities: + if entity.value == anchor: + anchor_value_count += 1 + assert anchor_value_count == 1 + for entity in solution.entities: + assert entity.anchor == anchor