From cec439e359a35d21d8df75086f7609ed8028c68a Mon Sep 17 00:00:00 2001 From: usedhondacivic Date: Mon, 28 Oct 2024 11:38:00 -0400 Subject: [PATCH 1/3] Change hal to use tcp --- little_red_rover/little_red_rover/hal.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/little_red_rover/little_red_rover/hal.py b/little_red_rover/little_red_rover/hal.py index 01e5662..430f6b3 100644 --- a/little_red_rover/little_red_rover/hal.py +++ b/little_red_rover/little_red_rover/hal.py @@ -1,6 +1,7 @@ import rospy +import struct -from math import floor, inf, pi +from math import inf, pi from sensor_msgs.msg import Imu, JointState from sensor_msgs.msg import LaserScan @@ -13,12 +14,8 @@ class HAL: def __init__(self): - self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind(("0.0.0.0", 8001)) - - self.send_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.connect(("192.168.4.1", 8001)) self.subscription = rospy.Subscriber("cmd_vel", Twist, self.cmd_vel_callback) self.joint_state_publisher = rospy.Publisher( @@ -50,7 +47,9 @@ def run_loop(self): while not rospy.is_shutdown(): self.socket.settimeout(1.0) try: - data = self.socket.recv(1500) + length = struct.unpack("I", self.socket.recv(4))[0] + # print(length) + data = self.socket.recv(length) packet = messages.UdpPacket() packet.ParseFromString(data) except socket.timeout: @@ -135,7 +134,7 @@ def cmd_vel_callback(self, msg: Twist): packet.cmd_vel.v = msg.linear.x packet.cmd_vel.w = msg.angular.z - self.send_socket.sendto(packet.SerializeToString(), ("192.168.4.1", 8001)) + self.socket.send(packet.SerializeToString()) def main(args=None): From cc109b7118a59a831a28298f9c3a4beb2877d93a Mon Sep 17 00:00:00 2001 From: usedhondacivic Date: Mon, 28 Oct 2024 17:14:53 -0400 Subject: [PATCH 2/3] Improve TCP error handling, decompose peripherals into seperate classes --- docker/docker-compose.yml | 5 +- .../little_red_rover/drive_base_peripheral.py | 40 +++++ little_red_rover/little_red_rover/hal.py | 145 +++--------------- .../little_red_rover/imu_peripheral.py | 32 ++++ .../little_red_rover/lidar_peripheral.py | 62 ++++++++ .../little_red_rover/pb/messages.proto | 2 +- .../little_red_rover/pb/messages_pb2.py | 6 +- .../little_red_rover/pb/messages_pb2.pyi | 2 +- .../little_red_rover/rover_connection.py | 53 +++++++ 9 files changed, 218 insertions(+), 129 deletions(-) create mode 100644 little_red_rover/little_red_rover/drive_base_peripheral.py create mode 100644 little_red_rover/little_red_rover/imu_peripheral.py create mode 100644 little_red_rover/little_red_rover/lidar_peripheral.py create mode 100644 little_red_rover/little_red_rover/rover_connection.py diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index d2b34e3..14ba7d0 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -5,12 +5,9 @@ services: context: .. dockerfile: docker/ros_workspace.Dockerfile tty: true # Prevent immediate exit when running with dev containers - network_mode: "host" volumes: - ../little_red_rover:/little_red_rover_ws/src/little_red_rover - ../tools:/tools ports: - - "9002:9002" # gzweb - - "9090:9090" # rosbridge - - "8001:8001/udp" # agent -> rover + - "8001:8001" # agent -> rover - "8765:8765" # foxglove bridge diff --git a/little_red_rover/little_red_rover/drive_base_peripheral.py b/little_red_rover/little_red_rover/drive_base_peripheral.py new file mode 100644 index 0000000..137c21f --- /dev/null +++ b/little_red_rover/little_red_rover/drive_base_peripheral.py @@ -0,0 +1,40 @@ +import rospy + +from sensor_msgs.msg import JointState +from geometry_msgs.msg import Twist + +import little_red_rover.pb.messages_pb2 as messages +from little_red_rover.rover_connection import RoverConnection + + +class DriveBasePeripheral: + def __init__(self, connection: RoverConnection): + self.joint_state_publisher = rospy.Publisher( + "joint_states", JointState, queue_size=10 + ) + + self.joint_state_msg = JointState() + self.joint_state_msg.header.frame_id = "robot_body" + self.subscription = rospy.Subscriber("cmd_vel", Twist, self.cmd_vel_callback) + + self.connection = connection + + def handle_packet(self, packet: messages.NetworkPacket): + if not packet.HasField("joint_states"): + return + + self.joint_state_msg.header.stamp.set( + packet.joint_states.time.sec, packet.joint_states.time.nanosec + ) + self.joint_state_msg.name = list(packet.joint_states.name) + self.joint_state_msg.effort = packet.joint_states.effort + self.joint_state_msg.position = packet.joint_states.position + self.joint_state_msg.velocity = packet.joint_states.velocity + self.joint_state_publisher.publish(self.joint_state_msg) + + def cmd_vel_callback(self, msg: Twist): + packet = messages.NetworkPacket() + packet.cmd_vel.v = msg.linear.x + packet.cmd_vel.w = msg.angular.z + + self.connection.send(packet.SerializeToString()) diff --git a/little_red_rover/little_red_rover/hal.py b/little_red_rover/little_red_rover/hal.py index 430f6b3..9b03a84 100644 --- a/little_red_rover/little_red_rover/hal.py +++ b/little_red_rover/little_red_rover/hal.py @@ -1,143 +1,48 @@ +from google.protobuf.message import DecodeError import rospy -import struct - -from math import inf, pi - -from sensor_msgs.msg import Imu, JointState -from sensor_msgs.msg import LaserScan -from geometry_msgs.msg import Twist import threading -import socket import little_red_rover.pb.messages_pb2 as messages +from little_red_rover.rover_connection import RoverConnection +from little_red_rover.lidar_peripheral import LidarPeripheral +from little_red_rover.drive_base_peripheral import DriveBasePeripheral +from little_red_rover.imu_peripheral import ImuPeripheral + class HAL: def __init__(self): - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.connect(("192.168.4.1", 8001)) - self.subscription = rospy.Subscriber("cmd_vel", Twist, self.cmd_vel_callback) + self.connection = RoverConnection(("192.168.4.1", 8001)) - self.joint_state_publisher = rospy.Publisher( - "joint_states", JointState, queue_size=10 - ) - self.scan_publisher = rospy.Publisher("scan", LaserScan, queue_size=10) - self.imu_publisher = rospy.Publisher("imu/data_raw", Imu, queue_size=10) + self.peripherals = [ + LidarPeripheral(), + DriveBasePeripheral(self.connection), + ImuPeripheral(), + ] - self.ranges = [0.0] * 720 - self.intensities = [0.0] * 720 + self.decode_error_count = 0 threading.Thread(target=self.run_loop).start() def run_loop(self): - laser_msg = LaserScan() - laser_msg.header.frame_id = "lidar" - laser_msg.range_min = 0.1 - laser_msg.range_max = 8.0 - laser_msg.angle_min = 0.0 - laser_msg.angle_max = 2.0 * pi - laser_msg.angle_increment = (2 * pi) / (len(self.ranges)) - - joint_state_msg = JointState() - joint_state_msg.header.frame_id = "robot_body" - - imu_msg = Imu() - imu_msg.header.frame_id = "robot_body" - while not rospy.is_shutdown(): - self.socket.settimeout(1.0) try: - length = struct.unpack("I", self.socket.recv(4))[0] - # print(length) - data = self.socket.recv(length) - packet = messages.UdpPacket() - packet.ParseFromString(data) - except socket.timeout: - print( - "No data recieved from Little Red Rover for the past 1.0 seconds. As you sure you're connected to the rover's wifi hotspot?" - ) - continue - except socket.error as e: - print(e) - continue - except Exception as e: - print(e) - continue - - if len(packet.laser) != 0: - for scan in packet.laser: - self.handle_laser_scan(laser_msg, scan) - elif packet.HasField("joint_states"): - self.handle_joint_states(joint_state_msg, packet.joint_states) - elif packet.HasField("imu"): - self.handle_imu(imu_msg, packet.imu) - - def handle_joint_states(self, msg: JointState, packet: messages.JointStates): - msg.header.stamp.set(packet.time.sec, packet.time.nanosec) - msg.name = list(packet.name) - msg.effort = packet.effort - msg.position = packet.position - msg.velocity = packet.velocity - self.joint_state_publisher.publish(msg) + data = self.connection.recv_packet() + packet = messages.NetworkPacket() + packet.ParseFromString(bytes(data)) - def handle_laser_scan(self, msg: LaserScan, packet: messages.LaserScan): - break_in_packet = False - for i in range(len(packet.ranges)): - angle = packet.angle_min + (packet.angle_max - packet.angle_min) * ( - i / (len(packet.ranges) - 1) - ) - index = int(((angle % (2.0 * pi)) / (2.0 * pi)) * 720.0) + for peripheral in self.peripherals: + peripheral.handle_packet(packet) - if angle > pi * 2.0 and not break_in_packet: - msg.time_increment = packet.time_increment - msg.scan_time = packet.scan_time - msg.ranges = self.ranges - msg.intensities = self.intensities - - self.scan_publisher.publish(msg) - break_in_packet = True - - self.ranges = [0.0] * 720 - self.intensities = [0.0] * 720 - - msg.header.stamp.set(packet.time.sec, packet.time.nanosec) - # msg.header.stamp = self.get_clock().now().to_msg() - - self.ranges[index] = packet.ranges[i] - self.intensities[index] = packet.intensities[i] - - if ( - self.ranges[index] > 8.0 - or self.ranges[index] < 0.1 - or self.intensities[index] == 0 - ): - self.ranges[index] = inf - self.intensities[index] = 0.0 - - def handle_imu(self, msg: Imu, packet: messages.IMU): - msg.header.stamp.set(packet.time.sec, packet.time.nanosec) - msg.orientation_covariance = [-1.0] + [0.0] * 8 - msg.linear_acceleration.x = packet.accel_x - msg.linear_acceleration.y = packet.accel_y - msg.linear_acceleration.z = packet.accel_z - # TODO - msg.linear_acceleration_covariance = [0.0] * 9 - msg.angular_velocity.x = packet.gyro_x - msg.angular_velocity.y = packet.gyro_y - msg.angular_velocity.z = packet.gyro_z - # TODO - msg.angular_velocity_covariance = [0.0] * 9 - self.imu_publisher.publish(msg) - - def cmd_vel_callback(self, msg: Twist): - packet = messages.UdpPacket() - packet.cmd_vel.v = msg.linear.x - packet.cmd_vel.w = msg.angular.z - - self.socket.send(packet.SerializeToString()) + except DecodeError: + self.decode_error_count += 1 + if self.decode_error_count > 20: + print(f"Failed to decode {self.decode_error_count} packets.") + except Exception as e: + print(f"Error: {e}") -def main(args=None): +def main(_=None): rospy.init_node("hal", anonymous=True) _ = HAL() diff --git a/little_red_rover/little_red_rover/imu_peripheral.py b/little_red_rover/little_red_rover/imu_peripheral.py new file mode 100644 index 0000000..b1c393c --- /dev/null +++ b/little_red_rover/little_red_rover/imu_peripheral.py @@ -0,0 +1,32 @@ +import rospy +from sensor_msgs.msg import Imu +import little_red_rover.pb.messages_pb2 as messages + + +class ImuPeripheral: + def __init__(self): + self.imu_publisher = rospy.Publisher("imu/data_raw", Imu, queue_size=10) + self.imu_msg = Imu() + self.imu_msg.header.frame_id = "robot_body" + + def handle_packet(self, packet: messages.NetworkPacket): + if not packet.HasField("imu"): + return + self.imu_msg.header.stamp.set(packet.imu.time.sec, packet.imu.time.nanosec) + + # disable orientation + self.imu_msg.orientation_covariance = [-1.0] + [0.0] * 8 + + # accel + self.imu_msg.linear_acceleration.x = packet.imu.accel_x + self.imu_msg.linear_acceleration.y = packet.imu.accel_y + self.imu_msg.linear_acceleration.z = packet.imu.accel_z + self.imu_msg.linear_acceleration_covariance = [0.0] * 9 # TODO + + # gyro + self.imu_msg.angular_velocity.x = packet.imu.gyro_x + self.imu_msg.angular_velocity.y = packet.imu.gyro_y + self.imu_msg.angular_velocity.z = packet.imu.gyro_z + self.imu_msg.angular_velocity_covariance = [0.0] * 9 # TODO + + self.imu_publisher.publish(self.imu_msg) diff --git a/little_red_rover/little_red_rover/lidar_peripheral.py b/little_red_rover/little_red_rover/lidar_peripheral.py new file mode 100644 index 0000000..1f98e86 --- /dev/null +++ b/little_red_rover/little_red_rover/lidar_peripheral.py @@ -0,0 +1,62 @@ +import rospy +from sensor_msgs.msg import LaserScan +import little_red_rover.pb.messages_pb2 as messages + +from math import inf, pi + + +class LidarPeripheral: + def __init__(self): + self.ranges = [0.0] * 720 + self.intensities = [0.0] * 720 + + self.scan_publisher = rospy.Publisher("scan", LaserScan, queue_size=10) + + self.laser_msg = LaserScan() + self.laser_msg.header.frame_id = "lidar" + self.laser_msg.range_min = 0.1 + self.laser_msg.range_max = 8.0 + self.laser_msg.angle_min = 0.0 + self.laser_msg.angle_max = 2.0 * pi + self.laser_msg.angle_increment = (2 * pi) / (len(self.ranges)) + + def handle_packet(self, packet: messages.NetworkPacket): + if len(packet.laser) == 0: + return + for scan in packet.laser: + self.handle_laser_scan(self.laser_msg, scan) + pass + + def handle_laser_scan(self, msg: LaserScan, packet: messages.LaserScan): + break_in_packet = False + for i in range(len(packet.ranges)): + angle = packet.angle_min + (packet.angle_max - packet.angle_min) * ( + i / (len(packet.ranges) - 1) + ) + index = int(((angle % (2.0 * pi)) / (2.0 * pi)) * 720.0) + + if angle > pi * 2.0 and not break_in_packet: + msg.time_increment = packet.time_increment + msg.scan_time = packet.scan_time + msg.ranges = self.ranges + msg.intensities = self.intensities + + self.scan_publisher.publish(msg) + break_in_packet = True + + self.ranges = [0.0] * 720 + self.intensities = [0.0] * 720 + + msg.header.stamp.set(packet.time.sec, packet.time.nanosec) + # msg.header.stamp = self.get_clock().now().to_msg() + + self.ranges[index] = packet.ranges[i] + self.intensities[index] = packet.intensities[i] + + if ( + self.ranges[index] > 8.0 + or self.ranges[index] < 0.1 + or self.intensities[index] == 0 + ): + self.ranges[index] = inf + self.intensities[index] = 0.0 diff --git a/little_red_rover/little_red_rover/pb/messages.proto b/little_red_rover/little_red_rover/pb/messages.proto index 4ee6188..932493d 100644 --- a/little_red_rover/little_red_rover/pb/messages.proto +++ b/little_red_rover/little_red_rover/pb/messages.proto @@ -42,7 +42,7 @@ message IMU { float accel_z = 7; } -message UdpPacket { +message NetworkPacket { repeated LaserScan laser = 1; optional JointStates joint_states = 2; optional TwistCmd cmd_vel = 3; diff --git a/little_red_rover/little_red_rover/pb/messages_pb2.py b/little_red_rover/little_red_rover/pb/messages_pb2.py index 56f45a5..9d6e06b 100644 --- a/little_red_rover/little_red_rover/pb/messages_pb2.py +++ b/little_red_rover/little_red_rover/pb/messages_pb2.py @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0emessages.proto\")\n\tTimeStamp\x12\x0b\n\x03sec\x18\x01 \x01(\x05\x12\x0f\n\x07nanosec\x18\x02 \x01(\r\":\n\x08TwistCmd\x12\x18\n\x04time\x18\x01 \x01(\x0b\x32\n.TimeStamp\x12\t\n\x01v\x18\x02 \x01(\x02\x12\t\n\x01w\x18\x03 \x01(\x02\"\xda\x01\n\tLaserScan\x12\x18\n\x04time\x18\x01 \x01(\x0b\x32\n.TimeStamp\x12\x11\n\tangle_min\x18\x02 \x01(\x02\x12\x11\n\tangle_max\x18\x03 \x01(\x02\x12\x17\n\x0f\x61ngle_increment\x18\x04 \x01(\x02\x12\x16\n\x0etime_increment\x18\x05 \x01(\x02\x12\x11\n\tscan_time\x18\x06 \x01(\x02\x12\x11\n\trange_min\x18\x07 \x01(\x02\x12\x11\n\trange_max\x18\x08 \x01(\x02\x12\x0e\n\x06ranges\x18\t \x03(\x02\x12\x13\n\x0bintensities\x18\n \x03(\x02\"i\n\x0bJointStates\x12\x18\n\x04time\x18\x01 \x01(\x0b\x32\n.TimeStamp\x12\x0c\n\x04name\x18\x02 \x03(\t\x12\x10\n\x08position\x18\x03 \x03(\x01\x12\x10\n\x08velocity\x18\x04 \x03(\x01\x12\x0e\n\x06\x65\x66\x66ort\x18\x05 \x03(\x01\"\x82\x01\n\x03IMU\x12\x18\n\x04time\x18\x01 \x01(\x0b\x32\n.TimeStamp\x12\x0e\n\x06gyro_x\x18\x02 \x01(\x02\x12\x0e\n\x06gyro_y\x18\x03 \x01(\x02\x12\x0e\n\x06gyro_z\x18\x04 \x01(\x02\x12\x0f\n\x07\x61\x63\x63\x65l_x\x18\x05 \x01(\x02\x12\x0f\n\x07\x61\x63\x63\x65l_y\x18\x06 \x01(\x02\x12\x0f\n\x07\x61\x63\x63\x65l_z\x18\x07 \x01(\x02\"\xad\x01\n\tUdpPacket\x12\x19\n\x05laser\x18\x01 \x03(\x0b\x32\n.LaserScan\x12\'\n\x0cjoint_states\x18\x02 \x01(\x0b\x32\x0c.JointStatesH\x00\x88\x01\x01\x12\x1f\n\x07\x63md_vel\x18\x03 \x01(\x0b\x32\t.TwistCmdH\x01\x88\x01\x01\x12\x16\n\x03imu\x18\x04 \x01(\x0b\x32\x04.IMUH\x02\x88\x01\x01\x42\x0f\n\r_joint_statesB\n\n\x08_cmd_velB\x06\n\x04_imub\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0emessages.proto\")\n\tTimeStamp\x12\x0b\n\x03sec\x18\x01 \x01(\x05\x12\x0f\n\x07nanosec\x18\x02 \x01(\r\":\n\x08TwistCmd\x12\x18\n\x04time\x18\x01 \x01(\x0b\x32\n.TimeStamp\x12\t\n\x01v\x18\x02 \x01(\x02\x12\t\n\x01w\x18\x03 \x01(\x02\"\xda\x01\n\tLaserScan\x12\x18\n\x04time\x18\x01 \x01(\x0b\x32\n.TimeStamp\x12\x11\n\tangle_min\x18\x02 \x01(\x02\x12\x11\n\tangle_max\x18\x03 \x01(\x02\x12\x17\n\x0f\x61ngle_increment\x18\x04 \x01(\x02\x12\x16\n\x0etime_increment\x18\x05 \x01(\x02\x12\x11\n\tscan_time\x18\x06 \x01(\x02\x12\x11\n\trange_min\x18\x07 \x01(\x02\x12\x11\n\trange_max\x18\x08 \x01(\x02\x12\x0e\n\x06ranges\x18\t \x03(\x02\x12\x13\n\x0bintensities\x18\n \x03(\x02\"i\n\x0bJointStates\x12\x18\n\x04time\x18\x01 \x01(\x0b\x32\n.TimeStamp\x12\x0c\n\x04name\x18\x02 \x03(\t\x12\x10\n\x08position\x18\x03 \x03(\x01\x12\x10\n\x08velocity\x18\x04 \x03(\x01\x12\x0e\n\x06\x65\x66\x66ort\x18\x05 \x03(\x01\"\x82\x01\n\x03IMU\x12\x18\n\x04time\x18\x01 \x01(\x0b\x32\n.TimeStamp\x12\x0e\n\x06gyro_x\x18\x02 \x01(\x02\x12\x0e\n\x06gyro_y\x18\x03 \x01(\x02\x12\x0e\n\x06gyro_z\x18\x04 \x01(\x02\x12\x0f\n\x07\x61\x63\x63\x65l_x\x18\x05 \x01(\x02\x12\x0f\n\x07\x61\x63\x63\x65l_y\x18\x06 \x01(\x02\x12\x0f\n\x07\x61\x63\x63\x65l_z\x18\x07 \x01(\x02\"\xb1\x01\n\rNetworkPacket\x12\x19\n\x05laser\x18\x01 \x03(\x0b\x32\n.LaserScan\x12\'\n\x0cjoint_states\x18\x02 \x01(\x0b\x32\x0c.JointStatesH\x00\x88\x01\x01\x12\x1f\n\x07\x63md_vel\x18\x03 \x01(\x0b\x32\t.TwistCmdH\x01\x88\x01\x01\x12\x16\n\x03imu\x18\x04 \x01(\x0b\x32\x04.IMUH\x02\x88\x01\x01\x42\x0f\n\r_joint_statesB\n\n\x08_cmd_velB\x06\n\x04_imub\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -41,6 +41,6 @@ _globals['_JOINTSTATES']._serialized_end=447 _globals['_IMU']._serialized_start=450 _globals['_IMU']._serialized_end=580 - _globals['_UDPPACKET']._serialized_start=583 - _globals['_UDPPACKET']._serialized_end=756 + _globals['_NETWORKPACKET']._serialized_start=583 + _globals['_NETWORKPACKET']._serialized_end=760 # @@protoc_insertion_point(module_scope) diff --git a/little_red_rover/little_red_rover/pb/messages_pb2.pyi b/little_red_rover/little_red_rover/pb/messages_pb2.pyi index ee1bc41..089342f 100644 --- a/little_red_rover/little_red_rover/pb/messages_pb2.pyi +++ b/little_red_rover/little_red_rover/pb/messages_pb2.pyi @@ -79,7 +79,7 @@ class IMU(_message.Message): accel_z: float def __init__(self, time: _Optional[_Union[TimeStamp, _Mapping]] = ..., gyro_x: _Optional[float] = ..., gyro_y: _Optional[float] = ..., gyro_z: _Optional[float] = ..., accel_x: _Optional[float] = ..., accel_y: _Optional[float] = ..., accel_z: _Optional[float] = ...) -> None: ... -class UdpPacket(_message.Message): +class NetworkPacket(_message.Message): __slots__ = ("laser", "joint_states", "cmd_vel", "imu") LASER_FIELD_NUMBER: _ClassVar[int] JOINT_STATES_FIELD_NUMBER: _ClassVar[int] diff --git a/little_red_rover/little_red_rover/rover_connection.py b/little_red_rover/little_red_rover/rover_connection.py new file mode 100644 index 0000000..3de9a3e --- /dev/null +++ b/little_red_rover/little_red_rover/rover_connection.py @@ -0,0 +1,53 @@ +import struct +import socket + + +class RoverConnection: + def __init__(self, endpoint): + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.settimeout(5.0) + self.endpoint = endpoint + self.socket.connect(endpoint) + + def recv_packet(self) -> bytes: + """ + Packets are prefixed with the byte string LRR, followed by the message length in bytes. + """ + + data = None + while data == None: + try: + while self.socket.recv(3, socket.MSG_PEEK) != b"LRR": + self.socket.recv(1) + + assert self.recv_length(3) == b"LRR" + length = struct.unpack("H", self.recv_length(2))[0] + data = self.recv_length(length) + except Exception as e: + print(f"Rover connection hit error: {e}. Reconnecting...") + self.socket.close() + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.settimeout(5.0) + self.socket.connect(self.endpoint) + print(f"Error: {e}") + + return data + + def recv_length(self, length) -> bytes: + data = bytearray() + while len(data) < length: + data.extend(self.socket.recv(length - len(data))) + + return bytes(data) + + def send(self, msg: bytes): + try: + self.socket.sendall(msg) + except OSError as e: + if e.errno == 9: + # The socket is currently closed + pass + else: + print(f"OSError while sending: {e}") + except Exception as e: + print(f"Exception while sending: {e}") From 615aebd72034de731dbf77450932c7a9acc95540 Mon Sep 17 00:00:00 2001 From: usedhondacivic Date: Mon, 28 Oct 2024 18:55:54 -0400 Subject: [PATCH 3/3] Improve reconnection and error handling in HAL --- little_red_rover/little_red_rover/hal.py | 12 ++++-- .../little_red_rover/rover_connection.py | 43 +++++++++++++------ 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/little_red_rover/little_red_rover/hal.py b/little_red_rover/little_red_rover/hal.py index 9b03a84..a9636cb 100644 --- a/little_red_rover/little_red_rover/hal.py +++ b/little_red_rover/little_red_rover/hal.py @@ -22,12 +22,18 @@ def __init__(self): self.decode_error_count = 0 - threading.Thread(target=self.run_loop).start() + threading.Thread(target=self.run_loop, daemon=True).start() def run_loop(self): while not rospy.is_shutdown(): try: - data = self.connection.recv_packet() + data = None + while data == None: + try: + data = self.connection.recv_packet() + except Exception as e: + print(e) + packet = messages.NetworkPacket() packet.ParseFromString(bytes(data)) @@ -39,7 +45,7 @@ def run_loop(self): if self.decode_error_count > 20: print(f"Failed to decode {self.decode_error_count} packets.") except Exception as e: - print(f"Error: {e}") + print(f"HAL: Error - {e}") def main(_=None): diff --git a/little_red_rover/little_red_rover/rover_connection.py b/little_red_rover/little_red_rover/rover_connection.py index 3de9a3e..8543d18 100644 --- a/little_red_rover/little_red_rover/rover_connection.py +++ b/little_red_rover/little_red_rover/rover_connection.py @@ -1,5 +1,7 @@ import struct import socket +import typing +import time class RoverConnection: @@ -7,31 +9,42 @@ def __init__(self, endpoint): self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.settimeout(5.0) self.endpoint = endpoint - self.socket.connect(endpoint) + try: + self.socket.connect(self.endpoint) + except Exception: + pass - def recv_packet(self) -> bytes: + def recv_packet(self) -> typing.Union[bytes, None]: """ Packets are prefixed with the byte string LRR, followed by the message length in bytes. """ data = None - while data == None: - try: - while self.socket.recv(3, socket.MSG_PEEK) != b"LRR": - self.socket.recv(1) + try: + while self.socket.recv(3, socket.MSG_PEEK) != b"LRR": + self.socket.recv(1) - assert self.recv_length(3) == b"LRR" - length = struct.unpack("H", self.recv_length(2))[0] - data = self.recv_length(length) - except Exception as e: - print(f"Rover connection hit error: {e}. Reconnecting...") + assert self.recv_length(3) == b"LRR" + length = struct.unpack("H", self.recv_length(2))[0] + data = self.recv_length(length) + except Exception as e: + print(f"Rover connection hit an error: {e}. Reconnecting...") + self.reconnect() + + return data + + def reconnect(self): + while True: + try: self.socket.close() self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.settimeout(5.0) self.socket.connect(self.endpoint) - print(f"Error: {e}") - - return data + print("Reconnected!") + return + except Exception as e: + time.sleep(1.0) + print(f"Error while reconnecting: {e}. Trying again in 1 second...") def recv_length(self, length) -> bytes: data = bytearray() @@ -42,6 +55,8 @@ def recv_length(self, length) -> bytes: def send(self, msg: bytes): try: + self.socket.sendall(b"LRR") + self.socket.sendall(len(msg).to_bytes(2, byteorder="little")) self.socket.sendall(msg) except OSError as e: if e.errno == 9: