Skip to content

Commit

Permalink
EvaluatorLTL python wrapper updates
Browse files Browse the repository at this point in the history
  • Loading branch information
acarcelik committed Jan 25, 2024
1 parent 6027c1f commit 17e72ca
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
6 changes: 3 additions & 3 deletions bark/python_wrapper/world/ltl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ void python_ltl(py::module m) {
using namespace bark::world::evaluation;

#ifdef LTL_RULES
py::class_<EvaluatorLTL, BaseEvaluator, std::shared_ptr<EvaluatorLTL>>(

py::class_<EvaluatorLTL, BaseEvaluator, PyEvaluatorLTL, std::shared_ptr<EvaluatorLTL>>(
m, "EvaluatorLTL")
.def(py::init<AgentId, const std::string&, const LabelFunctions&>(),
py::arg("agent_id"), py::arg("ltl_formula"),
py::arg("label_functions"))
.def("Evaluate", py::overload_cast<const ObservedWorld&>(&PyEvaluatorLTL::Evaluate))
.def_property_readonly("rule_states", &EvaluatorLTL::GetRuleStates)
.def_property_readonly("label_functions",
&EvaluatorLTL::GetLabelFunctions)
Expand All @@ -65,7 +65,7 @@ void python_ltl(py::module m) {
std::shared_ptr<BaseLabelFunction>>(m, "BaseLabelFunction")
.def(py::init<const std::string&>())
.def("Evaluate", &BaseLabelFunction::Evaluate);

py::class_<ConstantLabelFunction, BaseLabelFunction,
std::shared_ptr<ConstantLabelFunction>>(m, "ConstantLabelFunction")
.def(py::init<const std::string&>())
Expand Down
11 changes: 11 additions & 0 deletions bark/python_wrapper/world/ltl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ class PySafeDistanceLabelFunction : public SafeDistanceLabelFunction {
class PyEvaluatorLTL : public EvaluatorLTL {
public:
using EvaluatorLTL::Evaluate;
/* Inherit the constructors */
using EvaluatorLTL::EvaluatorLTL;

EvaluationReturn Evaluate(const ObservedWorld& observed_world) override {
PYBIND11_OVERLOAD(
EvaluationReturn,
EvaluatorLTL,
Evaluate,
observed_world
);
}
};

void python_ltl(py::module m);
Expand Down
2 changes: 2 additions & 0 deletions bark/world/world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ void World::AddObject(const objects::ObjectPtr& object) {
void World::AddEvaluator(const std::string& name,
const EvaluatorPtr& evaluator) {
evaluators_[name] = evaluator;
// std::cout << "AddEvaluator in World: " << name << ";" << evaluator << "\n";
}

EvaluationMap World::Evaluate() const {
Expand All @@ -152,6 +153,7 @@ EvaluationMap World::Evaluate() const {
std::vector<ObservedWorld> World::Observe(
const std::vector<AgentId>& agent_ids) const {
WorldPtr current_world(this->Clone());

std::vector<ObservedWorld> observed_worlds;
for (auto agent_id : agent_ids) {
if (agents_.find(agent_id) == agents_.end()) {
Expand Down

0 comments on commit 17e72ca

Please sign in to comment.