Skip to content

Commit

Permalink
[WIP] Fix several bugs.
Browse files Browse the repository at this point in the history
First working version of isolated drone class.
  • Loading branch information
amacati committed Jun 17, 2024
1 parent 4195b93 commit 6f68927
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 598 deletions.
7 changes: 7 additions & 0 deletions safe_control_gym/envs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,15 @@ class SimConstants: ...
class DroneConstants:
firmware_freq: int = 500 # Firmware frequency in Hz
supply_voltage: float = 3.0 # Power supply voltage
min_pwm: int = 20000 # Minimum PWM signal
max_pwm: int = 65535 # Maximum PWM signal
thrust_curve_a: float = -0.0006239 # Thrust curve parameters for brushed motors
thrust_curve_b: float = 0.088 # Thrust curve parameters for brushed motors
tumble_threshold: float = -0.5 # Vertical acceleration threshold for tumbling detection
tumble_duration: int = 30 # Number of consecutive steps before tumbling is detected
# TODO: acc and gyro were swapped in original implementation. Possible bug?
acc_lpf_cutoff: int = 80 # Low-pass filter cutoff freq
gyro_lpf_cutoff: int = 30 # Low-pass filter cutoff freq
KF: float = 3.16e-10 # Motor force factor
pwm2rpm_scale: float = 0.2685 # mapping factor from PWM to RPM
pwm2rpm_const: float = 4070.3 # mapping constant from PWM to RPM
155 changes: 92 additions & 63 deletions safe_control_gym/envs/drone.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import importlib.util
import logging
from typing import Literal
Expand All @@ -13,8 +15,9 @@


class Drone:
def __init__(self, controller: Literal["pid", "mellinger"], ctrl_freq: int = 30):
def __init__(self, controller: Literal["pid", "mellinger"]):
self.firmware = self._load_firmware()
self.firmware_freq = Constants.firmware_freq
# Initialize firmware states
self._state = self.firmware.state_t()
self._control = self.firmware.control_t()
Expand All @@ -23,13 +26,8 @@ def __init__(self, controller: Literal["pid", "mellinger"], ctrl_freq: int = 30)
self._acc_lpf = [self.firmware.lpf2pData() for _ in range(3)]
self._gyro_lpf = [self.firmware.lpf2pData() for _ in range(3)]

self.ctrl_freq = ctrl_freq
assert controller in ["pid", "mellinger"], f"Invalid controller {controller}."
self._controller = controller
if controller == "pid":
self._controller = self.firmware.controllerPid
else:
self._controller = self.firmware.controllerMellinger
# Helper variables for the controller
self._pwms = np.zeros(4) # PWM signals for each motor
self._tick = 0 # Current controller tick
Expand Down Expand Up @@ -59,55 +57,73 @@ def reset(
self._reset_controller()
# Initilaize high level commander
self.firmware.crtpCommanderHighLevelInit()
self._update_state(pos, rpy, vel)
self._update_state(0, pos, rpy * RAD_TO_DEG, vel, np.array([0, 0, 1.0]))
self._last_vel[...], self._last_rpy[...] = vel, rpy
self.firmware.crtpCommanderHighLevelTellState(self._state)

def step(
def step_controller(
self,
pos: npt.NDArray[np.float64],
rpy: npt.NDArray[np.float64],
vel: npt.NDArray[np.float64],
sim_time: float,
):
"""Take a drone controller step.
Args:
sim_time: Time in s from start of flight.
"""
self.firmware.crtpCommanderHighLevelStop() # Resets planner object
self.firmware.crtpCommanderHighLevelUpdateTime(sim_time)
command, args = self.command_queue.pop(0)
getattr(self, command)(*args)

body_rot = R.from_euler("XYZ", rpy).inv()
# Estimate rates
rotation_rates = (rpy - self.prev_rpy) * Constants.firmware_freq # body coord, rad/s
self.prev_rpy = rpy
rotation_rates = (rpy - self._last_rpy) * Constants.firmware_freq # body coord, rad/s
self._last_rpy = rpy
# TODO: Convert to real acc, not multiple of g
acc = (vel - self.prev_vel) * Constants.firmware_freq / 9.8 + np.array([0, 0, 1])
self.prev_vel = vel
acc = (vel - self._last_vel) * Constants.firmware_freq / 9.8 + np.array([0, 0, 1])
self._last_vel = vel
# Update state
timestamp = int(self._tick / Constants.firmware_freq * 1e3)
self._update_state(timestamp, pos, vel, acc, rpy * RAD_TO_DEG)
self._update_state(timestamp, pos, rpy * RAD_TO_DEG, vel, acc)
# Update sensor data
sensor_timestamp = int(self._tick / Constants.firmware_freq * 1e6)
self._update_sensorData(sensor_timestamp, body_rot.apply(acc), rotation_rates * RAD_TO_DEG)
self._update_sensor_data(sensor_timestamp, body_rot.apply(acc), rotation_rates * RAD_TO_DEG)
# Update setpoint
self._updateSetpoint(self._tick / Constants.firmware_freq)
self._update_setpoint(self._tick / Constants.firmware_freq)
# Step controller
self._step_controller()
# Get action. TODO: Is this really needed?
# new_action = (
# self.KF
# * (
# self.PWM2RPM_SCALE * np.clip(np.array(self.pwms), self.MIN_PWM, self.MAX_PWM)
# + self.PWM2RPM_CONST
# )
# ** 2
# )
# action = new_action[[3, 2, 1, 0]]

# region Commands
success = self._step_controller()
self._tick += 1
if not success:
self._pwms[...] = 0
return np.zeros(4)
return self._pwms_to_action(self._pwms)

@property
def tick(self) -> int:
return self._tick

def _update_state(
self,
timestamp: float,
pos: npt.NDArray[np.float64],
rpy: npt.NDArray[np.float64],
vel: npt.NDArray[np.float64],
acc: npt.NDArray[np.float64],
):
for name, value in zip(("timestamp", "roll", "pitch", "yaw"), (timestamp, *rpy)):
if name == "pitch":
value = -value # Legacy cf coordinate system uses inverted pitch
setattr(self._state.attitude, name, value)
if self._controller == "mellinger": # Requires quaternion
quat = R.from_euler("XYZ", rpy, degrees=True).as_quat()
for name, value in zip(("x", "y", "z", "w"), quat):
setattr(self._state.attitudeQuaternion, name, value)
for name, value in zip(("x", "y", "z"), pos):
setattr(self._state.position, name, value)
for name, value in zip(("x", "y", "z"), vel):
setattr(self._state.velocity, name, value)
for name, value in zip(("x", "y", "z"), acc):
setattr(self._state.acc, name, value)

def _pwms_to_action(self, pwms: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
return Constants.KF * (Constants.pwm2rpm_scale * pwms + Constants.pwm2rpm_const) ** 2

def full_state_cmd(
self,
Expand All @@ -131,25 +147,29 @@ def full_state_cmd(
rpy_rate: roll, pitch, yaw rates (rad/s)
timestep: simulation time when command is sent (s)
"""
timestep = self._tick / Constants.firmware_freq # TODO: Adopt for all commands, remove arg
self.firmware.crtpCommanderHighLevelStop() # Resets planner object
self.firmware.crtpCommanderHighLevelUpdateTime(timestep)

for name, x in zip(("pos", "vel", "acc", "rpy_rate"), (pos, vel, acc, rpy_rate)):
assert isinstance(x, np.ndarray), f"{name} must be a numpy array."
assert len(x) == 3, f"{name} must have length 3."
self.setpoint.position.x, self.setpoint.position.y, self.setpoint.position.z = pos
self.setpoint.velocity.x, self.setpoint.velocity.y, self.setpoint.velocity.z = vel
s_acc = self.setpoint.acceleration
self._setpoint.position.x, self._setpoint.position.y, self._setpoint.position.z = pos
self._setpoint.velocity.x, self._setpoint.velocity.y, self._setpoint.velocity.z = vel
s_acc = self._setpoint.acceleration
s_acc.x, s_acc.y, s_acc.z = acc
s_a_rate = self.setpoint.attitudeRate
s_a_rate = self._setpoint.attitudeRate
s_a_rate.roll, s_a_rate.pitch, s_a_rate.yaw = rpy_rate * RAD_TO_DEG
s_quat = self.setpoint.attitudeQuaternion
s_quat = self._setpoint.attitudeQuaternion
s_quat.x, s_quat.y, s_quat.z, s_quat.w = R.from_euler("XYZ", [0, 0, yaw]).as_quat()
# initilize setpoint modes to match cmdFullState
mode = self.setpoint.mode
mode = self._setpoint.mode
mode_abs, mode_disable = self.firmware.modeAbs, self.firmware.modeDisable
mode.x, mode.y, mode.z = mode_abs, mode_abs, mode_abs
mode.quat = mode_abs
mode.roll, mode.pitch, mode.yaw = mode_disable, mode_disable, mode_disable
# This may end up skipping control loops
self.setpoint.timestamp = int(timestep * 1000)
self._setpoint.timestamp = int(timestep * 1000)
self._fullstate_cmd = True

def takeoff_cmd(self, height: float, duration: float, yaw: float | None = None):
Expand Down Expand Up @@ -222,9 +242,6 @@ def notify_setpoint_stop(self):
self.firmware.crtpCommanderHighLevelTellState(self.state)
self._fullstate_cmd = False

# endregion
# region reset

def _reset_firmware_states(self):
self._state = self.firmware.state_t()
self._control = self.firmware.control_t()
Expand All @@ -238,8 +255,8 @@ def _reset_low_pass_filters(self):
self._acc_lpf = [self.firmware.lpf2pData() for _ in range(3)]
self._gyro_lpf = [self.firmware.lpf2pData() for _ in range(3)]
for i in range(3):
self.firmware.lpf2pinit(self._acc_lpf[i], freq, Constants.acc_lpf_cutoff)
self.firmware.lpf2pinit(self._gyro_lpf[i], freq, Constants.gyro_lpf_cutoff)
self.firmware.lpf2pInit(self._acc_lpf[i], freq, Constants.acc_lpf_cutoff)
self.firmware.lpf2pInit(self._gyro_lpf[i], freq, Constants.gyro_lpf_cutoff)

def _reset_helper_variables(self):
self._n_tumble = 0
Expand All @@ -254,18 +271,13 @@ def _reset_controller(self):
else:
self.firmware.controllerMellingerInit()

# endregion
# region Drone step

def _step_controller(self):
"""Step the controller."""
def _step_controller(self) -> bool:
# Check if the drone is tumblig. If yes, set the control signal to zero.
self._n_tumble = 0 if self._state.acc.z > Constants.tumble_threshold else self._n_tumble + 1
if self._n_tumble > Constants.tumble_duration:
logger.debug("CrazyFlie is tumbling. Killing motors to simulate damage prevention.")
self._pwms[...] = 0
self._tick += 1
return # Skip controller step
return False # Skip controller step
# Determine tick based on time passed, allowing us to run pid slower than the 1000Hz it was
# designed for
tick = self._determine_controller_tick()
Expand All @@ -275,8 +287,7 @@ def _step_controller(self):
ctrl = self.firmware.controllerMellinger
ctrl(self._control, self._setpoint, self._sensor_data, self._state, tick)
self._update_pwms(self._control)
self._tick += 1
return
return True

def _determine_controller_tick(self) -> Literal[0, 1, 2]:
"""Determine which controller to run based on time passed.
Expand Down Expand Up @@ -307,14 +318,11 @@ def _update_pwms(self, control):
# Quad formation is X
r = control.roll / 2
p = control.pitch / 2
thrust = [
control.thrust - r + p + control.yaw,
control.thrust - r - p - control.yaw,
control.thrust + r - p + control.yaw,
control.thrust + r + p - control.yaw,
]
y = control.yaw
thrust = control.thrust
thrust = [thrust - r + p + y, thrust - r - p - y, thrust + r - p + y, thrust + r + p - y]
thrust = np.clip(thrust, 0, Constants.max_pwm) # Limit thrust to motor range
self._pwms = self._thrust_to_pwm(thrust)
self._pwms = np.clip(self._thrust_to_pwm(thrust), Constants.min_pwm, Constants.max_pwm)

@staticmethod
def _thrust_to_pwm(thrust: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
Expand All @@ -333,7 +341,28 @@ def _thrust_to_pwm(thrust: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
percentage = np.minimum(1, volts / Constants.supply_voltage)
return percentage * Constants.max_pwm

# endregion
def _update_sensor_data(
self, timestamp: float, acc: npt.NDArray[np.float64], gyro: npt.NDArray[np.float64]
):
"""Update the onboard sensors with low-pass filtered values.
Args:
timestamp: Sensor reading time in microseconds.
acc: Acceleration values in Gs.
gyro: Gyro values in deg/s.
"""
for name, i, val in zip(("x", "y", "z"), range(3), acc):
setattr(self._sensor_data.acc, name, self.firmware.lpf2pApply(self._acc_lpf[i], val))
for name, i, val in zip(("x", "y", "z"), range(3), gyro):
setattr(self._sensor_data.gyro, name, self.firmware.lpf2pApply(self._gyro_lpf[i], val))
self._sensor_data.interruptTimestamp = timestamp

def _update_setpoint(self, timestep: float):
if not self._fullstate_cmd:
self.firmware.crtpCommanderHighLevelTellState(self._state)
self.firmware.crtpCommanderHighLevelUpdateTime(timestep)
self.firmware.crtpCommanderHighLevelGetSetpoint(self._setpoint, self._state)

# region Utils

def _load_firmware(self) -> ModuleType:
Expand Down
Loading

0 comments on commit 6f68927

Please sign in to comment.