diff --git a/README.md b/README.md
index f4ee4f2..156f229 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,25 @@
# edgeimpulse_ros
-ROS2 wrapper for Edge Impulse
+ROS2 wrapper for Edge Impulse on Linux.
-## How to install
+
+## 1. Topics
+
+- `/detection/input/image`, image topic to analyze
+- `/detection/output/image`, image with bounding boxes
+- `/detection/output/info`, VisionInfo message
+- `/detection/output/results`, results as text
+
+## 2. Parameters
+
+- `frame_id` (**string**), _"base_link"_, frame id of output topics
+- `model.filepath` (**string**), _""_, absolute filepath to .eim file
+- `show.overlay` (**bool**), _true_, show bounding boxes on output image
+- `show.labels` (**bool**), _true_, show labels on bounding boxes,
+- `show.classification_info` (**bool**), _true_, show the attendibility (0-1) of the prediction
+
+
+## 3. How to install
1. install edge_impulse_linux:
`pip3 install edge_impulse_linux`
@@ -30,12 +47,23 @@ ROS2 wrapper for Edge Impulse
`source install/setup.bash`
-## How to run
+## 4. How to run
Launch the node:
`ros2 run edgeimpulse_ros image_classification --ros-args -p model.filepath:="" -r /detection/input/image:="/your_image_topic"`
`
+## 5. Models
+
+Here you find some prebuilt models: [https://github.com/gbr1/edgeimpulse_example_models](https://github.com/gbr1/edgeimpulse_example_models)
+
+## 6. Known issues
+
+- this wrapper works on foxy, galactic and humble are coming soon (incompatibility on vision msgs by ros-perception)
+- if you use a classification model, topic results is empty
+- you cannot change color of bounding boxes (coming soon)
+- other types (imu and sound based ml) are unavailable
+
***Copyright © 2022 Giovanni di Dio Bruno - gbr1.github.io***
diff --git a/edgeimpulse_ros/__pycache__/__init__.cpython-38.pyc b/edgeimpulse_ros/__pycache__/__init__.cpython-38.pyc
index 93d905d..af7ee0b 100644
Binary files a/edgeimpulse_ros/__pycache__/__init__.cpython-38.pyc and b/edgeimpulse_ros/__pycache__/__init__.cpython-38.pyc differ
diff --git a/edgeimpulse_ros/__pycache__/image_classification.cpython-38.pyc b/edgeimpulse_ros/__pycache__/image_classification.cpython-38.pyc
index e1ecdbe..59650d8 100644
Binary files a/edgeimpulse_ros/__pycache__/image_classification.cpython-38.pyc and b/edgeimpulse_ros/__pycache__/image_classification.cpython-38.pyc differ
diff --git a/edgeimpulse_ros/image_classification.py b/edgeimpulse_ros/image_classification.py
index ab750fb..424103c 100644
--- a/edgeimpulse_ros/image_classification.py
+++ b/edgeimpulse_ros/image_classification.py
@@ -13,6 +13,7 @@
# limitations under the License.
+from distutils.log import info
from unittest import result
from .submodules import device_patches
import cv2
@@ -24,8 +25,10 @@
from sensor_msgs.msg import Image
-#from vision_msgs.msg import BoundingBox2DArray
-#from vision_msgs.msg import VisionInfo
+from vision_msgs.msg import Detection2DArray
+from vision_msgs.msg import Detection2D
+from vision_msgs.msg import ObjectHypothesisWithPose
+from vision_msgs.msg import VisionInfo
import os
import time
@@ -35,31 +38,46 @@
+
class EI_Image_node(Node):
+
def __init__(self):
self.occupied = False
self.img = None
+ self.info_msg = VisionInfo()
self.cv_bridge = CvBridge()
super().__init__('ei_image_classifier_node')
self.init_parameters()
self.ei_classifier = self.EI_Classifier(self.modelfile, self.get_logger())
- #self.publisher = self.create_publisher(BoundingBox2DArray,'/edge_impulse/detection',1)
+
+ self.info_msg.header.frame_id = self.frame_id
+ self.info_msg.method = self.ei_classifier.model_info['model_parameters']['model_type']
+ self.info_msg.database_location = self.ei_classifier.model_info['project']['name']+' / '+self.ei_classifier.model_info['project']['owner']
+ self.info_msg.database_version = self.ei_classifier.model_info['project']['deploy_version']
self.timer_parameter = self.create_timer(2,self.parameters_callback)
self.image_publisher = self.create_publisher(Image,'/detection/output/image',1)
+ self.results_publisher = self.create_publisher(Detection2DArray,'/detection/output/results',1)
+ self.info_publisher = self.create_publisher(VisionInfo, '/detection/output/info',1)
+
self.timer_classify = self.create_timer(0.01,self.classify_callback)
self.timer_classify.cancel()
self.subscription = self.create_subscription(Image,'/detection/input/image',self.listener_callback,1)
self.subscription
-
+
+
+
def init_parameters(self):
self.declare_parameter('model.filepath','')
self.modelfile= self.get_parameter('model.filepath').get_parameter_value().string_value
+ self.declare_parameter('frame_id','base_link')
+ self.frame_id= self.get_parameter('frame_id').get_parameter_value().string_value
+
self.declare_parameter('show.overlay', True)
self.show_overlay = self.get_parameter('show.overlay').get_parameter_value().bool_value
@@ -72,9 +90,6 @@ def init_parameters(self):
-
-
-
def parameters_callback(self):
self.show_labels_on_image = self.get_parameter('show.labels').get_parameter_value().bool_value
self.show_extra_classification_info = self.get_parameter('show.classification_info').get_parameter_value().bool_value
@@ -91,13 +106,25 @@ def listener_callback(self, msg):
self.img = current_frame
self.timer_classify.reset()
+
+
+
+
def classify_callback(self):
self.occupied = True
+
+ # vision msgs
+ results_msg = Detection2DArray()
+ time_now = self.get_clock().now().to_msg()
+ results_msg.header.stamp = time_now
+ results_msg.header.frame_id = self.frame_id
+
# classify
features, cropped, res = self.ei_classifier.classify(self.img)
- #prepare output
+
+ #p repare output
if "classification" in res["result"].keys():
if self.show_extra_classification_info:
self.get_logger().info('Result (%d ms.) ' % (res['timing']['dsp'] + res['timing']['classification']), end='')
@@ -112,6 +139,28 @@ def classify_callback(self):
self.get_logger().info('Found %d bounding boxes (%d ms.)' % (len(res["result"]["bounding_boxes"]), res['timing']['dsp'] + res['timing']['classification']))
for bb in res["result"]["bounding_boxes"]:
+ result_msg = Detection2D()
+ result_msg.header.stamp = time_now
+ result_msg.header.frame_id = self.frame_id
+
+ # object with hypthothesis
+ obj_hyp = ObjectHypothesisWithPose()
+ obj_hyp.id = bb['label'] #str(self.ei_classifier.labels.index(bb['label']))
+ obj_hyp.score = bb['value']
+ obj_hyp.pose.pose.position.x = float(bb['x'])
+ obj_hyp.pose.pose.position.y = float(bb['y'])
+ result_msg.results.append(obj_hyp)
+
+ # bounding box
+ result_msg.bbox.center.x = float(bb['x'])
+ result_msg.bbox.center.y = float(bb['y'])
+ result_msg.bbox.size_x = float(bb['width'])
+ result_msg.bbox.size_y = float(bb['height'])
+
+
+ results_msg.detections.append(result_msg)
+
+ # image
if self.show_extra_classification_info:
self.get_logger().info('%s (%.2f): x=%d y=%d w=%d h=%d' % (bb['label'], bb['value'], bb['x'], bb['y'], bb['width'], bb['height']))
if self.show_overlay:
@@ -124,6 +173,9 @@ def classify_callback(self):
# publish message
self.image_publisher.publish(self.cv_bridge.cv2_to_imgmsg(cropped,"bgr8"))
+ self.results_publisher.publish(results_msg)
+ self.info_msg.header.stamp = time_now
+ self.info_publisher.publish(self.info_msg)
self.occupied= False
self.timer_classify.cancel()
@@ -166,6 +218,8 @@ def classify(self, img):
self.logger.error('Error on classification')
+
+
def main():
rclpy.init()
node = EI_Image_node()
@@ -174,7 +228,6 @@ def main():
node.destroy_node()
rclpy.shutdown()
-
if __name__ == "__main__":
main()
diff --git a/edgeimpulse_ros/submodules/__pycache__/device_patches.cpython-38.pyc b/edgeimpulse_ros/submodules/__pycache__/device_patches.cpython-38.pyc
index 9346394..4c735e9 100644
Binary files a/edgeimpulse_ros/submodules/__pycache__/device_patches.cpython-38.pyc and b/edgeimpulse_ros/submodules/__pycache__/device_patches.cpython-38.pyc differ
diff --git a/package.xml b/package.xml
index 92eafb6..79ae856 100644
--- a/package.xml
+++ b/package.xml
@@ -14,7 +14,7 @@
ament_pep257
python3-pytest
-
+ vision_msgs
sensor_msgs
ros2launch
diff --git a/setup.py b/setup.py
index a843408..4d5aa3d 100644
--- a/setup.py
+++ b/setup.py
@@ -4,7 +4,7 @@
submodules = 'edgeimpulse_ros/submodules'
setup(
name=package_name,
- version='0.0.1',
+ version='0.0.2',
packages=[package_name, submodules],
data_files=[
('share/ament_index/resource_index/packages',