Skip to content

Commit

Permalink
enlarge goal naive and fix bug in ltl violation reset
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiangzhong Liu committed Dec 6, 2023
1 parent 97729a3 commit 4423a3d
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 9 deletions.
2 changes: 1 addition & 1 deletion bark_ml/environments/blueprints/single_lane/single_lane.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(self,
**kwargs)
def goal(self, world):
lane_corr = self._road_corridor.lane_corridors[0]
goal_polygon = Polygon2d([0, 0, 0], [Point2d(-1, -1), Point2d(-1, 1), Point2d(1, 1), Point2d(1, -1)])
goal_polygon = Polygon2d([0, 0, 0], [Point2d(-20, -2.5), Point2d(-20, 2.5), Point2d(5, 2.5), Point2d(5, -2.5)])
goal_polygon = goal_polygon.Translate(Point2d(lane_corr.center_line.ToArray()[-1, 0], lane_corr.center_line.ToArray()[-1, 1]))
return GoalDefinitionPolygon(goal_polygon)

Expand Down
5 changes: 0 additions & 5 deletions bark_ml/evaluators/evaluator_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,3 @@ def __init__(self, params):
eval_fns[functor_n_] = eval("{}(rule_config)".format("TrafficRuleLTLFunctor"))

super().__init__(params=self._params, bark_eval_fns=bark_evals, bark_ml_eval_fns=eval_fns)



def addKeyFunctorPair(self,functor_name,key_name):
self._fn_key_map[functor_name] = key_name
6 changes: 3 additions & 3 deletions bark_ml/evaluators/general_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,19 +401,19 @@ def __init__(self, params):

def __call__(self, observed_world, action, eval_results):
self.traffic_rule_violation_post = eval_results[self._params["RuleName"]]
print("Result from evaluatorLTL:", self.traffic_rule_violation_post)
# print("Result from evaluatorLTL:", self.traffic_rule_violation_post)
max_vio_num = self._params["ViolationTolerance","",15]
if self.traffic_rule_violation_post < self.traffic_rule_violation_pre:
self.traffic_rule_violation_pre = self.traffic_rule_violation_post
current_traffic_rule_violations = self.traffic_rule_violation_post - self.traffic_rule_violation_pre
self.traffic_rule_violations = self.traffic_rule_violations + current_traffic_rule_violations
self.traffic_rule_violation_pre = self.traffic_rule_violation_post
print("current traffic rule violations:", self.traffic_rule_violations)
# print("current traffic rule violations:", self.traffic_rule_violations)
if self.traffic_rule_violations > max_vio_num:
return True, 0, {}
return False, self.WeightedReward(current_traffic_rule_violations/max_vio_num), {}
def Reset(self):
self.traffic_rule_violation_pre = 0
# self.traffic_rule_violation_pre = 0
self.traffic_rule_violation_post = 0
self.traffic_rule_violations = 0
super().Reset()
Expand Down

0 comments on commit 4423a3d

Please sign in to comment.