Skip to content

Commit

Permalink
feat: Add chained variable support to Python
Browse files Browse the repository at this point in the history
  • Loading branch information
Christopher-Chianelli authored and triceo committed Aug 26, 2024
1 parent 716e45e commit 6c7416c
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 115 deletions.
18 changes: 14 additions & 4 deletions python/python-core/src/main/python/domain/_annotations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from _jpyinterpreter import JavaAnnotation
from jpype import JImplements, JOverride
from enum import Enum
from jpype import JImplements, JOverride, JClass
from typing import Union, List, Callable, Type, TypeVar

from ._variable_listener import VariableListener
Expand Down Expand Up @@ -63,6 +64,15 @@ class PlanningPin:
pass


class PlanningVariableGraphType(Enum):
CHAINED = 'CHAINED'
NONE = 'NONE'

def _to_java_value(self):
return getattr(JClass('ai.timefold.solver.core.api.domain.variable.PlanningVariableGraphType'),
self.name)


class PlanningVariable(JavaAnnotation):
"""
Specifies that an attribute can be changed and should be optimized by the optimization algorithms.
Expand All @@ -83,13 +93,13 @@ class PlanningVariable(JavaAnnotation):
def __init__(self, *,
value_range_provider_refs: List[str] = None,
allows_unassigned: bool = False,
graph_type=None):
graph_type: PlanningVariableGraphType = PlanningVariableGraphType.NONE):
ensure_init()
from ai.timefold.solver.core.api.domain.variable import PlanningVariable as JavaPlanningVariable
super().__init__(JavaPlanningVariable,
{
'valueRangeProviderRefs': value_range_provider_refs,
'graphType': graph_type,
'graphType': graph_type._to_java_value(),
'allowsUnassigned': allows_unassigned
})

Expand Down Expand Up @@ -814,7 +824,7 @@ def constraint_configuration(constraint_configuration_class: Type[Solution_]) ->


__all__ = ['PlanningId', 'PlanningScore', 'PlanningPin', 'PlanningVariable',
'PlanningListVariable', 'ShadowVariable',
'PlanningVariableGraphType', 'PlanningListVariable', 'ShadowVariable',
'PiggybackShadowVariable', 'CascadingUpdateShadowVariable',
'IndexShadowVariable', 'PreviousElementShadowVariable', 'NextElementShadowVariable',
'AnchorShadowVariable', 'InverseRelationShadowVariable',
Expand Down
200 changes: 89 additions & 111 deletions python/python-core/tests/test_anchors.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,89 @@
# import timefold.solver
# import timefold.solver.score
# import timefold.solver.config
# import timefold.solver.constraint
# from timefold.solver.types import PlanningVariableGraphType
#
#
# @timefold.solver.problem_fact
# class ChainedObject:
# pass
#
#
# @timefold.solver.problem_fact
# class ChainedAnchor(ChainedObject):
# def __init__(self, code):
# self.code = code
#
#
# @timefold.solver.planning_entity
# class ChainedEntity(ChainedObject):
# def __init__(self, code, value=None, anchor=None):
# self.code = code
# self.value = value
# self.anchor = anchor
#
# @timefold.solver.planning_variable(ChainedObject, value_range_provider_refs=['chained_anchor_range',
# 'chained_entity_range'],
# graph_type=PlanningVariableGraphType.CHAINED)
# def get_value(self):
# return self.value
#
# def set_value(self, value):
# self.value = value
#
# @timefold.solver.anchor_shadow_variable(ChainedAnchor, source_variable_name='value')
# def get_anchor(self):
# return self.anchor
#
# def set_anchor(self, anchor):
# self.anchor = anchor
#
# def __str__(self):
# return f'ChainedEntity(code={self.code}, value={self.value}, anchor={self.anchor})'
#
#
# @timefold.solver.planning_solution
# class ChainedSolution:
# def __init__(self, anchors, entities, score=None):
# self.anchors = anchors
# self.entities = entities
# self.score = score
#
# @timefold.solver.problem_fact_collection_property(ChainedAnchor)
# @timefold.solver.value_range_provider('chained_anchor_range')
# def get_anchors(self):
# return self.anchors
#
# @timefold.solver.planning_entity_collection_property(ChainedEntity)
# @timefold.solver.value_range_provider('chained_entity_range')
# def get_entities(self):
# return self.entities
#
# @timefold.solver.planning_score(timefold.solver.score.SimpleScore)
# def get_score(self):
# return self.score
#
# def set_score(self, score):
# self.score = score
#
#
# @timefold.solver.constraint_provider
# def chained_constraints(constraint_factory: timefold.solver.constraint.ConstraintFactory):
# return [
# constraint_factory.for_each(ChainedEntity)
# .group_by(lambda entity: entity.anchor, timefold.solver.constraint.ConstraintCollectors.count())
# .reward('Maximize chain length', timefold.solver.score.SimpleScore.ONE,
# lambda anchor, count: count * count)
# ]
#
#
# def test_chained():
# termination = timefold.solver.config.solver.termination.TerminationConfig()
# termination.setBestScoreLimit('9')
# solver_config = timefold.solver.config.solver.SolverConfig() \
# .withSolutionClass(ChainedSolution) \
# .withEntityClasses(ChainedEntity) \
# .withConstraintProviderClass(chained_constraints) \
# .withTerminationConfig(termination)
# solver = timefold.solver.solver_factory_create(solver_config).buildSolver()
# solution = solver.solve(ChainedSolution(
# [
# ChainedAnchor('A'),
# ChainedAnchor('B'),
# ChainedAnchor('C')
# ],
# [
# ChainedEntity('1'),
# ChainedEntity('2'),
# ChainedEntity('3'),
# ]
# ))
# assert solution.score.score == 9
# anchor = solution.entities[0].anchor
# assert anchor is not None
# anchor_value_count = 0
# for entity in solution.entities:
# if entity.value == anchor:
# anchor_value_count += 1
# assert anchor_value_count == 1
# for entity in solution.entities:
# assert entity.anchor == anchor
from timefold.solver import *
from timefold.solver.config import *
from timefold.solver.domain import *
from timefold.solver.score import *
from typing import Annotated


class ChainedObject:
pass


class ChainedAnchor(ChainedObject):
def __init__(self, code):
self.code = code


@planning_entity
class ChainedEntity(ChainedObject):
value: Annotated[ChainedObject, PlanningVariable(graph_type=PlanningVariableGraphType.CHAINED,
value_range_provider_refs=['chained_anchor_range',
'chained_entity_range'])]
anchor: Annotated[ChainedAnchor, AnchorShadowVariable(source_variable_name='value')]

def __init__(self, code, value=None, anchor=None):
self.code = code
self.value = value
self.anchor = anchor

def __str__(self):
return f'ChainedEntity(code={self.code}, value={self.value}, anchor={self.anchor})'


@planning_solution
class ChainedSolution:
anchors: Annotated[
list[ChainedAnchor], ProblemFactCollectionProperty, ValueRangeProvider(id='chained_anchor_range')]
entities: Annotated[
list[ChainedEntity], PlanningEntityCollectionProperty, ValueRangeProvider(id='chained_entity_range')]
score: Annotated[SimpleScore, PlanningScore]

def __init__(self, anchors, entities, score=None):
self.anchors = anchors
self.entities = entities
self.score = score


@constraint_provider
def chained_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(ChainedEntity)
.group_by(lambda entity: entity.anchor, ConstraintCollectors.count())
.reward(SimpleScore.ONE, lambda anchor, count: count * count)
.as_constraint('Maximize chain length')
]


def test_chained():
termination = TerminationConfig(best_score_limit='9')
solver_config = SolverConfig(
solution_class=ChainedSolution,
entity_class_list=[ChainedEntity],
score_director_factory_config=ScoreDirectorFactoryConfig(
constraint_provider_function=chained_constraints
),
termination_config=termination
)
solver = SolverFactory.create(solver_config).build_solver()
solution = solver.solve(ChainedSolution(
[
ChainedAnchor('A'),
ChainedAnchor('B'),
ChainedAnchor('C')
],
[
ChainedEntity('1'),
ChainedEntity('2'),
ChainedEntity('3'),
]
))
assert solution.score.score == 9
anchor = solution.entities[0].anchor
assert anchor is not None
anchor_value_count = 0
for entity in solution.entities:
if entity.value == anchor:
anchor_value_count += 1
assert anchor_value_count == 1
for entity in solution.entities:
assert entity.anchor == anchor

0 comments on commit 6c7416c

Please sign in to comment.