Skip to content

Commit

Permalink
Merge branch 'ltl_configue' of https://github.com/bark-simulator/bark-ml
Browse files Browse the repository at this point in the history
 into ltl_configue,after adding stl_evaluator
  • Loading branch information
Xiangzhong Liu committed Dec 4, 2023
2 parents fd9eb6a + 78023a4 commit 5a7c17b
Show file tree
Hide file tree
Showing 10 changed files with 616 additions and 3 deletions.
4 changes: 4 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@ load("@benchmark_database//util:deps.bzl", "benchmark_database_dependencies")
load("@benchmark_database//load:load.bzl", "benchmark_database_release")
benchmark_database_dependencies()
benchmark_database_release()

# ------------------- LTL RuleMonitor --------------
load("@rule_monitor_project//util:deps.bzl", "rule_monitor_dependencies")
rule_monitor_dependencies()
4 changes: 3 additions & 1 deletion bark_ml/evaluators/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ py_library(
"evaluator_configs.py"],
data = ["@bark_project//bark/python_wrapper:core.so"],
imports = ["../external/bark_project/bark/python_wrapper/"],
deps = ["@bark_project//bark/runtime:runtime"],
deps = ["@bark_project//bark/runtime:runtime",
"//bark_ml/evaluators/stl:evaluator_stl"
],
visibility = ["//visibility:public"],
)

Expand Down
10 changes: 10 additions & 0 deletions bark_ml/evaluators/stl/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
py_library(
name = "evaluator_stl",
srcs = ["__init__.py",
"evaluator_stl.py",
"label_functions/base_label_function.py",
"label_functions/safe_distance_label_function.py"],
data = ["@bark_project//bark/python_wrapper:core.so"],
imports = ["../../external/bark_project/bark/python_wrapper/"],
visibility = ["//visibility:public"],
)
Empty file.
29 changes: 29 additions & 0 deletions bark_ml/evaluators/stl/evaluator_stl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from bark.core.world.evaluation.ltl import EvaluatorLTL
from bark_ml.evaluators.stl.label_functions.base_label_function import BaseQuantizedLabelFunction

class EvaluatorSTL(EvaluatorLTL):
def __init__(self, agent_id: int, ltl_formula_str: str, label_functions):
super().__init__(agent_id, ltl_formula_str, label_functions)
self.robustness = float('inf')

def Evaluate(self, observed_world):
eval_return = super().Evaluate(observed_world)
# print(f"Evaluate return: {eval_return}")
# print(f"Evaluate safety_violations: {super().safety_violations}")
# TODO: Should we remove the # of safety violations? We should subtract the robustness, shouldn't we?
eval_return = eval_return - self.compute_robustness()
# print(f"Evaluate return updated: {eval_return}")
return eval_return

def compute_robustness(self):
self.robustness = float('inf')

for le in self.label_functions:
if isinstance(le, BaseQuantizedLabelFunction):
self.robustness = min(self.robustness, le.get_current_robustness())

if self.robustness == float('inf') or self.robustness == float('-inf'):
self.robustness = 0.0

# print(f'Robustness in EvaluatorSTL: {self.robustness}')
return self.robustness
6 changes: 6 additions & 0 deletions bark_ml/evaluators/stl/label_functions/base_label_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class BaseQuantizedLabelFunction():
def __init__(self, robustness: float = float('-inf')):
self.robustness = robustness

def get_current_robustness(self):
return self.robustness
210 changes: 210 additions & 0 deletions bark_ml/evaluators/stl/label_functions/safe_distance_label_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import sys
from bark.core.world.evaluation.ltl import SafeDistanceLabelFunction
# TODO
from rtamt.spec.stl.discrete_time.specification import StlDiscreteTimeSpecification
import logging
import rtamt
from bark_ml.evaluators.stl.label_functions.base_label_function import BaseQuantizedLabelFunction

class SafeDistanceQuantizedLabelFunction(SafeDistanceLabelFunction, BaseQuantizedLabelFunction):
robustness_min = float('inf')
robustness_max = float('-inf')

def __init__(self, label_str: str, to_rear: bool, delta_ego: float, delta_others: float, a_e: float, a_o: float,
consider_crossing_corridors: bool, max_agents_for_crossing: int, use_frac_param_from_world: bool,
lateral_difference_threshold: float, angle_difference_threshold: float, check_lateral_dist: bool):
super().__init__(label_str, to_rear, delta_ego, delta_others, a_e, a_o, consider_crossing_corridors,
max_agents_for_crossing, use_frac_param_from_world, lateral_difference_threshold,
angle_difference_threshold, check_lateral_dist)
self.initialize_specs()

self.robustness_lon = float('-inf')
self.robustness_lat = float('-inf')
self.robustness = float('-inf')

def initialize_specs(self):
self.stl_spec_timestep = 0
self.stl_spec_lon_checked = False
self.stl_spec_lat_checked = False

self.stl_spec_lon = StlDiscreteTimeSpecification()
self.stl_spec_lon.declare_var('dist', 'float')
self.stl_spec_lon.declare_var('safe_dist_0', 'float')
self.stl_spec_lon.declare_var('safe_dist_1', 'float')
self.stl_spec_lon.declare_var('safe_dist_2', 'float')
self.stl_spec_lon.declare_var('safe_dist_3', 'float')
self.stl_spec_lon.declare_var('delta', 'float')
self.stl_spec_lon.declare_var('t_stop_f', 'float')
self.stl_spec_lon.declare_var('t_stop_f_star', 'float')
self.stl_spec_lon.declare_var('a_f', 'float')
self.stl_spec_lon.declare_var('a_r', 'float')
self.stl_spec_lon.declare_var('v_f_star', 'float')
self.stl_spec_lon.declare_var('v_r', 'float')
self.stl_spec_lon.declare_var('t_stop_r', 'float')

# TODO Sampling period should be parametric
self.stl_spec_lon.unit = 's'
self.stl_spec_lon.set_sampling_period(200, 'ms', 0.1)

formula_lon = "(dist < 0.0)" \
+ ' or ((dist > safe_dist_0 or (delta <= t_stop_f and dist > safe_dist_3))' \
+ ' or ((delta <= t_stop_f and a_f > a_r and v_f_star < v_r and t_stop_r < t_stop_f_star)) and (dist > safe_dist_2))' \
+ ' or (dist > safe_dist_1)'

self.stl_spec_lon.spec = formula_lon

try:
self.stl_spec_lon.parse()
self.stl_spec_lon.pastify()
except rtamt.RTAMTException as err:
logging.info('RTAMT Exception: {}'.format(err))
sys.exit()

self.stl_spec_lat = StlDiscreteTimeSpecification()
self.stl_spec_lat.declare_var('dist_lat', 'float')
self.stl_spec_lat.declare_var('lateral_positive', 'float')
self.stl_spec_lat.declare_var('v_1_lat', 'float')
self.stl_spec_lat.declare_var('v_2_lat', 'float')
self.stl_spec_lat.declare_var('min_lat_safe_dist', 'float')

# TODO Sampling period should be parametric
self.stl_spec_lat.unit = 's'
self.stl_spec_lat.set_sampling_period(200, 'ms', 0.1)

formula_lat = 'dist_lat !== 0.0 and' \
+ ' ((v_1_lat >= 0.0 and v_2_lat <= 0.0 and dist_lat < 0.0)' \
+ ' or (v_1_lat <= 0.0 and v_2_lat >= 0.0 and dist_lat > 0.0)' \
+ ' or (lateral_positive > min_lat_safe_dist))'

self.stl_spec_lat.spec = formula_lat

try:
self.stl_spec_lat.parse()
self.stl_spec_lat.pastify()
except rtamt.RTAMTException as err:
logging.info('RTAMT Exception: {}'.format(err))
sys.exit()

logging.info("Successfully parsed the SD STL formulas")

def compute_robustness(self, eval_result):
safe_distance = eval_result

if (not self.stl_spec_lon_checked and not self.stl_spec_lat_checked):
if safe_distance:
if SafeDistanceQuantizedLabelFunction.robustness_max >= 0.0:
self.robustness = SafeDistanceQuantizedLabelFunction.robustness_max
else:
# TODO: Should be taken from configuration.
self.robustness = 1.0
else:
if SafeDistanceQuantizedLabelFunction.robustness_min <= 0.0:
self.robustness = SafeDistanceQuantizedLabelFunction.robustness_min
else:
# TODO: Should be taken from configuration.
self.robustness = -1.0
elif (self.stl_spec_lon_checked and self.stl_spec_lat_checked):
if safe_distance and (self.robustness_lon < 0.0 or self.robustness_lat < 0.0):
self.robustness = max(self.robustness_lon, self.robustness_lat)
else:
self.robustness = min(self.robustness_lon, self.robustness_lat)
elif (self.stl_spec_lon_checked):
self.robustness = self.robustness_lon

logging.info(f"Current robustness for SD={self.robustness}")

# print(f'Robustness in Label Function: {self.robustness}')

if self.robustness > SafeDistanceQuantizedLabelFunction.robustness_max:
SafeDistanceQuantizedLabelFunction.robustness_max = self.robustness

if self.robustness < SafeDistanceQuantizedLabelFunction.robustness_min:
SafeDistanceQuantizedLabelFunction.robustness_min = self.robustness

def Evaluate(self, observed_world):
self.stl_spec_timestep = observed_world.time
self.stl_spec_lon_checked = False
self.stl_spec_lat_checked = False

eval_result = super().Evaluate(observed_world)

self.compute_robustness(next(iter(eval_result.values())))

return eval_result

def CheckSafeDistanceLongitudinal(self, v_f: float, v_r: float, dist: float, a_r: float, a_f: float, delta: float):
self.stl_spec_lon_checked = True

v_f_star = self.CalcVelFrontStar(v_f, a_f, delta)
t_stop_f_star = -v_f_star / a_r
t_stop_r = -v_r / a_r
t_stop_f = -v_f / a_f

ZeroToPositive = lambda safe_dist: 0.0 if safe_dist < 0.0 else safe_dist
safe_dist_0 = ZeroToPositive(self.CalcSafeDistance0(v_r, a_r, delta))
safe_dist_1 = ZeroToPositive(self.CalcSafeDistance1(v_r, v_f, a_r, a_f, delta))
safe_dist_2 = ZeroToPositive(self.CalcSafeDistance2(v_r, v_f, a_r, a_f, delta))
safe_dist_3 = ZeroToPositive(self.CalcSafeDistance3(v_r, v_f, a_r, a_f, delta))
logging.info(f"sf0={safe_dist_0}, sf1={safe_dist_1}, sf2={safe_dist_2}, sf3={safe_dist_3}")

# Updating STL monitor
self.robustness_lon = self.stl_spec_lon.update(self.stl_spec_timestep, [('dist', dist),
('safe_dist_0', safe_dist_0),
('safe_dist_1', safe_dist_1),
('safe_dist_2', safe_dist_2),
('safe_dist_3', safe_dist_3),
('delta', delta),
('t_stop_f', t_stop_f),
('t_stop_f_star', t_stop_f_star),
('a_f', a_f),
('a_r', a_r),
('v_f_star', v_f_star),
('v_r', v_r),
('t_stop_r', t_stop_r)])
# print(f"CheckSafeDistanceLongitudinal: Robustness STL spec result in the label function: {self.robustness_lon}")

safe_distance_lon = self.robustness_lon > 0.0

if self.robustness_lon == 0.0:
safe_distance_lon = super().CheckSafeDistanceLongitudinal(v_f, v_r, dist, a_r, a_f, delta)

return safe_distance_lon

def CheckSafeDistanceLateral(self, v_1_lat: float, v_2_lat: float, dist_lat: float, a_1_lat: float, a_2_lat: float, delta1: float, delta2: float):
# return super().CheckSafeDistanceLateral(v_1_lat, v_2_lat, dist_lat, a_1_lat, a_2_lat, delta1, delta2)
self.stl_spec_lat_checked = True

# For convention of RSS paper, make v_1_lat be larger (e.g. positive compared to v_2_lat) ...
v_1_lat_orig = v_1_lat
v_2_lat_orig = v_2_lat

if v_1_lat < v_2_lat:
v_1_lat, v_2_lat = v_2_lat, v_1_lat
delta1, delta2 = delta2, delta1
a_1_lat, a_2_lat = a_2_lat, a_1_lat

# ... lateral distance positive
lateral_positive = abs(dist_lat)

min_lat_safe_dist = (
v_1_lat * delta1 +
(v_1_lat * delta1 if v_1_lat == 0.0 else v_1_lat * v_1_lat / (2 * a_1_lat)) -
(v_2_lat * delta2 - (v_2_lat * delta2 if v_2_lat == 0.0 else v_2_lat * v_2_lat / (2 * a_2_lat)))
)
logging.info("Min lat safe dist:", min_lat_safe_dist)

# Updating STL monitor
self.robustness_lat = self.stl_spec_lat.update(self.stl_spec_timestep, [('dist_lat', dist_lat),
('lateral_positive', lateral_positive),
('v_1_lat', v_1_lat_orig),
('v_2_lat', v_2_lat_orig),
('min_lat_safe_dist', min_lat_safe_dist)
])
# print(f"CheckSafeDistanceLateral: Robustness STL spec result in the label function: {self.robustness_lat}")

safe_distance_lat = self.robustness_lat > 0.0

if self.robustness_lat == 0.0:
safe_distance_lat = super().CheckSafeDistanceLateral(v_1_lat, v_2_lat, dist_lat, a_1_lat, a_2_lat, delta1, delta2)

return safe_distance_lat
13 changes: 13 additions & 0 deletions bark_ml/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@ py_test(
visibility = ["//visibility:public"],
)

py_test(
name = "py_label_function_tests",
srcs = ["py_label_function_tests.py"],
data = ["@bark_project//bark/python_wrapper:core.so",
"//bark_ml:generate_core"],
imports = ["../external/bark_project/bark/python_wrapper/",
"../python_wrapper/"],
deps = ["//bark_ml/environments:single_agent_runtime",
"//bark_ml/behaviors:behaviors",
"//bark_ml/commons:py_spaces"],
visibility = ["//visibility:public"],
)

py_test(
name = "py_evaluator_tests",
srcs = ["py_evaluator_tests.py"],
Expand Down
Loading

0 comments on commit 5a7c17b

Please sign in to comment.